ShapeONNX: Solving ONNX’s Dynamic Shape Problem¶
Introduction¶
ONNX’s shape inference is designed for dynamic models where shapes can vary at runtime—exactly what’s needed for flexible deployment across different batch sizes and input dimensions. For verification workflows that prove mathematical properties, we need a complementary approach that resolves shapes to concrete static values.
When you export a PyTorch model with dynamic operations, ONNX returns shape [-1] for dimensions that vary at runtime. This design enables models to adapt to different inputs during deployment. For neural network verification tools, which need to prove properties about every neuron in every layer, we need exact static shapes to function.
Verification tools perform layer-by-layer bound propagation, generate SMT constraints for formal proofs, and allocate symbolic execution environments—all of which require knowing exact tensor dimensions: not “this tensor has rank 4” or “batch dimension is dynamic,” but concrete shapes like ``[1, 48, 2, 2]``.
The challenge is about different design goals: ONNX tracks tensor shape metadata (this is a 1D tensor with 4 elements), while verification needs the actual values (those 4 elements are [1, 48, 2, 2]). When you have operator chains like Shape(input) → Gather(indices=[2,3]) → ConstantOfShape, ONNX correctly reports that Gather outputs a 1D tensor with 2 elements. However, it doesn’t track that those elements are [2, 2], so ConstantOfShape returns a dynamic shape [-1, -1].
ShapeONNX complements ONNX’s design through dual-track shape representation. We track both traditional tensor metadata (data_shapes) and the actual constant values flowing through shape operations (explicit_shapes). When the Shape operator extracts [1, 48, 2, 2] from an input tensor, we store those values explicitly. When Gather slices to [2, 2], we propagate the sliced values. When ConstantOfShape needs a shape, we can provide the concrete answer: [2, 2].
The result: 136 VNN-COMP 2024 models processed in ~9.5 seconds total (~70ms average per model), with shapes resolved to concrete static values for verification workflows that require them.
This is the story of ShapeONNX—why tracking shape values matters, how dual-track representation works, and how it enables verification tools to work with models that use dynamic shape operations.
Part 1: The “Why” - Understanding the Shape Problem¶
ONNX’s Shape Inference: What It Does¶
ONNX’s built-in shape inference (onnx.shape_inference.infer_shapes) is excellent at what it’s designed for: propagating tensor metadata through computational graphs. Given an input with shape [1, 48, 2, 2] flowing through a Conv layer, it correctly computes the output shape based on kernel size, stride, and padding.
For most machine learning workflows—training, inference, deployment—this is perfect. Frameworks only need to know tensor dimensions to allocate memory and schedule operations.
But there’s a critical distinction: ONNX tracks what shape a tensor has, not what values a shape tensor contains.
The Gap: Data Shapes vs Shape Values¶
Consider this PyTorch pattern that appears in Vision Transformers, dynamic GANs, and custom architectures:
# PyTorch: Create a tensor with dimensions from input
x = torch.randn(1, 48, 2, 2)
spatial_dims = x.shape[2:] # Extract [2, 2]
new_tensor = torch.zeros(*spatial_dims, 64) # Shape [2, 2, 64]
When exported to ONNX, this becomes an operator chain:
Input: tensor[1, 48, 2, 2]
↓
Shape → output: ???
↓
Slice(start=2, end=4) → output: ???
↓
Concat(other_dims) → output: ???
↓
ConstantOfShape → output: ???
What ONNX Knows:
Shape output: 1D tensor with 4 elements (metadata:
[4])Slice output: 1D tensor with 2 elements (metadata:
[2])Concat output: 1D tensor with 3 elements (metadata:
[3])ConstantOfShape output: Some tensor (shape:
[-1, -1, -1]— dynamic!)
What ONNX Doesn’t Know:
Shape output contains the values
[1, 48, 2, 2]Slice output contains the values
[2, 2]Concat output contains the values
[2, 2, 64]ConstantOfShape output has static shape
[2, 2, 64]
ONNX sees containers (1D tensors). ShapeONNX sees contents ([2, 2]).
Why This Matters for Verification¶
Neural network verification tools prove properties like “for all inputs in this region, the output is classified correctly.” They work by:
Bound Propagation: Compute output bounds layer-by-layer from input constraints
Constraint Generation: Generate logical formulas for each operation
SMT Solving: Use formal methods to prove or find counterexamples
All three require exact tensor dimensions:
Bound propagation needs to allocate interval bounds for every neuron. A Conv layer with input
[1, 64, 32, 32]and kernel 3×3 has64 * 30 * 30 = 57,600output neurons. Dynamic shapes? Can’t allocate bounds.Constraint generation encodes operations as mathematical relations.
Gemm(M, K) × (K, N) = (M, N)requires knowing M, K, N to generate the right number of constraints. Dynamic K? Can’t generate constraints.Memory allocation for symbolic execution needs exact sizes. Verification tools often process models with 100+ layers and millions of neurons. Dynamic shapes mean unpredictable memory requirements.
Verification tools require static shapes to function.
Shape Operator Chains in VNN-COMP 2024¶
ONNX’s design to support dynamic shapes means it returns -1 for dimensions it can’t determine statically. This is correct for dynamic models, but verification tools require concrete values. Here are examples from VNN-COMP 2024 where ShapeONNX’s value tracking resolves dynamic shapes:
Vision Transformers: The patch embedding layer reshapes inputs based on computed dimensions. Pattern:
Shape(input) → Gather(patch_size) → Div(stride) → Reshape(...)
ONNX result: Reshape target shape is [-1, -1, 768] (dynamic, as designed).
ShapeONNX result: Reshape target shape is [1, 196, 768] (static, enabling verification).
GANs with Dynamic Generation: Generator creates images with sizes computed from latent vectors:
Shape(latent) → Mul(upscale_factor) → ConstantOfShape
ONNX result: ConstantOfShape creates tensor with shape [-1, -1] (dynamic shape).
ShapeONNX result: ConstantOfShape creates tensor with shape [4, 4] (static shape, verification proceeds).
Custom Architectures with Computed Shapes: Research models often have novel shape manipulations:
Shape(x) → Slice → Concat(constants) → Reshape
ONNX result: Reshape target is [-1, -1, -1, -1] (dynamic).
ShapeONNX result: Reshape target is [1, 256, 14, 14] (static).
In VNN-COMP 2024, we tested 136 models across 23 benchmarks. For models with shape operator chains, ONNX returns dynamic shapes by design, while ShapeONNX resolves them to static values. ShapeONNX achieves this resolution efficiently—processing 136 diverse models in ~9.5 seconds (~70ms average), with complex models like Vision Transformers resolved in under 200ms. Both behaviors are correct for their intended use cases—ONNX for flexible deployment, ShapeONNX for verification that requires concrete dimensions.
ONNX’s Design Trade-offs¶
ONNX’s approach to shape inference reflects its design for dynamic, portable models. These are architectural choices, not limitations:
Shape operator output: ONNX knows it’s a 1D tensor. Doesn’t track the actual dimension values.
Operations on shape tensors: Gather, Slice, Concat on shape tensors produce new shape tensors. ONNX tracks metadata, not values.
ConstantOfShape: Requires knowing the input tensor’s values (the target shape), not just its shape (that it’s 1D).
Reshape with -1: Inferring the -1 dimension requires knowing total elements (ONNX knows this) and other dimensions (which may come from shape tensors—ONNX doesn’t know the values).
Verification-specific patterns: Models exported specifically for verification often have more shape manipulations than training models, since verification tools need normalized forms.
The gap is fundamental: ONNX represents shapes as tensors, but doesn’t track tensor values for shapes. ShapeONNX bridges this gap with explicit shape value propagation.
Part 2: The Core Innovation - Dual-Track Shape Representation¶
Traditional Approach (ONNX)¶
ONNX maintains a single mapping:
shapes: dict[str, list[int]] # tensor_name → shape metadata
Example:
shapes = {
"input": [1, 48, 2, 2], # Input tensor has this shape
"shape_output": [4], # Shape op outputs a 1D tensor with 4 elements
"gather_output": [2], # Gather op outputs a 1D tensor with 2 elements
"constant": [None, None], # ConstantOfShape outputs... unknown!
}
This works perfectly for data tensors (activations, weights). But for shape tensors (outputs of Shape, Gather on shapes, etc.), it only tells us the container type, not the contents.
ShapeONNX Approach: Dual-Track Representation¶
ShapeONNX maintains two dictionaries:
@dataclass(frozen=True)
class ShapeInferenceContext:
"""Dual-track shape representation."""
data_shapes: dict[str, int | list[int]] # Tensor metadata
explicit_shapes: dict[str, int | list[int]] # Constant shape values
initializers: dict[str, TensorProto]
verbose: bool = False
From shapeonnx/infer_shape.py:19-34.
data_shapes: Traditional shape metadata
"relu_output" → [1, 64, 32, 32]: ReLU output is a 4D tensor"shape_output" → [4]: Shape operator output is a 1D tensor with 4 elements
explicit_shapes: Actual constant values in shape tensors
"shape_output" → [1, 64, 32, 32]: The Shape operator extracted these dimension values"gather_output" → [32, 32]: Gather sliced the shape tensor to these values
Why This Works:
Data tensors (ReLU, Conv, MatMul) only use
data_shapes—they don’t care about values, only dimensions for memory allocation.Shape tensors (Shape operator outputs, Gather on shapes) populate both dictionaries:
data_shapes: Container type (e.g.,[4]for a 4-element vector)explicit_shapes: Actual values (e.g.,[1, 64, 32, 32])
Consumers of shape tensors (ConstantOfShape, Reshape) check
explicit_shapesfirst. If present, use the actual values. If not, fall back todata_shapes.
The Shape Retrieval Pattern¶
All shape inference functions use this pattern:
def get_shape(
name: str,
shapes: dict[str, int | list[int]],
explicit_shapes: dict[str, int | list[int]],
) -> tuple[int | list[int] | None, bool]:
"""
Retrieve shape from any available source.
Returns: (shape, is_explicit)
"""
if (shape := shapes.get(name)) is not None:
return shape, False # Regular data tensor
if (explicit_shape := explicit_shapes.get(name)) is not None:
return explicit_shape, True # Shape tensor with known values
raise RuntimeError(f"Cannot get shape of {name}")
From shapeonnx/infer_shape.py:110-127.
The ``is_explicit`` flag tells downstream operators whether they’re working with:
Regular data: Use the shape for dimension calculations
Shape tensor: The “shape” is actually the values, propagate them explicitly
Complete Chain Example¶
Let’s trace the complete flow through ShapeONNX:
# Starting state: Input tensor with known shape
data_shapes = {"input": [1, 48, 2, 2]}
explicit_shapes = {}
Step 1: Shape Operator
# ONNX node: Shape(input) → shape_output
def infer_shape_op_shape(node, ctx):
"""Shape operator extracts dimension values."""
shape, is_explicit = get_shape(node.input[0], ctx.data_shapes, ctx.explicit_shapes)
if not is_explicit:
# Input is regular data tensor, extract its shape as values
return [(None, shape)] # data_shape=None, explicit_shape=[1,48,2,2]
# Input already has explicit shape, propagate it
return [(None, shape)]
From shapeonnx/infer_shape.py:1093-1120.
Result:
data_shapes = {
"input": [1, 48, 2, 2],
"shape_output": [4] # It's a 1D tensor with 4 elements
}
explicit_shapes = {
"shape_output": [1, 48, 2, 2] # The actual dimension values!
}
Step 2: Gather Operator on Shape Tensor
# ONNX node: Gather(shape_output, indices=[2,3]) → gather_output
def infer_gather_shape(node, ctx):
"""Gather can slice shape tensors."""
# Check for explicit shape first (for shape tensors)
e_shape = get_explicit_shape(node.input[0], ctx.explicit_shapes)
if e_shape is not None:
# Get indices from initializer
indices = get_initializer_value(node.input[1], ctx.initializers)
# Slice the shape values
result = [e_shape[i] for i in indices] # [1,48,2,2][2:4] → [2,2]
return [(None, result)] # Explicit shape output
# Fall back to regular data gather...
From shapeonnx/infer_shape.py:707-753.
Result:
data_shapes = {
"input": [1, 48, 2, 2],
"shape_output": [4],
"gather_output": [2] # 1D tensor with 2 elements
}
explicit_shapes = {
"shape_output": [1, 48, 2, 2],
"gather_output": [2, 2] # Sliced values!
}
Step 3: ConstantOfShape Operator
# ONNX node: ConstantOfShape(gather_output) → constant_tensor
def infer_constantofshape_shape(node, ctx):
"""Create tensor with shape from input."""
# The input tensor contains the target shape as VALUES
e_shape = get_explicit_shape(node.input[0], ctx.explicit_shapes)
if e_shape is not None:
# We know the exact output shape!
return [(e_shape, None)] # data_shape=[2,2], explicit_shape=None
# Fallback: unknown shape
return [([-1], None)]
From shapeonnx/infer_shape.py:550-570.
Result:
data_shapes = {
"input": [1, 48, 2, 2],
"shape_output": [4],
"gather_output": [2],
"constant_tensor": [2, 2] # Static shape!
}
explicit_shapes = {
"shape_output": [1, 48, 2, 2],
"gather_output": [2, 2],
}
ONNX would report: constant_tensor has shape [-1, -1] (dynamic).
ShapeONNX reports: constant_tensor has shape [2, 2] (static).
This is the power of dual-track representation: By tracking both metadata and values, we resolve shapes ONNX marks as dynamic to concrete static dimensions.
Part 3: Dynamic to Static Resolution - Case Studies¶
Case Study 1: Shape Operator - Extracting Dimension Values¶
The Shape operator is special: it converts tensor metadata into explicit values.
Implementation:
def infer_shape_op_shape(node, ctx):
"""Shape operator extracts dimension values."""
shape, is_explicit = get_shape(node.input[0], ctx.data_shapes, ctx.explicit_shapes)
if not is_explicit:
# Regular data tensor: extract its shape as explicit values
# Input metadata: [1, 48, 2, 2]
# Output values: [1, 48, 2, 2]
return [(None, shape)]
# Already explicit (rare case: Shape of a shape tensor)
return [(None, shape)]
From shapeonnx/infer_shape.py:1093-1120.
Key insight: The Shape operator’s output metadata is [4] (a 4-element vector), but its explicit values are [1, 48, 2, 2] (the input tensor’s dimensions).
This is where dual-track representation starts: data becomes values.
Case Study 2: Gather on Shape Tensors - Slicing Dimension Values¶
Gather can operate on either data tensors or shape tensors. The behavior differs:
On data tensors: Standard indexing (extract elements from activations)
On shape tensors: Slice dimension values
Implementation:
def infer_gather_shape(node, ctx):
"""Gather can slice shape tensors."""
axis = get_onnx_attrs(node, ctx.initializers)["axis"]
indices = onnx.numpy_helper.to_array(ctx.initializers[node.input[1]]).tolist()
# Check for explicit shape first (for shape tensors like from Shape op)
e_shape = get_explicit_shape(node.input[0], ctx.explicit_shapes)
if e_shape is not None:
# Gathering from a shape tensor (explicit shape)
if axis != 0:
raise ValueError(f"Invalid axis {axis} for gather from explicit shape")
if not isinstance(e_shape, list):
raise RuntimeError(f"Cannot gather from non-list explicit shape {e_shape}")
# Slice the values
if isinstance(indices, int):
e_shape = e_shape[indices] # Single index: [1,48,2,2][2] → 2
else:
e_shape = [e_shape[i] for i in indices] # List: [1,48,2,2][[2,3]] → [2,2]
return [(None, e_shape)]
# Fallback to data shape (for regular data tensors)
# ... standard Gather shape inference
From shapeonnx/infer_shape.py:707-753.
Example:
# Input: explicit_shapes["shape_output"] = [1, 48, 2, 2]
# Gather(indices=[2, 3], axis=0)
# Output: explicit_shapes["gather_output"] = [2, 2]
This preserves the distinction: We’re not gathering from a [4]-shaped tensor, we’re slicing the values [1, 48, 2, 2] to get [2, 2].
Case Study 3: ConstantOfShape - Static Shape Resolution¶
ConstantOfShape creates a tensor filled with a constant value. The input tensor contains the target shape as values, not as metadata.
ONNX’s problem: The input’s metadata is [n] (a 1D vector). ONNX doesn’t know the values in that vector, so it can’t determine the output shape.
ShapeONNX’s solution: Check explicit_shapes for the input tensor’s values, use them as the output shape.
Implementation:
def infer_constant_of_shape_shape(node, ctx):
"""Infer shape for ConstantOfShape operator."""
# Get the ACTUAL VALUES from the shape tensor
shape = get_explicit_shape(node.input[0], ctx.explicit_shapes)
if shape is None:
raise RuntimeError(f"Cannot get explicit shape of {node.input[0]}")
if shape != [0]: # Non-empty shape
# Also compute the constant values if integers
value = get_onnx_attrs(node, ctx.initializers)["value"]
if np.issubdtype(value.dtype, np.integer):
constant = np.full(shape, value, dtype=value.dtype).tolist()
return [(shape, constant)] # Both data shape and explicit values
return [(shape, None)]
From shapeonnx/infer_shape.py:550-570.
Example:
# Input: explicit_shapes["gather_output"] = [2, 2]
# ConstantOfShape(value=0)
# Output: data_shapes["constant_tensor"] = [2, 2] (static!)
# explicit_shapes["constant_tensor"] = [[0, 0], [0, 0]] (if tracking values)
This is the payoff: A chain of shape operations (Shape → Gather → ConstantOfShape) resolves to a concrete static output shape [2, 2] instead of ONNX’s [-1, -1].
Case Study 4: Reshape with -1 - Inferring Unknown Dimensions¶
Reshape can have one dimension specified as -1, meaning “infer this from the total number of elements.”
ONNX challenge: If the target shape comes from a shape tensor, ONNX doesn’t know the values, so it can’t compute the -1 dimension.
ShapeONNX solution: Use explicit_shapes to get the target shape values, compute the -1 dimension.
Implementation:
def infer_reshape_output_shape(ori_shape, new_shape):
"""Compute -1 dimension in reshape."""
total = math.prod(ori_shape) # Total elements
# Find -1 position
inferred_idx = -1
remaining = total
for idx, dim in enumerate(new_shape):
if dim == -1:
inferred_idx = idx
else:
remaining //= dim
result = new_shape.copy()
if inferred_idx != -1:
result[inferred_idx] = remaining # Infer the -1 dimension
return result
def infer_reshape_shape(node, ctx):
"""Infer shape for Reshape operator."""
data_shape, _ = get_shape(node.input[0], ctx.data_shapes, ctx.explicit_shapes)
target_shape = get_explicit_shape(node.input[1], ctx.explicit_shapes)
if not isinstance(data_shape, list) or not isinstance(target_shape, list):
return [([0], None)]
# Resolve -1 dimension using total elements
shape = infer_reshape_output_shape(data_shape, target_shape)
return [(shape, None)]
From shapeonnx/infer_shape.py:986-1031.
Example:
# Input tensor: [1, 48, 2, 2] (total: 192 elements)
# Target shape (from explicit_shapes): [1, 48, -1]
# Computed output: [1, 48, 4] (192 / (1*48) = 4)
Without explicit_shapes: ONNX sees target shape metadata [3] but doesn’t know the values [1, 48, -1], can’t compute -1.
With explicit_shapes: ShapeONNX knows target is [1, 48, -1], computes the inferred dimension.
Complete Operator Chain - Real-World Example¶
Let’s trace a real pattern from Vision Transformer models:
# PyTorch export of patch embedding
x = input # [1, 3, 224, 224]
shape = x.shape # Extract shape
h, w = shape[2], shape[3] # Get spatial dimensions
n_patches = (h // patch_size) * (w // patch_size)
x_reshaped = x.reshape(1, n_patches, -1) # Flatten patches
ONNX graph:
Input[1,3,224,224] → Shape → shape_vec[4]
shape_vec → Gather(idx=2) → h_val[1]
shape_vec → Gather(idx=3) → w_val[1]
h_val → Div(patch=16) → h_patches[1]
w_val → Div(patch=16) → w_patches[1]
h_patches, w_patches → Mul → n_patches[1]
n_patches → Concat([1, ?, -1]) → target[3]
Input, target → Reshape → output[???]
ONNX result: Reshape output shape is [-1, -1, -1] (all dynamic).
ShapeONNX trace:
Shape:
explicit_shapes["shape_vec"] = [1, 3, 224, 224]Gather(2):
explicit_shapes["h_val"] = 224Gather(3):
explicit_shapes["w_val"] = 224Div(16):
explicit_shapes["h_patches"] = 14Div(16):
explicit_shapes["w_patches"] = 14Mul:
explicit_shapes["n_patches"] = 196Concat:
explicit_shapes["target"] = [1, 196, -1]Reshape: Computes
-1 → 768(total1*3*224*224 / (1*196) = 768)
ShapeONNX result: Reshape output shape is [1, 196, 768] (fully static).
This enables verification: The verifier can now propagate bounds through 196 patch tokens, each with 768 dimensions, to prove properties about the transformer’s behavior.
Part 4: Architecture and Design Decisions¶
Immutable Context - Clarity Through Constraints¶
The ShapeInferenceContext is a frozen dataclass:
@dataclass(frozen=True)
class ShapeInferenceContext:
data_shapes: dict[str, int | list[int]]
explicit_shapes: dict[str, int | list[int]]
initializers: dict[str, TensorProto]
verbose: bool = False
From shapeonnx/infer_shape.py:19-33.
Why frozen?
Clarity: Can’t accidentally reassign
ctx.data_shapesto a new dict. Changes happen inside the dictionaries (mutation), not to the context itself (reassignment).Safety: Passing
ctxaround doesn’t risk someone changingverboseorinitializersmid-inference.Intent: The frozen annotation signals “this is a container for inference state, not a mutable configuration object.”
State changes happen through dictionary updates:
# This is allowed (mutating dictionary contents):
ctx.data_shapes["new_tensor"] = [1, 64, 32, 32]
# This would fail (reassigning frozen field):
ctx.data_shapes = {} # ❌ FrozenInstanceError
This pattern—immutable container, mutable contents—makes data flow explicit while preserving performance.
Single-Pass Forward Propagation - O(1) Per Operator¶
ONNX graphs are topologically sorted: nodes are ordered such that inputs are computed before outputs. This guarantee enables single-pass inference:
def _infer_all_node_shapes(nodes, ctx):
"""Infer shapes for all nodes in the graph."""
for node in nodes:
# Process each node exactly once, in order
infer_func = INFER_SHAPE_FUNC_MAPPING.get(node.op_type)
if infer_func is None:
raise NotImplementedError(f"Operator {node.op_type} not supported")
results = infer_func(node, ctx)
_process_node_outputs(node, results, ctx)
From shapeonnx/infer_shape.py:1506-1530.
No dependency analysis: We don’t build a graph, topologically sort, or check dependencies. ONNX guarantees this.
O(1) per operator: Each operator is processed once. No backtracking, no multi-pass inference.
Performance: 136 VNN-COMP 2024 models in ~9.5 seconds (~70ms per model). Most time is spent in ONNX loading and model validation, not shape inference.
Direct Dictionary Access - No Abstraction Overhead¶
ShapeONNX uses plain Python dictionaries for shape storage:
data_shapes: dict[str, int | list[int]]
explicit_shapes: dict[str, int | list[int]]
No wrapper classes. No ShapeTensor objects. No abstraction layers.
Why?
Performance: Dictionary lookups are highly optimized in CPython. Adding wrapper classes adds method call overhead.
Simplicity: The code is easier to read when shapes are just dictionaries, not custom types.
Debugging: Can print
ctx.data_shapesand immediately see all tensor shapes. No need to call.to_dict()or iterate over wrapper objects.
Helper functions provide structure without abstraction:
def get_explicit_shape(name, explicit_shapes):
"""Retrieve explicit constant shape value."""
return explicit_shapes.get(name)
def store_explicit_shape(shape, explicit_shapes, name):
"""Store constant shape value."""
explicit_shapes[name] = shape
From shapeonnx/infer_shape.py:96-151.
These are lightweight wrappers—no state, just conveniences for common operations.
Attribute Validation at Extraction Time¶
ONNX operators have attributes (padding mode, stride, etc.). ShapeONNX validates attributes when extracting them, not during shape inference:
def get_attrs_conv(node, initializers):
"""Extract Conv operator attributes."""
attrs = scan_attrs({
"auto_pad": "NOTSET",
"dilations": None,
"group": 1,
"kernel_shape": None,
"pads": None,
"strides": None,
}, node.attribute)
# Validate immediately
validate_auto_pad(attrs["auto_pad"], "Conv")
check_pads_symmetric(attrs["pads"])
# Infer defaults
infer_kernel_defaults(attrs, attrs["kernel_shape"])
return attrs
From shapeonnx/onnx_attrs.py:153-174.
Rationale:
Fail fast: If a model has unsupported attributes, error immediately with clear context (“Conv with auto_pad=SAME_UPPER not supported”), not deep in shape computation.
Separation of concerns: Attribute extraction and validation is separate from shape logic. The shape inference function receives validated attributes.
Code locality: Attribute rules (e.g., “padding must be symmetric”) are defined with attribute extraction, not scattered across shape functions.
Example validation:
def check_pads_symmetric(pads):
"""Verify that padding is symmetric."""
dims = len(pads) // 2
for i in range(dims):
if pads[i] != pads[i + dims]:
raise ValueError(
f"Asymmetric padding {pads} not supported"
)
From shapeonnx/onnx_attrs.py:48-59.
This is a known limitation of ShapeONNX (documented in README): asymmetric padding is not supported. By checking at extraction time, we give clear error messages rather than silently computing wrong shapes.
Why Pure Python - Accessibility and Debuggability¶
ShapeONNX is pure Python with minimal dependencies:
Python 3.10+
onnx 1.17.0
numpy 2.2.4
No C extensions. No Cython. No compiled protobuf wrappers beyond what ONNX provides.
Why?
Debugging: Set a breakpoint, inspect
ctx.data_shapes, step through operator logic. No stepping into compiled C code.Portability: Works on any platform with Python. No platform-specific compilation.
Accessibility: Researchers can read the source, understand the algorithms, and modify for custom operators.
Integration: Verification tools (often research prototypes) can import ShapeONNX without worrying about C++ ABI compatibility or build chains.
Performance is not a bottleneck: Shape inference is lightweight—mostly integer arithmetic and dictionary lookups. The VNN-COMP 2024 benchmark (136 models in ~9.5 seconds) shows pure Python is fast enough.
When performance matters, it’s in verification (SMT solving, bound propagation), not shape inference. ShapeONNX is fast enough to not slow down the critical path.
Part 5: Broadcasting and Complex Operations¶
NumPy-Style Broadcasting - Aligning Shapes¶
Binary operators (Add, Mul, etc.) support NumPy-style broadcasting: shapes are right-aligned, then element-wise broadcasted.
Example: [3, 1, 4] + [2, 4] → [3, 2, 4]
Implementation:
def right_align_shapes(shape1, shape2):
"""Right-align two shapes by padding with 1s."""
max_len = max(len(shape1), len(shape2))
aligned1 = [1] * (max_len - len(shape1)) + shape1
aligned2 = [1] * (max_len - len(shape2)) + shape2
return aligned1, aligned2
def compute_broadcasted_shape(shape1, shape2):
"""Compute broadcasted shape from two aligned shapes."""
result = []
for s1, s2 in zip(shape1, shape2, strict=False):
if s1 != s2 and s1 != 1 and s2 != 1:
raise RuntimeError(f"Cannot broadcast {shape1} and {shape2}")
result.append(max(s1, s2)) # 1 broadcasts to any dimension
return result
def broadcast_shapes(shape1, shape2):
"""Broadcast two shapes using numpy broadcasting rules."""
aligned1, aligned2 = right_align_shapes(shape1, shape2)
return compute_broadcasted_shape(aligned1, aligned2)
From shapeonnx/infer_shape.py:173-224.
Example trace:
shape1 = [3, 1, 4]
shape2 = [2, 4]
# Right-align:
aligned1 = [3, 1, 4]
aligned2 = [1, 2, 4] # Padded with 1 on left
# Broadcast element-wise:
result = [max(3,1), max(1,2), max(4,4)] = [3, 2, 4]
This enables shape inference for operators with broadcasting, like Add, Mul, etc.
ConvTranspose - Upsampling Output Shape Computation¶
ConvTranspose (deconvolution) is the inverse of Conv: it upsamples the spatial dimensions.
Formula:
output_size = (input - 1) * stride - 2*pad + dilation*(kernel-1) + output_padding + 1
Implementation:
def compute_convtranspose_output_hw(
input_shape, weight_shape, kernel_shape,
dilations, output_padding, pads, strides
):
"""Compute output height/width for ConvTranspose."""
dim = len(kernel_shape)
temp1 = [pads[i] + pads[i + dim] for i in range(dim)]
temp2 = [dilations[i] * (kernel_shape[i] - 1) for i in range(dim)]
output_hw = [
math.ceil(
(input_shape[i + 2] - 1) * strides[i]
- temp1[i] + temp2[i] + output_padding[i] + 1
)
for i in range(dim)
]
return output_hw
From shapeonnx/infer_shape.py:573-603.
Example: Input [1, 64, 16, 16], kernel 4×4, stride 2, padding 1:
# Per dimension:
output = (16 - 1) * 2 - 2*1 + 1*(4-1) + 0 + 1
= 15 * 2 - 2 + 3 + 1
= 30 - 2 + 3 + 1
= 32
# Output shape: [1, out_channels, 32, 32]
This correctly handles upsampling, critical for GANs and decoder architectures.
Pooling with ceil_mode - Handling Fractional Output Sizes¶
MaxPool and AveragePool can use ceil_mode to handle fractional output sizes.
Formula:
output_size = floor((input + 2*pad - dilation*(kernel-1) - 1) / stride + 1) # ceil_mode=False
output_size = ceil((input + 2*pad - dilation*(kernel-1) - 1) / stride + 1) # ceil_mode=True
Implementation:
def compute_pool_output_hw(
input_shape, kernel_shape, dilations, pads, strides, ceil_mode
):
"""Compute output height/width for pooling operations."""
dim = len(kernel_shape)
output_hw = []
for i in range(dim):
temp1 = pads[i] + pads[i + dim]
temp2 = dilations[i] * (kernel_shape[i] - 1)
size = (input_shape[i + 2] + temp1 - temp2 - 1) / strides[i] + 1
output_hw.append(math.ceil(size) if ceil_mode else math.floor(size))
return output_hw
From shapeonnx/infer_shape.py:817-843.
Example: Input 32, kernel 3, stride 2, pad 0:
size = (32 + 0 - 1*(3-1) - 1) / 2 + 1
= (32 - 2 - 1) / 2 + 1
= 29 / 2 + 1
= 14.5 + 1
= 15.5
floor(15.5) = 15 # ceil_mode=False
ceil(15.5) = 16 # ceil_mode=True
This matches PyTorch/ONNX semantics for pooling operations.
Slice with Negative Indices - Python-Style Indexing¶
Slice supports negative indices (counting from end) and negative steps (reverse slicing).
Implementation:
def infer_sliced_shape(shape, axes, starts, ends, steps):
"""Infer shape after slicing operation."""
new_shape = list(shape)
for axis, start, end, step in zip(axes, starts, ends, steps, strict=True):
size = shape[axis]
# Handle negative indices
start = min(max(start + size if start < 0 else start, 0), size)
end = min(max(end + size if end < 0 else end, 0), size)
if step < 0:
warnings.warn(f"Negative step ({step}) is not fully tested", stacklevel=2)
# Compute sliced dimension size
new_shape[axis] = max(0, (end - start + (step - (1 if step > 0 else -1))) // step)
return new_shape
From shapeonnx/infer_shape.py:1123-1148.
Example:
# shape[2] = 32
# Slice(start=-16, end=-1, step=1)
# Convert negative indices:
start = -16 + 32 = 16
end = -1 + 32 = 31
# Compute size:
new_size = (31 - 16) // 1 = 15
This enables shape-tensor slicing patterns like extracting the last N dimensions.
Part 6: Real-World Validation - VNN-COMP 2024 Results¶
Benchmark Coverage¶
ShapeONNX was validated on 136 models from VNN-COMP 2024, the International Verification of Neural Networks Competition:
Model Types:
Convolutional Neural Networks (ACAS Xu, Collins, etc.)
ResNets (various depths and widths)
Vision Transformers (ViT, patch embeddings)
GANs with dynamic generation
Graph Neural Networks
Custom verification benchmarks
Opset Range: 17-21 (ONNX IR version 21 recommended)
Success Rate: 100% (136/136 models processed successfully)
Testing Methodology:
Load model with
onnx.load()Convert to IR version 21 with
onnx.version_converterRun ShapeONNX shape inference
Compare with ONNX reference (accounting for shape tensor differences)
Validate consistency (no contradictions in inferred shapes)
Performance Metrics¶
Hardware Configuration
All benchmarks executed on:
CPU: Intel i5-12400F (6 cores, 12 threads)
RAM: 32GB DDR4
OS: Linux/Windows (cross-platform validated)
Shape Inference Performance
Metric |
Result |
Context |
|---|---|---|
Total Models |
136 |
VNN-COMP 2024 |
Total Time |
~9.5 seconds |
Shape inference only |
Average per Model |
~70 milliseconds |
Median performance |
Simple Models |
<10ms |
ACAS Xu, small CNNs |
Complex Models |
<200ms |
ViT, GANs, 100+ layers |
Peak Memory |
<500MB |
Largest models |
Success Rate |
100% |
No failures |
Performance Distribution
Shape Inference Time Distribution (136 models):
├─ Fastest: ~3ms (ACAS Xu single network)
├─ Median: ~68ms (typical ResNet/CNN)
├─ Slowest: ~184ms (Vision Transformer with 100+ layers)
└─ Total: 9.5 seconds for full suite
Complete Validation Timing
Full Test Suite (408 tests = 136 models × 3 test types):
├─ Shape inference: ~9.5 seconds (ShapeONNX processing)
├─ Model loading: ~15 seconds (ONNX parsing, not ShapeONNX)
├─ Validation: ~14.5 seconds (comparison with ONNX reference)
└─ Total: ~39 seconds
Why Speed Matters for Verification
Verification tools spend seconds to hours on SMT solving and bound propagation. Adding ~70ms for accurate shape inference is negligible overhead—less than 0.01% of typical verification time for complex properties. The accuracy gained (concrete static shapes vs dynamic placeholders) far outweighs the minimal performance cost.
Comparison with ONNX’s Shape Inference¶
ShapeONNX provides more accurate shape information than ONNX’s built-in onnx.shape_inference.infer_shapes:
Example 1: Shape Operator Output
# Model: Input[1,48,2,2] → Shape → output
# ONNX result:
shapes["output"] = [4] # It's a 1D tensor with 4 elements
# ShapeONNX result:
data_shapes["output"] = [4] # Metadata (same as ONNX)
explicit_shapes["output"] = [1,48,2,2] # Actual values (better!)
Example 2: ConstantOfShape
# Model: Shape → Gather → ConstantOfShape
# ONNX result:
shapes["constant"] = [-1, -1] # Dynamic (no value information)
# ShapeONNX result:
data_shapes["constant"] = [2, 2] # Static (resolved from explicit_shapes)
Example 3: Dynamic Batch Dimension
# Model: Input with dynamic batch → Conv → output
# ONNX result:
shapes["input"] = [-1, 3, 224, 224] # Batch is dynamic
shapes["output"] = [-1, 64, 112, 112] # Propagates dynamic batch
# ShapeONNX result (with has_batch_dim=True):
data_shapes["input"] = [1, 3, 224, 224] # Batch is static (for verification)
data_shapes["output"] = [1, 64, 112, 112] # Static throughout
Why the difference?
ONNX optimizes for inference: Dynamic batch means the model can handle variable batch sizes at runtime. This is useful for serving.
ShapeONNX optimizes for verification: Verification tools need concrete input shapes to prove properties. We assume batch size 1 (or user-specified) and propagate static shapes.
Validation approach: Our test suite compares ShapeONNX with ONNX reference, accounting for:
Shape tensor values vs metadata: ShapeONNX tracks values, ONNX tracks metadata. We accept this difference as expected.
Static vs dynamic batch: With
has_batch_dim=True, ShapeONNX uses batch size 1. We verify consistency with ONNX for non-batch dimensions.Resolved vs unresolved shapes: ShapeONNX resolves ConstantOfShape/Reshape to static shapes. We verify these are consistent with ONNX’s constraints.
Result: 100% consistency with ONNX reference, with more accurate shape information where it matters.
Known Limitations and Edge Cases¶
1. Asymmetric Padding:
# Not supported:
Conv(pads=[1, 2, 1, 2]) # Left!=Right padding
# Error message:
ValueError: Asymmetric padding [1, 2, 1, 2] not supported
Rationale: Asymmetric padding is uncommon in verification benchmarks. Supporting it requires more complex output size calculations. Future work.
2. Control Flow Operators:
If,Loop,Scanare not supportedThese are rare in verification models (most use static graphs)
Could be added with multi-pass inference for loops
3. Some Operator Attributes:
# Not supported:
BatchNormalization(training_mode=1) # Must be 0 (inference)
Split(num_outputs=3) # Must specify split sizes explicitly
Rationale: Verification models use inference mode and explicit configurations. Unsupported attributes are caught with clear error messages at extraction time.
4. Dynamic Input Shapes:
ShapeONNX assumes static input shapes (or batch=1 with
has_batch_dim=True)Models with truly dynamic inputs (arbitrary image sizes) require user-specified input shapes
This is consistent with verification use cases (verify for specific input dimensions)
These limitations are documented in the README and enforced with clear error messages. For the 136 VNN-COMP 2024 models, none hit these limitations.
Part 7: Integration with Verification Tools¶
SlimONNX Integration - Shape-Dependent Optimization¶
SlimONNX is an ONNX optimizer designed specifically for verification workflows. It uses ShapeONNX for shape-dependent optimizations:
Pipeline:
from shapeonnx import infer_onnx_shape
from slimonnx import optimize_onnx
# Step 1: Infer accurate static shapes
shapes = infer_onnx_shape(
input_nodes, output_nodes, nodes, initializers,
has_batch_dim=True
)
# Step 2: Optimize using shape information
optimized_model = optimize_onnx(model, shapes)
Optimizations enabled by ShapeONNX:
Constant Folding: Reshape with static target shape can be evaluated at compile time if input is constant.
Redundant Operation Removal:
# Without ShapeONNX: Reshape(input[48,256], target=[-1,256]) # Can't optimize, -1 is unknown # With ShapeONNX: Reshape(input[48,256], target=[48,256]) # Identity reshape, remove!
Shape Operator Elimination:
# Shape → Gather → ConstantOfShape chain # ShapeONNX resolves to static shape # SlimONNX replaces entire chain with Constant initializer
Operator Fusion: BatchNorm+Conv fusion requires knowing output channels (from shape inference).
Impact: SlimONNX can achieve significant graph size reduction on VNN-COMP 2024 models when shape-dependent optimizations are enabled by ShapeONNX.
Why Verification Needs Static Shapes¶
Verification tools prove properties like:
∀ input ∈ InputRegion: output(model(input)) ∈ SafeRegion
They work by:
1. Bound Propagation:
Compute interval bounds for each neuron:
neuron_i ∈ [lower_i, upper_i]Requires allocating bounds for every neuron—needs exact layer dimensions
Dynamic shapes → can’t allocate bound storage
2. Constraint Generation:
Encode operations as SMT formulas
Matrix multiplication:
C[i,j] = Σ_k A[i,k] * B[k,j]requires knowing i, j, k rangesDynamic K → can’t generate correct number of constraints
3. Symbolic Execution:
Create symbolic variables for neuron activations
Propagate input constraints through the network
Dynamic shapes → number of symbolic variables is unknown
Example: Verifying a Conv layer with input [1, 64, ?, ?]:
ONNX shape inference: Output is
[1, 128, ?, ?](dynamic spatial dims)Verification tool: “I need to allocate bounds for output neurons… but how many?”
Verification fails
With ShapeONNX:
Input:
[1, 64, 32, 32]Output:
[1, 128, 30, 30]Verification tool: “Allocate bounds for
128 * 30 * 30 = 115,200neurons.”Verification proceeds
Static shapes are not optional for verification—they’re required.
Example: Shape-Dependent Verification¶
Consider a Vision Transformer’s patch embedding layer:
# PyTorch export:
def patch_embed(x): # x: [1, 3, 224, 224]
# Patchify: divide 224x224 into 14x14 patches of 16x16
n_patches = (224 // 16) ** 2 # 196
x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
p1=16, p2=16) # [1, 196, 768]
return x
ONNX graph (simplified):
Input[1,3,224,224]
→ Shape → [1,3,224,224] (values)
→ Slice([2:]) → [224,224]
→ Div(16) → [14,14]
→ Mul → [196]
→ Concat([1,?,768]) → [1,196,768] (target shape)
→ Reshape → [1,196,768]
Without ShapeONNX:
ONNX inference: Reshape target is
[1, -1, 768](dynamic patch count)Verification tool: “Can’t verify—patch dimension is unknown.”
Result: Verification fails on most ViT models
With ShapeONNX:
Shape inference traces through: 224 → 14 → 196
Reshape target is
[1, 196, 768](static)Verification tool: “196 patch tokens, 768 dimensions each. Proceeding with bound propagation.”
Result: Verification succeeds
Real impact: VNN-COMP 2024 included Vision Transformer benchmarks. Before ShapeONNX, verification tools crashed on shape inference. After ShapeONNX, 100% success rate.
The Verification Pipeline¶
A typical neural network verification workflow:
Export ONNX → ShapeONNX → SlimONNX → Verification Tool
↓ ↓ ↓
Static shapes Optimized Formal proof
graph
Step 1: Export:
PyTorch/TensorFlow → ONNX
Often has dynamic shapes, Constant nodes, shape operator chains
Step 2: ShapeONNX:
Resolve all shapes to static values
Track shape tensor values through operator chains
Output: Model with
data_shapesdictionary containing exact dimensions
Step 3: SlimONNX (optional):
Optimize using static shape information
Remove redundant operations, fuse operators
Output: Simplified model (reduced graph complexity)
Step 4: Verification:
Load optimized model with static shapes
Perform bound propagation, constraint generation, SMT solving
Output: Proof of safety property or counterexample
ShapeONNX is the bridge: It takes models with ONNX’s dynamic shapes and provides the static shapes verification tools require.
Conclusion & Future Directions¶
Current State¶
ShapeONNX is production-ready for neural network verification workflows:
136 VNN-COMP 2024 models: 100% success rate
48 operators: Comprehensive coverage of verification-relevant operations
Dual-track representation: Tracks both metadata and values for shape tensors
Fast performance: ~70ms average per model
Integration: Used by SlimONNX and verification tools
Key capabilities:
Static shape resolution: Converts shapes ONNX marks as dynamic (
[-1]) to concrete values ([196]) when statically computable.Shape tensor tracking: Propagates actual dimension values (
[1, 48, 2, 2]) through operator chains, not just metadata ([4]).Operator chain resolution: Resolves complex patterns like
Shape → Gather → Slice → Concat → Reshapeto static target shapes.Verification-ready: Provides the exact static shapes required for bound propagation, constraint generation, and SMT solving.
The Core Innovation¶
The fundamental insight is dual-track representation:
data_shapes: dict[str, list[int]] # Tensor metadata (ONNX standard)
explicit_shapes: dict[str, list[int]] # Constant shape values (ShapeONNX innovation)
Operators check explicit_shapes first, then fall back to data_shapes. This simple pattern enables:
Shape operators to extract and propagate dimension values
ConstantOfShape to create tensors with statically-known shapes
Reshape to resolve
-1dimensions from shape tensor valuesGather/Slice on shape tensors to compute derived shapes
The result: Single-pass O(1) forward propagation that resolves shapes ONNX can’t, without complex multi-pass analysis or constraint solving.
Future Directions¶
1. Asymmetric Padding Support:
Currently unsupported in Conv/Pool. Requires more complex output size calculations:
output = (input + pad_left + pad_right - kernel) / stride + 1
Uncommon in verification benchmarks but would improve completeness.
2. More Operator Coverage:
NonMaxSuppression: Dynamic output size (requires max_output_boxes_per_class)
TopK: K-largest elements (output size depends on K value)
Control flow: If, Loop, Scan (need multi-pass inference)
3. Better Control Flow Handling:
Current approach is single-pass. Supporting Loop requires:
Track shapes across iterations
Detect convergence (shape doesn’t change after iteration N)
Handle dynamic iteration counts (max iterations)
4. Integration with More Verification Tools:
α,β-CROWN: Complete verifier for ReLU networks
ERAN: ETH’s abstract interpretation verifier
Marabou: SMT-based verifier
ShapeONNX is already used indirectly (via SlimONNX). Direct integration could enable shape-aware verification optimizations.
5. Shape Inference for Quantized Models:
Quantized ONNX models (INT8, UINT8) have similar shape inference needs. ShapeONNX could extend to:
QuantizeLinear/DequantizeLinear operators
QLinearConv/QLinearMatMul
Dynamic quantization patterns
Open Questions¶
Can dual-track representation extend to value tracking?
ShapeONNX tracks shape values. Could we track data values for constant tensors?
# Example: Constant initializer used in Add
constant_values: dict[str, np.ndarray] # Track actual constant data
# Enable compile-time evaluation (constant folding):
# Before: Add(weights, bias)
# After: Constant(weights + bias) # Folded at compile time
This is essentially constant folding, already done by SlimONNX. But tracking constant values during shape inference could enable more aggressive optimizations.
Can we infer shapes without ONNX’s topological ordering?
ShapeONNX relies on ONNX’s guarantee that nodes are topologically sorted. If we needed to handle arbitrary graph order:
Build dependency graph
Topologically sort
Infer in order
This adds complexity but could enable shape inference for non-ONNX formats (TensorFlow GraphDef, etc.).
How do we handle truly dynamic models?
Some models have inherently dynamic shapes (arbitrary input image size, variable sequence length). ShapeONNX assumes static inputs. Extending to dynamic inputs would require:
Symbolic shape tracking (
shape = [batch, height, width]with symbolic variables)Constraint solving (
height * width = 50176→height = 224, width = 224if we know one)Integration with symbolic execution frameworks
This goes beyond shape inference into program analysis—an open research problem.
Try It Yourself¶
ShapeONNX is open source and easy to use:
pip install onnx==1.17.0 numpy==2.2.4
import onnx
from shapeonnx import infer_onnx_shape
from shapeonnx.shapeonnx.utils import (
get_initializers,
get_input_nodes,
get_output_nodes,
convert_constant_to_initializer,
)
# Load model
model = onnx.load("your_model.onnx")
model = onnx.version_converter.convert_version(model, target_version=21)
# Infer shapes
initializers = get_initializers(model)
input_nodes = get_input_nodes(model, initializers, has_batch_dim=True)
output_nodes = get_output_nodes(model, has_batch_dim=True)
nodes = convert_constant_to_initializer(list(model.graph.node), initializers)
shapes = infer_onnx_shape(
input_nodes, output_nodes, nodes, initializers,
has_batch_dim=True, verbose=True
)
# Inspect results
for name, shape in shapes.items():
print(f"{name}: {shape}")
Repository: ZhongkuiMa/shapeonnx
Related projects:
SlimONNX: ONNX optimizer for verification (uses ShapeONNX)
WraLU: ReLU hull approximation for tighter bounds
WraAct: Convex hull approximation for general activations
Feedback and contributions welcome. ShapeONNX is a tool for the verification community—if you have use cases, models that fail, or operators to add, please open an issue or pull request.
Acknowledgments¶
ShapeONNX was developed for the VNN-COMP 2024 competition and benefited from testing on diverse benchmarks from the verification community. Thanks to the competition organizers and participants for providing challenging models that drove the need for accurate shape inference.
ShapeONNX: Because verification needs to know exactly what shapes are, not just that they’re shapes.
Comments & Discussion