Skip to main content
ONNXShape InferenceStatic Analysis

ShapeONNX: Solving ONNX's Dynamic Shape Problem

A dual-track shape inference tool that resolves ONNX's dynamic shapes to concrete static values for neural network verification workflows.

Introduction

The Fundamental Problem:

ONNX was designed to answer: “What shape does this tensor have?” (metadata)

But verification tools need: “What values does this shape tensor contain?” (semantics)

The gap is not a bug in ONNX — it’s a different problem domain.

Example:

# Verification tool receives Vision Transformer with:
x = torch.randn(1, 3, 224, 224)
patches = x.reshape(1, 196, 768)  # 196 patches = (224/16)^2
  • ONNX tracks: “Reshape target is a 1D tensor with 3 elements (metadata: [3])”
  • ONNX doesn’t track: “Those 3 elements contain [1, 196, 768] (semantics)”
  • Result: Reshape output shape becomes [-1, -1, -1] (can’t infer)

Verification needs the second: concrete shape values.

Solution: Dual-track representation. Ask two questions separately:

  • data_shapes["x"]: “What shape?” → [1, 196, 768] (metadata)
  • explicit_shapes["shape_tensor"]: “What values?” → [1, 196, 768] (semantics)

Result: 136 VNN-COMP 2024 models verified with static shapes, ~70ms average per model.

Part 1: The Dual-Track Representation

Core idea: Two dictionaries instead of one.

# Traditional ONNX: only data_shapes
shapes = {
    "input": [1, 48, 2, 2],
    "shape_output": [4],          # "This is a 1D tensor"
    "reshape_output": [-1, -1],   # "Dynamic, don't know"
}

# ShapeONNX: data_shapes + explicit_shapes
data_shapes = {
    "input": [1, 48, 2, 2],
    "shape_output": [4],
    "reshape_output": [2, 2]
}
explicit_shapes = {
    "shape_output": [1, 48, 2, 2],  # "Contains these values"
    "gather_output": [2, 2]          # "Sliced to these values"
}

Decision Table - When to use which dictionary:

OperatorInput (data/explicit)Output (data/explicit)Logic
Shapedata only(None, data)Extract dimension values
Gather on shapesexplicit(None, sliced)Slice shape values
Slice on shapesexplicit(None, sliced)Slice shape values
Concat shapesexplicit(None, concatenated)Concatenate values
ConstantOfShapeexplicit(explicit, None)Output shape = input values
Reshapedata + explicit(output, None)Target shape from explicit

Retrieval pattern (used by all operators):

def get_shape(name, data_shapes, explicit_shapes):
    # Priority: explicit values first (for shape tensors)
    if name in explicit_shapes:
        return explicit_shapes[name], True  # "I know the values"
    elif name in data_shapes:
        return data_shapes[name], False     # "I know the dimensions"
    else:
        raise RuntimeError(f"Shape unknown: {name}")

Part 2: How It Works

Vision Transformer patch embedding (step-by-step):

  1. Input: data_shapes["input"] = [1, 3, 224, 224]

  2. Shape operator: Extract dimension values

    explicit_shapes["shape_vec"] = [1, 3, 224, 224]
    data_shapes["shape_vec"] = [4]  (it's a 1D tensor)
  3. Gather(indices=[2,3]): Slice the shape values

    explicit_shapes["spatial_dims"] = [224, 224]
    data_shapes["spatial_dims"] = [2]  (it's a 1D tensor)
  4. Div(16): Compute patch count

    explicit_shapes["patch_count"] = [14, 14]  (224/16 = 14)
  5. Mul: Flatten to scalar

    explicit_shapes["n_patches"] = 196
  6. Concat([1, 196, -1]): Build target shape

    explicit_shapes["reshape_target"] = [1, 196, -1]
  7. Reshape: Use explicit shape to infer -1

    data_shapes["output"] = [1, 196, 768]  (inferred: total 1*3*224*224 / (1*196) = 768)

ONNX would report: [−1, −1, −1] (all dynamic)

ShapeONNX reports: [1, 196, 768] (fully static)

Part 3: Key Design Decisions

Single-pass inference: ONNX graphs are topologically sorted. Process each node once in order. No multi-pass analysis needed.

Performance: Pure Python 3.10+, no C extensions. Graph traversal overhead is <1% of verification runtime (bound computation dominates).

Limitations (documented and caught early):

  • Asymmetric padding: Not supported (rare in verification)
  • Control flow (If/Loop): Not supported (most verification models use static graphs)
  • Dynamic input shapes: Assumes static inputs or batch_size=1

Part 4: Common Questions

Q: Can I use only explicit_shapes? No. data_shapes tracks regular tensor operations. explicit_shapes tracks shape tensor operations. Both are needed.

Q: What if a shape tensor’s value isn’t computed statically? Falls back to data_shapes. ShapeONNX resolves what it can statically determine.

Q: How does this integrate with verification tools? Pass the resolved data_shapes dict to your bound propagator. Static dimensions enable:

  • Bound allocation for each neuron
  • Constraint generation with known dimensions
  • Memory allocation without surprises

Q: How fast is it? ~70ms average per model (136 models in 9.5 seconds total). Detailed breakdown:

  • ONNX loading: ~110ms per model
  • Shape inference: ~70ms per model
  • Complex models (ViT, GANs): <200ms

Part 5: Verification Integration

Typical workflow:

PyTorch/TensorFlow model
  ↓ export
ONNX model (contains dynamic shapes, shape operators)
  ↓ ShapeONNX (推导值)
Static shape dictionary (e.g., output = [1, 1000])
  ↓ SlimONNX (可选优化)
Simplified graph (redundant ops removed)
  ↓ Verification tool (CROWN, DeepPoly)
Safety proof or counterexample

What ShapeONNX does: Computes shape tensor values (the explicit_shapes dict). Verification tool uses this to allocate bounds for each neuron.

Real impact: Vision Transformer models (former bottleneck: dynamic shapes from patch embedding) now verify successfully.

Limitations:

  • Assumes batch size = 1 (or user-specified static input shape)
  • Truly dynamic shapes (e.g., process arbitrary image sizes) require symbolic analysis
  • For typical verification use cases (fixed input domain), static shapes are sufficient

Try It Yourself

pip install onnx==1.17.0 numpy==2.2.4
import onnx
from shapeonnx import infer_onnx_shape
from shapeonnx.utils import get_initializers, get_input_nodes, get_output_nodes

# Load model
model = onnx.load("your_model.onnx")
model = onnx.version_converter.convert_version(model, target_version=21)

# Infer shapes
initializers = get_initializers(model)
input_nodes = get_input_nodes(model, initializers, has_batch_dim=True)
output_nodes = get_output_nodes(model, has_batch_dim=True)

shapes = infer_onnx_shape(
    input_nodes, output_nodes,
    list(model.graph.node), initializers,
    has_batch_dim=True
)

# Use static shapes in verification
print(f"Output shape: {shapes['output']}")  # [1, 1000] not [-1, -1]

Conclusion

ShapeONNX is not a criticism of ONNX—ONNX’s design for dynamic shapes is correct for deployment. ShapeONNX is a complementary tool for verification workflows that require static dimensions.

Key insight: Track shape tensor values separately from regular tensor metadata.

Result: Resolve shapes ONNX must mark as dynamic to concrete static values, enabling verification on models with computed shapes.

If you’re building verification tools, ShapeONNX bridges the gap between ONNX’s flexible dynamic shapes and verification’s need for concrete dimensions.

Repository: https://github.com/ZhongkuiMa/shapeonnx