ShapeONNX: Solving ONNX's Dynamic Shape Problem
A dual-track shape inference tool that resolves ONNX's dynamic shapes to concrete static values for neural network verification workflows.
Introduction
The Fundamental Problem:
ONNX was designed to answer: “What shape does this tensor have?” (metadata)
But verification tools need: “What values does this shape tensor contain?” (semantics)
The gap is not a bug in ONNX — it’s a different problem domain.
Example:
# Verification tool receives Vision Transformer with:
x = torch.randn(1, 3, 224, 224)
patches = x.reshape(1, 196, 768) # 196 patches = (224/16)^2
- ONNX tracks: “Reshape target is a 1D tensor with 3 elements (metadata:
[3])” - ONNX doesn’t track: “Those 3 elements contain
[1, 196, 768](semantics)” - Result: Reshape output shape becomes
[-1, -1, -1](can’t infer)
Verification needs the second: concrete shape values.
Solution: Dual-track representation. Ask two questions separately:
data_shapes["x"]: “What shape?” →[1, 196, 768](metadata)explicit_shapes["shape_tensor"]: “What values?” →[1, 196, 768](semantics)
Result: 136 VNN-COMP 2024 models verified with static shapes, ~70ms average per model.
Part 1: The Dual-Track Representation
Core idea: Two dictionaries instead of one.
# Traditional ONNX: only data_shapes
shapes = {
"input": [1, 48, 2, 2],
"shape_output": [4], # "This is a 1D tensor"
"reshape_output": [-1, -1], # "Dynamic, don't know"
}
# ShapeONNX: data_shapes + explicit_shapes
data_shapes = {
"input": [1, 48, 2, 2],
"shape_output": [4],
"reshape_output": [2, 2]
}
explicit_shapes = {
"shape_output": [1, 48, 2, 2], # "Contains these values"
"gather_output": [2, 2] # "Sliced to these values"
}
Decision Table - When to use which dictionary:
| Operator | Input (data/explicit) | Output (data/explicit) | Logic |
|---|---|---|---|
| Shape | data only | (None, data) | Extract dimension values |
| Gather on shapes | explicit | (None, sliced) | Slice shape values |
| Slice on shapes | explicit | (None, sliced) | Slice shape values |
| Concat shapes | explicit | (None, concatenated) | Concatenate values |
| ConstantOfShape | explicit | (explicit, None) | Output shape = input values |
| Reshape | data + explicit | (output, None) | Target shape from explicit |
Retrieval pattern (used by all operators):
def get_shape(name, data_shapes, explicit_shapes):
# Priority: explicit values first (for shape tensors)
if name in explicit_shapes:
return explicit_shapes[name], True # "I know the values"
elif name in data_shapes:
return data_shapes[name], False # "I know the dimensions"
else:
raise RuntimeError(f"Shape unknown: {name}")
Part 2: How It Works
Vision Transformer patch embedding (step-by-step):
-
Input:
data_shapes["input"] = [1, 3, 224, 224] -
Shape operator: Extract dimension values
explicit_shapes["shape_vec"] = [1, 3, 224, 224] data_shapes["shape_vec"] = [4] (it's a 1D tensor) -
Gather(indices=[2,3]): Slice the shape values
explicit_shapes["spatial_dims"] = [224, 224] data_shapes["spatial_dims"] = [2] (it's a 1D tensor) -
Div(16): Compute patch count
explicit_shapes["patch_count"] = [14, 14] (224/16 = 14) -
Mul: Flatten to scalar
explicit_shapes["n_patches"] = 196 -
Concat([1, 196, -1]): Build target shape
explicit_shapes["reshape_target"] = [1, 196, -1] -
Reshape: Use explicit shape to infer -1
data_shapes["output"] = [1, 196, 768] (inferred: total 1*3*224*224 / (1*196) = 768)
ONNX would report: [−1, −1, −1] (all dynamic)
ShapeONNX reports: [1, 196, 768] (fully static)
Part 3: Key Design Decisions
Single-pass inference: ONNX graphs are topologically sorted. Process each node once in order. No multi-pass analysis needed.
Performance: Pure Python 3.10+, no C extensions. Graph traversal overhead is <1% of verification runtime (bound computation dominates).
Limitations (documented and caught early):
- Asymmetric padding: Not supported (rare in verification)
- Control flow (If/Loop): Not supported (most verification models use static graphs)
- Dynamic input shapes: Assumes static inputs or batch_size=1
Part 4: Common Questions
Q: Can I use only explicit_shapes?
No. data_shapes tracks regular tensor operations. explicit_shapes tracks shape tensor operations. Both are needed.
Q: What if a shape tensor’s value isn’t computed statically?
Falls back to data_shapes. ShapeONNX resolves what it can statically determine.
Q: How does this integrate with verification tools?
Pass the resolved data_shapes dict to your bound propagator. Static dimensions enable:
- Bound allocation for each neuron
- Constraint generation with known dimensions
- Memory allocation without surprises
Q: How fast is it? ~70ms average per model (136 models in 9.5 seconds total). Detailed breakdown:
- ONNX loading: ~110ms per model
- Shape inference: ~70ms per model
- Complex models (ViT, GANs): <200ms
Part 5: Verification Integration
Typical workflow:
PyTorch/TensorFlow model
↓ export
ONNX model (contains dynamic shapes, shape operators)
↓ ShapeONNX (推导值)
Static shape dictionary (e.g., output = [1, 1000])
↓ SlimONNX (可选优化)
Simplified graph (redundant ops removed)
↓ Verification tool (CROWN, DeepPoly)
Safety proof or counterexample
What ShapeONNX does: Computes shape tensor values (the explicit_shapes dict). Verification tool uses this to allocate bounds for each neuron.
Real impact: Vision Transformer models (former bottleneck: dynamic shapes from patch embedding) now verify successfully.
Limitations:
- Assumes batch size = 1 (or user-specified static input shape)
- Truly dynamic shapes (e.g., process arbitrary image sizes) require symbolic analysis
- For typical verification use cases (fixed input domain), static shapes are sufficient
Try It Yourself
pip install onnx==1.17.0 numpy==2.2.4
import onnx
from shapeonnx import infer_onnx_shape
from shapeonnx.utils import get_initializers, get_input_nodes, get_output_nodes
# Load model
model = onnx.load("your_model.onnx")
model = onnx.version_converter.convert_version(model, target_version=21)
# Infer shapes
initializers = get_initializers(model)
input_nodes = get_input_nodes(model, initializers, has_batch_dim=True)
output_nodes = get_output_nodes(model, has_batch_dim=True)
shapes = infer_onnx_shape(
input_nodes, output_nodes,
list(model.graph.node), initializers,
has_batch_dim=True
)
# Use static shapes in verification
print(f"Output shape: {shapes['output']}") # [1, 1000] not [-1, -1]
Conclusion
ShapeONNX is not a criticism of ONNX—ONNX’s design for dynamic shapes is correct for deployment. ShapeONNX is a complementary tool for verification workflows that require static dimensions.
Key insight: Track shape tensor values separately from regular tensor metadata.
Result: Resolve shapes ONNX must mark as dynamic to concrete static values, enabling verification on models with computed shapes.
If you’re building verification tools, ShapeONNX bridges the gap between ONNX’s flexible dynamic shapes and verification’s need for concrete dimensions.
Repository: https://github.com/ZhongkuiMa/shapeonnx