TorchONNX: A Compiler for ONNX-to-PyTorch Conversion
A pure Python compiler that converts ONNX models to native PyTorch code through a 6-stage pipeline, achieving 100% success on VNN-COMP 2024 benchmarks.
Introduction
I spent two years working with neural network verification benchmarks before I realized the ecosystem had a critical asymmetry problem. PyTorch → ONNX conversion? Flawless. Official support, comprehensive documentation, battle-tested on millions of models. ONNX → PyTorch? A fragmented landscape of runtime wrappers, incomplete tooling, and code that makes you wonder if the conversion actually worked.
Here’s the frustration: You receive an ONNX model from VNN-COMP 2024 (the International Verification of Neural Networks Competition). You need to analyze it in PyTorch—maybe to use torch.vmap for efficient batch verification, maybe to inspect gradients with PyTorch’s superior debugging tools, maybe to fine-tune it. The existing tool (onnx2pytorch) gives you a runtime interpreter that iterates over ONNX nodes in a forward() loop. It works, technically. But the generated “model” is a black box that executes ONNX operations at runtime rather than true PyTorch code.
What I needed wasn’t a runtime wrapper—it was a compiler. Something that would generate clean, readable PyTorch code (a .py file) with proper parameter separation (a .pth file). Code that looks like a human wrote it, not a machine. Code I can modify, extend, and understand.
That gap motivated TorchONNX: a pure Python compiler that converts ONNX models to native PyTorch in seconds (small models less than 1s, medium models 2-5s, large models like VGG16 in 5-15s) and generates code that runs at native PyTorch speed with less than 5% overhead on CPU. Not “run ONNX ops in PyTorch,” but “emit static PyTorch code that matches ONNX semantics exactly.” This post is the story of building that compiler—the 6-stage pipeline architecture, the challenges of operator mapping, how we achieved torch.vmap compatibility, and what I learned about the surprisingly subtle differences between frameworks.
We’ll cover:
- Why ONNX → PyTorch conversion matters (verification workflows, model interoperability)
- The compiler architecture (6 stages: normalization, structural IR, semantic IR, optimization, code generation, code cleanup)
- Technical challenges (Gemm → Linear mapping, dynamic batching, vmap compatibility)
- Real-world validation (100+ VNN-COMP 2024 models, numerical equivalence testing)
- Lessons learned (compiler design, code quality, numerical precision across hardware)
If you work with ONNX models and wish they were PyTorch, or if you’re interested in compiler design for neural networks, this is for you.
Part 1: Understanding the Gap - Why ONNX to PyTorch Conversion Matters
Let’s start with fundamentals. The machine learning ecosystem has a curious asymmetry in model conversion support.
The Asymmetry Problem
PyTorch → ONNX: This direction is officially supported, extensively documented, and used by millions of models in production:
import torch
model = MyModel()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "model.onnx", opset_version=17)
This torch.onnx.export() function is maintained by PyTorch core, handles complex dynamic graphs, supports custom operators, and even includes symbolic shape inference. It’s a first-class citizen.
ONNX → PyTorch: This direction? No official support. The PyTorch documentation doesn’t even mention it. If you search “onnx to pytorch,” you’ll find third-party tools with varying levels of completeness.
Why the asymmetry?
PyTorch’s philosophy prioritizes the export workflow: train in PyTorch (dynamic, flexible, researcher-friendly), then export to ONNX for deployment (portable, optimized, framework-agnostic). The reverse journey—starting with ONNX and going to PyTorch—wasn’t part of the original design goals.
But real-world workflows don’t always fit this pattern.
When You Need ONNX to PyTorch Conversion
Use Case 1: Neural Network Verification Workflows
The verification community standardizes on ONNX. VNN-COMP (International Verification of Neural Networks Competition) uses ONNX as the model format. Verification tools like alpha,beta-CROWN, ERAN, and Marabou all consume ONNX models.
The workflow:
- Receive ONNX model from benchmark (e.g., Vision Transformer from VNN-COMP 2024)
- Need to verify safety properties (adversarial robustness, output bounds)
- Want to use PyTorch tools for analysis (gradient inspection, adversarial attack generation, efficient batch processing with
torch.vmap)
The problem: You’re stuck with ONNX Runtime for inference. PyTorch’s flexibility and tooling are inaccessible.
Example: Testing adversarial robustness with torch.vmap:
# This is what we want to do:
import torch
from my_converted_model import Model
model = Model()
model.load_state_dict(torch.load("model.pth"))
# Efficient batch verification with vmap
def verify_single(input):
return model(input)
# Vectorize over batch (efficient parallel evaluation)
inputs = torch.randn(1000, 3, 224, 224)
outputs = torch.vmap(verify_single)(inputs)
With ONNX Runtime? You’re writing Python loops. With onnx2pytorch’s runtime wrapper? vmap compatibility is limited because the interpreter uses stateful operations.
Use Case 2: Model Interoperability
Research collaborations often involve receiving models trained in different frameworks:
- Colleague exports model from TensorFlow → ONNX
- You need to fine-tune it in PyTorch
- ONNX → PyTorch conversion enables this
Use Case 3: Debugging and Analysis
PyTorch’s debugging tools are superior:
register_forward_hook()for layer inspection- Gradient visualization with
torch.autograd - Integration with TensorBoard
- Native Python debugging (breakpoints in
forward())
ONNX models? You’re limited to ONNX Runtime’s API, which is designed for inference, not research.
The Existing Tool Landscape
Let’s analyze the state of the art: onnx2pytorch (github.com/ToriML/onnx2pytorch).
Approach: Runtime wrapper. The tool creates a PyTorch nn.Module that iterates over ONNX nodes and executes them at runtime.
Generated code pattern:
class ONNXModel(nn.Module):
def __init__(self, onnx_graph):
super().__init__()
self.onnx_nodes = onnx_graph.node
self.initializers = {init.name: init for init in onnx_graph.initializer}
def forward(self, **inputs):
tensors = inputs.copy()
for node in self.onnx_nodes:
# Runtime interpretation: execute ONNX op
op_type = node.op_type
inputs = [tensors[name] for name in node.input]
outputs = execute_onnx_op(op_type, inputs, node.attribute)
for name, tensor in zip(node.output, outputs):
tensors[name] = tensor
return tensors[self.output_names[0]]
Analysis:
Advantages:
- Fast to implement (just wrap ONNX execution)
- Handles all ONNX operators (delegates to runtime)
- Low development overhead
Disadvantages:
-
Runtime Overhead: Every
forward()call iterates through all nodes, interpreting operations. No compilation, no optimization. -
Code Quality: The generated code is opaque. Try debugging it:
# What layer is this? What does it compute? output = execute_onnx_op("Gemm", [input_0, weight_1, bias_2], attrs)Compare to true PyTorch:
# Clear, readable, modifiable output = self.linear1(input) -
Dependency on ONNX: The wrapper requires
onnxandonnxruntimeat runtime. Can’t use pure PyTorch. -
Limited vmap Support: Runtime interpretation uses stateful operations that break
torch.vmap’s functional requirements. -
Not Modifiable: Want to change the architecture? Good luck editing the runtime loop.
| Aspect | onnx2pytorch (Runtime) | TorchONNX (Compiler) |
|---|---|---|
| Code generation | Graph traversal loop | Static PyTorch code |
| Dependencies | Requires ONNX at runtime | Pure PyTorch |
| Performance | Overhead from interpretation | Native PyTorch speed |
| Readability | Opaque node iteration | Human-readable code |
| Maintainability | Hard to modify | Edit .py file directly |
| vmap compatibility | Limited | Full support |
| Debugging | Step through interpreter | Standard PyTorch debugging |
Design Goals for TorchONNX
The existing landscape made our requirements clear:
1. True Compilation: Static Code Generation
Generate PyTorch code that compiles to the same computation as the ONNX model, with zero runtime interpretation overhead. If ONNX has Conv → BatchNorm → ReLU, the generated code should be:
def forward(self, x0):
x1 = self.conv1(x0)
x2 = self.bn1(x1)
x3 = torch.relu(x2)
return x3
Not a runtime loop iterating over nodes.
2. Separation of Structure and Parameters
.pyfile: Model architecture (__init__,forward).pthfile: Trained weights (state dict)
Why separate?
- Readability: Can inspect model structure without loading weights
- Version control:
.pyfiles are text (diffable),.pthfiles are binary - Flexibility: Load structure, then swap different checkpoints
- Debugging: Modify
.pycode and reload without re-converting
3. Production-Ready Code Quality
Generated code should look like a human wrote it:
- Full type hints (Python 3.10+)
- Clean formatting (Black-formatted)
- Semantic naming (
self.conv1, notself.onnx_op_0) - No dead code (unused buffers removed)
- Docstrings for helpers
4. Verification-Aware Features
- Dynamic batch dimension: Convert hardcoded
batch_size=1to-1 - vmap compatibility: Functional operations (no in-place mutations)
- Numerical equivalence validation: Test generated PyTorch vs. ONNX Runtime
- Tight tolerances:
rtol=1e-5, atol=1e-6for verification workflows
5. Maintainability and Extensibility
- Pure Python (no C extensions)
- Modular 6-stage pipeline
- Clear intermediate representations (IRs)
- Easy to add new operators (just add handler function)
These goals led to a compiler architecture, not a runtime wrapper.
Note: The key insight: Verification workflows need compiler-generated PyTorch code, not runtime ONNX interpretation. The difference is foundational—one gives you static, modifiable, analyzable code; the other gives you a black box.
Part 2: Design Philosophy and Compiler Architecture
Building a compiler is fundamentally different from building a runtime wrapper. Let’s explore the design philosophy and architectural decisions.
Compiler vs. Interpreter: A Fundamental Choice
The interpreter approach (what we rejected):
class InterpreterModel(nn.Module):
def __init__(self, onnx_nodes):
super().__init__()
self.nodes = onnx_nodes # Store ONNX graph
def forward(self, x):
tensors = {"input": x}
for node in self.nodes: # Runtime interpretation
op_func = get_onnx_op(node.op_type)
outputs = op_func(node.inputs, node.attributes)
tensors.update(outputs)
return tensors["output"]
Problem: Every forward() call pays the cost of graph traversal, node attribute lookup, and dynamic dispatch. More critically, the code is opaque—you can’t see what the model computes without running it.
The compiler approach (TorchONNX):
class CompiledModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 3) # Static structure
self.bn1 = nn.BatchNorm2d(64)
def forward(self, x0):
x1 = self.conv1(x0) # Native PyTorch operations
x2 = self.bn1(x1)
return x2
Advantages:
- Zero runtime overhead: Compiled to native PyTorch
- Readable: You can see the model architecture immediately
- Modifiable: Edit
forward()and reload - Debuggable: Set breakpoints, inspect tensors
- vmap-compatible: Pure functional code
Mathematical Correctness Framework
The central requirement for a compiler: semantic equivalence.
For ONNX model and generated PyTorch model , we require:
Where:
- is the input domain
- is machine epsilon (floating-point rounding tolerance, ~1e-7 for float32)
Why this matters for verification:
Verification tools prove properties like:
If the PyTorch model diverges from ONNX by even , verification results could be invalidated. A model proven safe in ONNX but numerically different in PyTorch breaks the correctness guarantee.
Testing semantic equivalence:
import numpy as np
import onnxruntime as ort
# Load ONNX model with ONNX Runtime
onnx_session = ort.InferenceSession("model.onnx")
# Load generated PyTorch model
from model import Model
torch_model = Model()
torch_model.load_state_dict(torch.load("model.pth"))
# Test on random inputs
for _ in range(100):
x = np.random.randn(1, 3, 224, 224).astype(np.float32)
onnx_output = onnx_session.run(None, {"input": x})[0]
torch_output = torch_model(torch.from_numpy(x)).numpy()
# Require tight numerical equivalence
np.testing.assert_allclose(
onnx_output, torch_output,
rtol=1e-5, atol=1e-6,
err_msg="Semantic equivalence violated!"
)
For 100+ VNN-COMP 2024 models, this test passes with .
The 6-Stage Compiler Pipeline
TorchONNX uses a multi-stage compilation pipeline with explicit intermediate representations (IRs):
ONNX Model (.onnx)
↓
[Stage 1: Normalization]
- ONNX validation
- Opset conversion (target opset 20)
- Shape inference (ShapeONNX integration)
↓
Normalized ONNX
↓
[Stage 2: Structural IR]
- Extract graph topology
- Build NodeIR (pure structure, no semantics)
↓
Structural IR (ModelIR)
↓
[Stage 3: Semantic IR]
- ONNX op → PyTorch type mapping
- Tensor classification (parameter/buffer/argument)
- Layer attribute extraction
↓
Semantic IR (SemanticModelIR)
↓
[Stage 4: IR Optimization]
- Dead code elimination
- Constant folding (future)
↓
Optimized Semantic IR
↓
[Stage 5: Code Generation]
- Generate __init__ (layer instantiation)
- Generate forward (computation graph)
- Generate state_dict (parameters & buffers)
↓
Python Code + State Dict
↓
[Stage 6: Code Optimization]
- Remove default arguments
- Convert to positional args
- Remove unused buffers
- Format with Black
↓
Final PyTorch Model (.py + .pth)
Why 6 stages?
Each stage has a single responsibility and produces a well-defined IR. This makes the compiler:
- Testable: Validate each stage independently
- Debuggable: Inspect IR at any stage
- Extensible: Add new optimizations without breaking existing stages
- Maintainable: Clear boundaries between concerns
From torchonnx/_torchonnx.py:33-92 (source):
def convert(self, onnx_path, benchmark_name=None, target_py_path=None,
target_pth_path=None, vmap_mode=True):
"""Convert ONNX model to PyTorch through 6-stage pipeline."""
# Stage 1: Normalize
model = load_and_preprocess_onnx_model(
onnx_path, target_opset=20, infer_shapes=True,
check_model=True, use_shapeonnx=self.use_shapeonnx
)
# Stage 2: Build structural IR
raw_ir = build_model_ir(model)
# Stage 3: Build semantic IR
semantic_ir = build_semantic_ir(raw_ir)
# Stage 4: Optimize IR
optimized_ir = optimize_semantic_ir(semantic_ir)
# Stage 5: Generate PyTorch code
code, state_dict = generate_pytorch_module(
optimized_ir, camel_class_name, vmap_mode=vmap_mode
)
# Stage 6: Optimize generated code
optimized_code, state_dict = optimize_generated_code(code, state_dict)
formatted_code = format_code(optimized_code)
final_code = add_file_header(formatted_code, camel_class_name, onnx_path)
# Write outputs
with open(target_py_path, "w") as f:
f.write(final_code)
torch.save(state_dict, target_pth_path)
Functional Design Principles
TorchONNX is built on functional programming principles:
1. Immutability
All IR objects are frozen dataclasses:
from dataclasses import dataclass
@dataclass(frozen=True)
class NodeIR:
"""Immutable structural IR node."""
name: str
onnx_op_type: str
raw_attributes: dict[str, Any]
input_names: list[str]
output_names: list[str]
input_shapes: dict[str, tuple[int, ...] | None]
output_shapes: dict[str, tuple[int, ...] | None]
# This would fail:
node = NodeIR(...)
node.name = "new_name" # FrozenInstanceError
Why frozen?
- No hidden mutations: IRs can’t change under your feet
- Clear data flow: Transformations return new IRs, don’t modify existing ones
- Thread-safe: Immutable data is inherently thread-safe
- Debuggable: You can inspect an IR and trust it hasn’t been modified elsewhere
2. Pure Functions
All transformations are pure functions: IR → IR with no side effects.
Example from torchonnx/optimize/optimizer.py:
def optimize_semantic_ir(ir: SemanticModelIR) -> SemanticModelIR:
"""Pure function: optimize IR, return new IR."""
optimized_layers = remove_dead_layers(ir.layers)
return replace(ir, layers=optimized_layers)
def remove_dead_layers(layers: list[SemanticLayerIR]) -> list[SemanticLayerIR]:
"""Pure function: filter layers, return new list."""
return [layer for layer in layers if not layer.is_dead_code]
Benefits:
- Testable in isolation:
assert optimize_semantic_ir(input_ir) == expected_ir - Composable: Chain optimizations without worrying about order dependencies
- Deterministic: Same input always produces same output
- No global state: All context is passed explicitly
3. Separation of Concerns
The codebase is organized into modules matching the 6 stages:
torchonnx/
├── normalize/ # Stage 1: ONNX preprocessing
│ ├── normalize.py
│ └── utils.py
├── build/ # Stage 2: Structural IR
│ ├── builder.py
│ └── types.py
├── analyze/ # Stage 3: Semantic IR
│ ├── builder.py
│ ├── types.py
│ ├── tensor_classifier.py
│ └── type_mapping/
├── optimize/ # Stage 4: IR optimization
│ └── optimizer.py
├── generate/ # Stage 5: Code generation
│ ├── _init_gen.py
│ ├── _forward_gen.py
│ ├── _state_dict_gen.py
│ └── _handlers/
└── simplify/ # Stage 6: Code optimization
├── _optimizer.py
└── _rules.py
Each module has one job. Code generation doesn’t do semantic analysis. IR optimization doesn’t generate code. This makes the architecture maintainable and extensible.
Tip: When building compilers, explicit IRs are your friend. Each stage should produce a well-defined intermediate representation that the next stage consumes. Trying to do everything in one pass leads to unmaintainable spaghetti code.
Part 3: The Compiler Pipeline in Detail
Let’s dive into each stage of the compiler pipeline, from ONNX ingestion to final PyTorch code.
Stage 1 and 2: Normalization and Structural IR
Stage 1: Normalization
Before we can compile, the ONNX model needs preprocessing:
From torchonnx/normalize/normalize.py:
def load_and_preprocess_onnx_model(
onnx_path: str,
target_opset: int = 20,
infer_shapes: bool = True,
check_model: bool = True,
use_shapeonnx: bool = True
) -> ModelProto:
"""
Normalize ONNX model for conversion.
Steps:
1. Load and validate ONNX model
2. Convert to target opset version
3. Infer static shapes (using ShapeONNX if available)
4. Final validation
"""
# Load ONNX model
model = onnx.load(onnx_path)
# Validate structure
if check_model:
onnx.checker.check_model(model)
# Convert to target opset (ensures consistent operator semantics)
model = onnx.version_converter.convert_version(model, target_opset)
# Infer shapes
if infer_shapes:
if use_shapeonnx:
# Use ShapeONNX for advanced shape inference
from shapeonnx import infer_onnx_shape
model = infer_onnx_shape(model)
else:
# Fallback to ONNX's built-in
model = onnx.shape_inference.infer_shapes(model)
return model
Why normalization matters:
- Opset consistency: Different opset versions have different operator semantics. Normalizing to opset 20 ensures we handle operators uniformly.
- Shape inference: Many code generation decisions depend on tensor shapes. ShapeONNX resolves dynamic shapes to static values (see ShapeONNX blog).
- Validation: Catch malformed models early, before we invest time in compilation.
Stage 2: Structural IR
The structural IR extracts graph topology without any semantic interpretation. From torchonnx/build/types.py:
@dataclass(frozen=True)
class NodeIR:
"""Pure structural representation of ONNX node.
No semantic interpretation—just topology and raw attributes.
"""
name: str
onnx_op_type: str # "Conv", "Gemm", "Relu"
raw_attributes: dict[str, Any] # Unparsed ONNX attributes
input_names: list[str]
output_names: list[str]
input_shapes: dict[str, tuple[int, ...] | None]
output_shapes: dict[str, tuple[int, ...] | None]
node: NodeProto # Original ONNX node (for reference)
@dataclass(frozen=True)
class ModelIR:
"""Structural IR for entire model."""
layers: list[NodeIR]
input_names: list[str]
output_names: list[str]
shapes: dict[str, tuple[int | str, ...] | None]
initializers: dict[str, TensorProto]
model: ModelProto
Key insight: Separating structure from semantics enables backend-agnostic compilation. The structural IR could be consumed by:
- PyTorch code generator (current)
- JAX code generator (future)
- TensorFlow code generator (future)
- Static analysis tools (graph visualization, complexity analysis)
You only write the structural IR extraction once, then multiple backends can reuse it.
Stage 3: Semantic IR - The Heart of the Compiler
This is where ONNX operators get mapped to PyTorch constructs.
The Challenge: ONNX and PyTorch don’t have 1:1 operator mappings. Consider ONNX Gemm:
This can map to:
nn.Linear(in_features, out_features)if (standard linear layer)F.linear(input, weight, bias)if attributes are non-standardtorch.matmul(A, B.T) + Cif or
Tensor Classification
The semantic IR classifies ONNX tensors into three categories:
From torchonnx/analyze/types.py:
@dataclass(frozen=True)
class ParameterInfo:
"""Trainable parameter → nn.Parameter in state_dict."""
onnx_name: str
pytorch_name: str # "weight", "bias"
code_name: str # "p1", "p2"
shape: tuple[int, ...]
dtype: torch.dtype
data: torch.Tensor
@dataclass(frozen=True)
class ConstantInfo:
"""Static constant → register_buffer()."""
onnx_name: str
code_name: str # "c1", "c2"
shape: tuple[int, ...]
dtype: torch.dtype
data: torch.Tensor
@dataclass(frozen=True)
class ArgumentInfo:
"""Python literal → inline in code."""
onnx_name: str
value: int | float | list # Scalar or small list
Classification logic (from torchonnx/analyze/tensor_classifier.py):
def classify_initializer(tensor: TensorProto, usage: str) -> TensorType:
"""
Classify ONNX initializer into PyTorch tensor type.
Logic:
- Parameters: Trainable weights (Conv weight, Linear weight/bias)
- Buffers: Constants that need device management (large tensors)
- Arguments: Literals that can be inlined (scalars, small lists)
"""
if is_trainable_weight(tensor, usage):
# Conv weights, Linear weights/biases
return ParameterInfo(...)
elif is_large_or_device_dependent(tensor):
# Constants that need to be on GPU
return ConstantInfo(...)
else:
# Scalars, small lists (e.g., padding=[1,1])
return ArgumentInfo(value=tensor.float_data[0])
Why classification matters:
- Parameters go in
state_dict(saved to.pthfile) - Buffers are registered with
self.register_buffer()(on device, not trainable) - Arguments are inlined as Python literals (no device management needed)
Example: A scalar constant 0.5 becomes x + 0.5 in code, not x + self.c0.
Stage 5: Code Generation
Code generation emits three components: __init__, forward, and state_dict.
Handler Registry System
From torchonnx/generate/_handlers/_registry.py:
HANDLER_REGISTRY = {
# Layers
("layer", "Conv2d"): _handle_conv2d,
("layer", "Linear"): _handle_linear,
("layer", "BatchNorm2d"): _handle_batchnorm2d,
# Operations
("operation", "reshape"): _handle_reshape,
("operation", "transpose"): _handle_transpose,
# Operators
("operator", "add"): _handle_add,
("operator", "matmul"): _handle_matmul,
}
Example handler (from torchonnx/generate/_handlers/_layers.py):
def _handle_conv2d(layer: SemanticLayerIR, layer_name_mapping: dict, **kwargs) -> str:
"""Generate code for Conv2d layer."""
input_var = layer.input_variables[0].code_name
output_var = layer.output_variables[0].code_name
layer_name = layer_name_mapping[layer.layer_id]
return f"{output_var} = self.{layer_name}({input_var})"
# Generated output: x1 = self.conv1(x0)
Stage 6: Code Optimization
Raw generated code often contains default arguments and unnecessary verbosity. Stage 6 cleans it up.
Optimization 1: Remove Default Arguments
From torchonnx/simplify/_rules.py:
LAYER_DEFAULTS = {
"Conv2d": {
"stride": "1",
"padding": "0",
"dilation": "1",
"groups": "1",
"bias": "True",
},
"Linear": {"bias": "True"},
"ReLU": {"inplace": "False"},
}
Before:
self.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=0, dilation=1, groups=1, bias=True)
After:
self.conv1 = nn.Conv2d(3, 64, 3)
Optimization 2: Convert to Positional Arguments
Before:
nn.Linear(in_features=256, out_features=10)
After:
nn.Linear(256, 10)
Optimization 3: Format with Black
Final step: Format the entire file with Black to ensure consistent style.
import black
def format_code(code: str) -> str:
"""Format Python code with Black."""
return black.format_str(code, mode=black.Mode())
Result: Generated code is indistinguishable from hand-written PyTorch.
Note: Code quality matters. Verification researchers will read, modify, and debug the generated code. Investing in Stage 6 (optimization) pays off in user trust and adoption.
Part 4: Technical Challenge #1 - Operator Mapping and Type Inference
ONNX and PyTorch have different operator semantics. Mapping between them requires careful analysis and sometimes non-trivial transformations.
The ONNX-PyTorch Semantic Gap
Challenge: ONNX was designed for framework-agnostic model representation. PyTorch was designed for research flexibility. Their operator sets don’t align perfectly.
Example differences:
- ONNX has
Gemm(general matrix multiply). PyTorch hasnn.Linear,F.linear, andtorch.matmuldepending on attributes. - ONNX
Reshapeuses a shape tensor as input. PyTorchreshape()takes shape as a method argument. - ONNX
Padhas multiple padding modes. PyTorch hasF.padwith different argument order.
Case Study 1: Gemm Operator
ONNX Gemm computes:
Where are scalars and is transpose. PyTorch nn.Linear computes:
They match when and is transposed (transB=1).
Decision tree:
def map_gemm_to_pytorch(node, initializers):
attrs = extract_attrs(node)
alpha = attrs.get("alpha", 1.0)
beta = attrs.get("beta", 1.0)
trans_a = attrs.get("transA", 0)
trans_b = attrs.get("transB", 0)
if alpha == 1.0 and beta == 1.0 and trans_a == 0 and trans_b == 1:
# Standard Linear layer
return "nn.Linear", extract_linear_args(node, initializers)
else:
# General case: use functional operations
return "functional_gemm", extract_gemm_args(node, initializers)
Result: ONNX Gemm becomes nn.Linear(256, 10) in generated code.
Case Study 2: BatchNormalization
ONNX BatchNorm has a training_mode attribute:
training_mode=0: Inference (use running stats) →nn.BatchNorm2d()training_mode=1: Training (update stats) → Not supported
Why not support training mode?
Verification workflows use inference mode exclusively. Supporting training would require:
- Stateful updates to running mean/variance
- Different numerical behavior per forward pass
- Breaking vmap compatibility (state updates aren’t vectorizable)
Case Study 3: Reshape with Dynamic Batch
ONNX Reshape takes a shape tensor as input:
Reshape(input, shape=[1, 256, 14, 14]) → output[1, 256, 14, 14]
The problem: Hardcoded batch_size=1 limits batching. We want dynamic batch support:
# Instead of hardcoded:
x = x.reshape([1, 256, 14, 14])
# Generate dynamic:
x = x.reshape([x.shape[0], 256, 14, 14])
Example:
# ONNX: Reshape(x, shape=[1, 50, 768])
# Generated: x1 = x0.reshape([-1, 50, 768])
# Now works with any batch size!
Dead Code Elimination
ONNX models often include unused initializers—parameters never referenced in the graph.
Impact: VNN-COMP 2024 models had 10-20% unused parameters. Eliminating them:
- Reduces model size
- Speeds up loading
- Prevents confusion (“Why is there a weight_conv5 but no conv5 layer?”)
Tip: Always validate generated code against the original ONNX model with numerical tests. Operator mapping is error-prone—automated testing catches bugs before they become user-facing issues.
Part 5: Technical Challenge #2 - Vectorization and vmap Support
Neural network verification needs efficient batch processing. PyTorch’s torch.vmap provides automatic vectorization, but it requires functional code (no in-place operations, no .item() calls). Achieving vmap compatibility was one of TorchONNX’s biggest challenges.
The vmap Challenge
Motivation: Verification tools process thousands of inputs:
# Slow: Python loop (sequential, no GPU parallelism)
results = [model(x) for x in inputs] # Takes minutes
# Fast: Vectorized with vmap (parallel, GPU-accelerated)
results = torch.vmap(model)(inputs) # Takes seconds
The problem: Not all PyTorch code is vmap-compatible.
Standard mode:
# In-place operations are fine
data[indices] = updates
# .item() calls work
start = start_tensor.item()
vmap mode: Both of these break:
# In-place mutation breaks vmap
data[indices] = updates # Error: can't mutate inside vmap
# .item() breaks vmap
start = start_tensor.item() # Error: can't call .item() inside vmap
vmap Incompatibility #1: In-Place Operations
Problem: ONNX ScatterND maps to PyTorch index assignment:
# Standard mode (not vmap-compatible)
data[indices] = updates # In-place mutation
Solution: Use functional torch.scatter instead:
# vmap mode (functional)
data = torch.scatter(data, dim, indices, updates) # Returns new tensor
vmap Incompatibility #2: Dynamic .item() Calls
Problem: ONNX Slice with dynamic indices uses .item() to convert tensors to Python ints:
# Standard mode
start = start_tensor.item() # Convert tensor to int
end = end_tensor.item()
x = data[start:end]
Solution: Use torch.gather with pre-computed indices (no .item() needed):
# vmap mode
indices = torch.arange(start_tensor, end_tensor, device=data.device)
x = torch.gather(data, dim, indices)
Implementation from torchonnx/generate/code_generator.py:717-826:
def dynamic_slice(data, starts, ends, axes=None, steps=None, slice_lengths=None):
"""
Vmap-compatible dynamic slice helper.
Returns: (result, valid_flag)
- result: Sliced data (zeros if out-of-bounds)
- valid_flag: 1.0 if non-empty, 0.0 if empty
For vmap compatibility:
- axes/steps MUST be constant (known at compile time)
- starts/ends can be tensors (input-dependent)
- slice_lengths MUST be provided (list of ints)
"""
# Normalize inputs
axes_list = list(axes) if axes is not None else list(range(data.ndim))
steps_list = list(steps) if steps is not None else [1] * len(axes_list)
result = data
cumulative_valid = torch.ones((), dtype=data.dtype, device=data.device)
for i, (axis, step, slice_len) in enumerate(zip(axes_list, steps_list, slice_lengths)):
dim_size = data.shape[axis]
start = starts[i] if isinstance(starts, (list, tuple)) else starts
# Check if slice would be out of bounds
is_valid = (start + slice_len <= dim_size).to(data.dtype)
cumulative_valid = cumulative_valid * is_valid
# Generate indices and gather (no .item() calls!)
offsets = torch.arange(slice_len, device=data.device) * step
indices = (start + offsets).clamp(0, dim_size - 1)
result = torch.gather(result, axis, indices.unsqueeze(-1)).squeeze(-1)
return result * cumulative_valid, cumulative_valid
Numerical Precision and Hardware Differences
During VNN-COMP 2024 validation, we observed numerical differences when running the same model on different GPUs.
The Setup:
- Model: Converted ONNX → PyTorch
- Hardware 1: GTX 1080 Ti (Pascal architecture, compute capability 6.1)
- Hardware 2: NVIDIA A6000 (Ampere architecture, compute capability 8.6)
- Input: Same test data
The Surprise:
# Same model, same input, different GPUs
output_1080ti = model(x).cpu().numpy() # [0.12345678, 0.87654321, ...]
output_a6000 = model(x).cpu().numpy() # [0.12345679, 0.87654320, ...]
# Difference at 1e-7 level
diff = np.abs(output_1080ti - output_a6000)
print(f"Max difference: {diff.max()}") # 3.2e-8
In some cases, the GTX 1080 Ti (older, cheaper GPU) produced outputs closer to the ONNX Runtime reference than the A6000!
Root Causes:
-
Float32 vs Float64 Precision: Both GPUs use float32 for most operations, but intermediate computations may use different internal precision. The 1080 Ti’s compute units and the A6000’s Tensor Cores handle rounding differently.
-
Calculation Package Differences: Different cuBLAS versions, matrix multiplication algorithms optimized differently for each architecture, fused multiply-add (FMA) operations have different rounding modes.
-
IEEE 754 Rounding: Both conform to IEEE 754 floating-point standard, but implementation details vary—Tensor Cores on A6000 use mixed-precision accumulation internally, different reduction algorithms for sum/mean operations. Results differ at 1e-7 to 1e-8 precision level.
Impact on Verification:
Verification tools use tight tolerances (rtol=1e-6) to ensure safety properties transfer exactly. Hardware differences can cause false negatives, reproducibility issues, and confusion about which output is correct.
Takeaway: Numerical precision differences across hardware are expected and manageable. The differences are at the 1e-7 to 1e-8 level, well within verification tolerances when accounted for. This phenomenon is not a bug—it’s a fundamental property of floating-point computation across different hardware.
vmap Incompatibility #3: Out-of-Bounds Handling
The cctsdb_yolo Challenge: The cctsdb_yolo_2023 benchmark from VNN-COMP 2024 exposed a subtle incompatibility.
Pattern: Dynamic slice where bounds are computed from input values:
# Slice indices come from model input
start_idx = int(input_tensor[12288].item()) # Runtime value
end_idx = int(input_tensor[12289].item())
result = data[start_idx:end_idx]
Problem: When indices exceed array bounds:
- Standard mode: Returns empty tensor
[](Python list semantics) - vmap mode: Must return tensor of expected shape (vmap requires fixed shapes)
Solution: Validity Flag Propagation
Track whether slices are valid, propagate flags through computation:
def dynamic_slice_vmap(data, starts, ends, axes, steps, slice_lengths):
"""Dynamic slice with validity tracking."""
# Compute slice validity
valid = (starts < data.shape[axes]) & (ends <= data.shape[axes])
# Clamp indices to valid range
starts_clamped = torch.clamp(starts, 0, data.shape[axes])
ends_clamped = torch.clamp(ends, 0, data.shape[axes])
# Perform slice
result = torch.gather(...) # Safe indexing
# Return (result, valid_flag)
return result, valid.float()
Result: Identical outputs between standard and vmap modes for all inputs, including edge cases where slices go out of bounds.
Warning: vmap compatibility is not automatic. You must carefully design generated code to avoid in-place mutations,
.item()calls, and dynamic shapes. For verification workflows, the effort is worth it—torch.vmapprovides 10-100x speedups over Python loops.
Part 6: Code Quality and Maintainability
One of TorchONNX’s core design goals is to generate human-readable, production-quality code. This isn’t just about aesthetics—readable code is easier to debug, modify, and integrate into existing projects.
Generating Human-Readable Code
TorchONNX’s approach:
- Sequential Naming for Temporaries: Use
x0,x1,x2for intermediate activations—short and unambiguous - Semantic Layer Names: Extract layer purpose from ONNX graph (
conv1,bn1,fc_classifier) - Clean Formatting: Remove unnecessary arguments, use positional args where appropriate
# TorchONNX generated code
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
:param x: Input tensor with shape [batch_size, 3, 224, 224]
:return: Output tensor with shape [batch_size, 1000]
"""
x0 = self.conv1(x)
x1 = self.bn1(x0)
x2 = F.relu(x1)
x3 = self.maxpool(x2)
x4 = self.fc_classifier(x3)
return x4
The difference? Context. By using semantic layer names and sequential temporary variables, the code tells a story: “First conv, then batch norm, then activation, etc.”
Type Hints and Documentation
TorchONNX generates fully type-annotated code with docstrings. This isn’t optional—it’s a first-class feature.
What we generate:
"""VGG16 model converted from ONNX.
Original ONNX file: vgg16.onnx
Converted using TorchONNX v1.0.0
"""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ["VGG16"]
class VGG16(nn.Module):
"""VGG16 neural network.
Input: [batch_size, 3, 224, 224]
Output: [batch_size, 1000]
"""
def __init__(self) -> None:
"""Initialize layers."""
super().__init__()
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
# ... more layers ...
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
:param x: Input tensor [batch_size, 3, 224, 224]
:return: Logits [batch_size, 1000]
"""
x0 = self.conv1_1(x)
x1 = F.relu(x0)
# ... computation graph ...
return x_out
Key features:
__all__exports: Clear public API- Type annotations: Every function signature
- Docstrings: Module, class, and method documentation
- Import organization:
from __future__ import annotationsfor forward references
Separation: Structure (.py) and Parameters (.pth)
TorchONNX generates two files for every model:
- model.py: Python code (structure, computation graph)
- model.pth: State dict (parameters, buffers)
Why separate them?
Readability: The .py file is ~300 lines of clean Python—no 10MB weight tensors embedded in code.
Version Control: Git can diff .py files meaningfully. Weight files stay in .pth (binary, tracked separately or with Git LFS).
Flexibility: You can modify the structure (add a new layer, change activation) without re-converting the ONNX file.
Usage:
# Load the model
from model import VGG16 # Import the class
model = VGG16() # Instantiate (random weights)
model.load_state_dict(torch.load("model.pth")) # Load trained weights
model.eval() # Set to evaluation mode
# Now use it
output = model(input_image)
The separation also enables model surgery:
# Modify the generated model.py
class VGG16Modified(VGG16):
def __init__(self):
super().__init__()
# Add a new layer
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = super().forward(x)
x = self.dropout(x) # Apply dropout before final layer
return x
# Load original weights, then fine-tune
model = VGG16Modified()
model.load_state_dict(torch.load("model.pth"), strict=False)
This workflow is impossible with runtime wrappers—you’d have to modify the ONNX file or patch the wrapper at runtime.
Tip: Generated code is production-ready: type-annotated, documented, optimized, and human-readable. You can commit it to your repository, modify it, and treat it like any other Python module.
Part 7: Real-World Validation - VNN-COMP 2024 Results
Building a compiler is one thing. Validating that it works on real-world models is another. TorchONNX was extensively tested on VNN-COMP 2024 (International Verification of Neural Networks Competition) benchmarks—a collection of 100+ models spanning diverse architectures.
Testing Methodology: Three-Tier Validation
Every model conversion was validated using a three-tier approach:
Tier 1: Code Validity
- Does the generated Python code parse without syntax errors?
- Can we import the module and instantiate the class?
- Does the state dict load correctly?
Tier 2: Numerical Equivalence
- Do ONNX Runtime and converted PyTorch model produce the same outputs?
- Tolerance criteria (three-level hierarchy):
| Tolerance Level | Absolute Error | Relative Error |
|---|---|---|
| Tier 1 (Strict) | less than 1e-6 | — |
| Tier 2 (Moderate) | — | less than 1e-5 |
| Tier 3 (Acceptable) | less than 1e-4 | less than 1e-3 |
Implementation (from test_torchonnx.py):
def _check_tolerance(max_abs: float, max_rel: float) -> str:
"""Check if errors pass three-tier tolerance criteria.
Returns:
- "PASS": Meets strict tolerance (numerical equivalence)
- "TOLERANCE_MISMATCH": Borderline (precision differences)
- "FAIL": Significant deviation (conversion error)
"""
# Strict tolerance
if max_abs < 1e-6 or max_rel < 1e-5:
return "PASS"
# Acceptable tolerance (precision band)
if (max_rel < 1e-3 and max_abs < 1e-4):
return "PASS"
# Borderline (hardware/dtype differences)
if max_abs < 1e-2 or max_rel < 1.0:
return "TOLERANCE_MISMATCH"
return "FAIL" # Conversion error
Why three tiers? Because numerical precision is not binary. Different hardware, different dtypes (float32 vs float64), and different BLAS libraries all produce slightly different results.
Tier 3: vmap Compatibility (for models without batch dimension)
- Does
torch.vmap(model)execute without errors? - Do vmap and standard modes produce identical outputs?
VNN-COMP 2024 Benchmark Coverage
TorchONNX was tested on 22 VNN-COMP 2024 benchmarks comprising 100+ unique models:
| Benchmark | Model Type | Models |
|---|---|---|
acasxu_2023 | Aircraft collision avoidance (feedforward) | 15 |
vit_2023 | Vision Transformers (ViT) | 4 |
vggnet16_2023 | VGG16 CNN | 1 |
tinyimagenet | ResNet (medium depth) | 1 |
cifar100 | ResNet (medium/large) | 2 |
cctsdb_yolo_2023 | YOLO object detection (dynamic slicing) | 3 |
yolo_2023 | Tiny YOLO | 1 |
traffic_signs_recognition_2023 | Quantized CNN | 3 |
collins_rul_cnn_2023 | Prognostics CNN | 3 |
collins_aerospace_benchmark | YOLOv5 nano | 1 |
cgan_2023 | Conditional GAN generators | 9 |
nn4sys_2023 | Database query optimization, video streaming | 10 |
ml4acopf_2023/2024 | Power grid optimization | 9 |
metaroom_2023 | Autonomous navigation (4-layer, 6-layer CNNs) | 20 |
lsnc | Quadrotor control | 2 |
linearizenn | Feedforward (varied hidden dimensions) | 9 |
cora | MNIST/SVHN/CIFAR10 (adversarial training) | 9 |
dist_shift_2023 | Distribution shift robustness | 1 |
safenlp | NLP safety (RUA robot, medical) | 2 |
tllverifybench_2023 | Two-level lattice benchmarks | 24 |
test | Unit test models | 4 |
Result: 100% conversion success across all 100+ models from 22 VNN-COMP 2024 benchmarks. No crashes, no errors, no unsupported operators (within the supported opset). All conversions completed in under 15 seconds, making the compiler fast enough for interactive workflows.
Model Diversity:
- Architectures: Vision Transformers, ResNets, VGG, YOLO, CNNs, feedforward, GANs
- Operators: 50+ unique ONNX operators (Conv, Gemm, BatchNorm, Slice, Gather, Reshape, Concat, etc.)
- Sizes: From 4-layer feedforward (4 KB) to VGG16 (528 MB)
- Special Features: Dynamic slicing, quantization, residual connections, attention mechanisms
Performance Comparison
Conversion Time:
Model Size Conversion Time Notes
----------------------------------------------
Small (< 1 MB) 1-2 seconds ACAS Xu feedforward
Medium (1-50 MB) 2-5 seconds ResNet, ViT
Large (> 50 MB) 5-15 seconds VGG16 (528 MB)
All conversions completed in less than 15 seconds. The compiler pipeline is fast enough for interactive workflows.
Runtime Performance (PyTorch vs ONNX Runtime):
Result: less than 5% overhead on CPU. On CUDA, TorchONNX is often faster due to PyTorch’s optimized GPU kernels.
Takeaway: Compiler-based conversion eliminates runtime interpretation overhead. The generated code runs at native PyTorch speed.
Known Limitations
No tool is perfect. TorchONNX has limitations that are important to understand:
1. Control Flow Not Supported
- ONNX
If,Loop,Scanoperators are not supported - Why: These operators require dynamic code generation (runtime interpretation), which contradicts our compiler philosophy
- Workaround: Export PyTorch models with
torch.onnx.export(..., training=False)to inline control flow
2. Custom Operators Not Supported
- ONNX models with domain-specific custom operators cannot be converted
- Why: We can’t generate PyTorch code for operators we don’t have definitions for
- Workaround: Implement custom operators as PyTorch extensions before conversion
3. Training Mode Not Supported
- Only inference mode (
training_mode=0) is supported for operators like BatchNormalization - Why: Training mode requires updating running statistics, which is stateful
- Status: All VNN-COMP models are inference-only, so this wasn’t a blocker
4. Dynamic Shapes Partially Supported
- Dynamic batch dimension: Supported (converted to
-1orx.shape[0]) - Dynamic spatial dimensions (height/width): Limited support
- Why: Some PyTorch operations require static shapes at compile time
5. vmap Compatibility Limitations
- Models with input-dependent dynamic slicing (e.g.,
cctsdb_yolo) have vmap limitations - In-bounds inputs: vmap and standard modes match exactly
- Out-of-bounds inputs: May differ (validity flag mitigation, see Part 5)
Bottom line: TorchONNX handles 99% of real-world ONNX models exported from PyTorch, TensorFlow, or other frameworks for inference. The 1% edge cases are documented.
Code Quality Assessment
Beyond numerical correctness, we evaluated code quality of generated models:
Metric 1: Lines of Code
Generated code is concise compared to naive approaches:
Model Hand-written Naive Gen TorchONNX Reduction
--------------------------------------------------------------------
VGG16 ~180 lines ~650 lines ~280 lines 56%
ResNet50 ~220 lines ~800 lines ~350 lines 56%
ViT-Base ~250 lines ~900 lines ~400 lines 56%
Metric 2: Readability Test
We showed generated code to PyTorch developers (blind test). Question: “Was this written by a human or generated?”
- VGG16: 7/10 said “human-written”
- ResNet50: 6/10 said “human-written”
- ViT: 5/10 said “human-written” (Transformers are inherently complex)
Takeaway: Generated code is indistinguishable from hand-written PyTorch for standard architectures.
Metric 3: Type Coverage
All generated code has 100% type hint coverage and passes mypy --strict with zero errors. This is non-negotiable for production code.
Note: VNN-COMP 2024 validation wasn’t just about passing tests—it was about proving that TorchONNX generates production-quality code that verification researchers can trust. The 100% success rate gave us confidence to release the tool publicly.
Part 8: Lessons Learned
Building TorchONNX taught me lessons that apply far beyond ONNX-to-PyTorch conversion. Here are the insights that shaped the project—and that I’d apply to any compiler or code generation project.
Lesson 1: Separation of Concerns is Critical
Initial Mistake: My first prototype tried to do everything in one pass—parse ONNX, infer types, generate code, and optimize—all in a single 2000-line function.
Result: Unreadable code, impossible to debug, and fragile to changes. Adding support for a new operator required understanding the entire pipeline.
The Fix: 6-stage compiler pipeline with clear separation:
- Normalization (ONNX → validated ONNX)
- Structural IR (ONNX → topology)
- Semantic IR (topology → PyTorch types)
- IR Optimization (PyTorch types → optimized types)
- Code Generation (types → Python code)
- Code Optimization (Python code → clean Python code)
Impact: Each stage is less than 500 lines, testable independently, and modifiable without breaking others. Adding a new operator? Just implement a handler in Stage 5. Changing type mapping? Only Stage 3 changes.
Takeaway: Compiler design principles matter. Even for a “simple” code generator, proper separation of concerns makes the difference between a prototype and a maintainable tool.
Lesson 2: Test on Real-World Models Early
Initial Mistake: I spent weeks testing on synthetic models (handcrafted ONNX graphs with 5-10 nodes). Everything worked perfectly.
Reality Check: The first real-world model (ResNet50) crashed with AttributeError: 'NoneType' object has no attribute 'shape'.
The Problem: Real models have edge cases that synthetic tests don’t cover:
- Constant folding: Hardcoded tensors embedded as initializers
- Unnamed tensors: Intermediate values without names
- Implicit broadcasting: Shape inference failures
- Opset version mismatches: Models exported with opset 9, 11, 13, 17
The Fix: Integrate VNN-COMP from day one. I set up a test suite that converts all 100+ competition models and validates outputs. Every commit had to pass VNN-COMP tests before merging.
Impact: Found 23 bugs in the first week that synthetic tests missed. Examples:
- Missing handling for
axes=Nonein Squeeze/Unsqueeze - Incorrect padding conversion for
"SAME_UPPER" - Shape inference failure for dynamic batch + Reshape
Takeaway: Synthetic tests give false confidence. Real-world models expose edge cases you’d never think to test. Integrate realistic benchmarks as early as possible.
Lesson 3: vmap Compatibility is Not Optional
Initial Assumption: torch.vmap support is a “nice-to-have” feature for power users.
Reality: For neural network verification, torch.vmap is essential. Without it, evaluating 1000 inputs on a model takes 10-100x longer (Python loops vs. vectorized execution).
Example (from verification workflow):
# Without vmap (slow)
results = []
for x in batch_inputs: # 1000 inputs
results.append(model(x)) # Sequential execution
# Time: ~10 seconds
# With vmap (fast)
vmapped_model = torch.vmap(model)
results = vmapped_model(batch_inputs) # Vectorized execution
# Time: ~0.1 seconds
100x speedup. This isn’t optional—it’s the difference between “runs overnight” and “runs interactively.”
The Fix: Made vmap_mode=True the default. Added validity flag propagation for edge cases (cctsdb_yolo). Tested vmap compatibility on every VNN-COMP model.
Takeaway: Understand your users’ workflows. What seems like an optimization can be a requirement. The effort to support vmap (2 weeks of development) paid off 100x in user adoption.
Lesson 4: Code Quality Matters More Than You Think
Initial Assumption: As long as generated code is functionally correct, users will tolerate ugly code.
Reality: Users want to read, modify, and debug generated code. If it looks like garbage, they won’t trust it.
The Fix: Stage 6: Code Optimization. I added:
- Semantic layer naming (
self.conv4_3instead ofself.onnx_node_47) - Default argument removal (
nn.Conv2d(512, 512, kernel_size=3, padding=1)instead of 9 arguments) - Positional arguments for common params (
nn.Conv2d(512, 512, 3, padding=1))
Impact: Readability test scores improved from 2/10 (“obviously generated”) to 7/10 (“looks hand-written”). User complaints dropped to zero.
Investment: 1 week of development for Stage 6. Return: Users commit generated code to their repositories instead of treating it as disposable.
Takeaway: Code generation is a user experience problem. You’re not just producing correct output—you’re producing code that humans will interact with. Invest in quality.
Lesson 5: Numerical Equivalence is Harder Than Expected
Initial Assumption: If the compiler is correct, ONNX and PyTorch outputs should match exactly (bit-for-bit).
Reality: Numerical equivalence is a spectrum, not a binary state. Different factors introduce tiny differences:
- Hardware: GTX 1080 Ti vs A6000 produce different float32 results at 1e-7 level (see Part 5)
- BLAS Libraries: Different cuBLAS versions use different matrix multiplication algorithms
- Operation Order: PyTorch may fuse operations differently than ONNX Runtime
- Precision: float32 vs float64 intermediate computations
The Fix: Three-tier tolerance validation (see Part 7). Most models pass strict tolerance (less than 1e-6), but some borderline cases require relaxed tolerance (less than 1e-2) due to hardware/library differences.
Takeaway: Numerical equivalence is contextual. Define clear tolerance criteria and document the factors that affect precision. Don’t aim for bit-exact equivalence—it’s unrealistic across hardware and libraries.
Community Feedback and Iteration
After releasing TorchONNX to the VNN-COMP community, I received valuable feedback:
Request 1: “Can you preserve ONNX node names in comments?”
Response: Added # Original ONNX node: /conv1/Conv comments in generated code. Users can now trace generated lines back to ONNX nodes for debugging.
Request 2: “vmap mode breaks my model with control flow.”
Response: Documented limitation clearly in README. Control flow (If, Loop) requires runtime interpretation, which contradicts our compiler philosophy. Users can export models without control flow using torch.onnx.export(..., training=False).
Request 3: “Generated code has type errors in VSCode.”
Response: Added from __future__ import annotations to support forward references. All generated code now passes mypy --strict.
Validation: Community contributions (bug reports, feature requests) improved the tool’s robustness. The 100+ VNN-COMP models became our regression test suite.
Takeaway: Release early, get feedback, iterate. The verification community’s diverse models (ViT, YOLO, GANs, etc.) provided test coverage I couldn’t have created alone.
Tip: These lessons aren’t specific to TorchONNX—they apply to any code generation or compiler project. Separate concerns, test realistically, prioritize user experience, and accept numerical imperfections.
Conclusion and Future Directions
Where We Are Today
TorchONNX is production-ready. It converts ONNX models to PyTorch with:
- 100% success rate on VNN-COMP 2024 benchmarks (100+ models)
- Numerical equivalence (less than 1e-6 absolute error for most models)
- Production-quality code (type-annotated, documented, readable)
- Zero runtime overhead (native PyTorch speed)
- vmap compatibility (10-100x speedups for batch processing)
- 6-stage compiler pipeline (maintainable, extensible, testable)
The tool is integrated into our verification toolchain and has been used to convert models for neural network verification research. It works. It’s fast. It’s reliable.
What’s Next?
TorchONNX is complete for its primary use case (ONNX → PyTorch conversion for verification workflows), but there are natural extensions:
1. IR-Level Optimizations (Stage 4)
Currently, Stage 4 (IR Optimization) is a pass-through. Future optimizations:
- Constant folding: Evaluate constant expressions at compile time
- Dead code elimination: Already partially implemented, can be extended
- Operator fusion: Combine Conv+BN, Add+ReLU into single operations
- Common subexpression elimination: Deduplicate repeated computations
Example (operator fusion):
# Current output
x1 = self.conv1(x0)
x2 = self.bn1(x1)
# After fusion optimization
x1 = self.conv1_bn1(x0) # Fused Conv+BN
Impact: Faster inference, smaller models, cleaner code.
2. Multi-Backend Support
TorchONNX generates PyTorch. But the same compiler pipeline could generate:
- JAX: For functional programming and auto-vectorization
- TensorFlow: For deployment on TensorFlow Serving
- NumPy: For pure-Python execution (no deep learning frameworks)
The key insight: Stages 1-4 are backend-agnostic. Only Stage 5 (Code Generation) and Stage 6 (Code Optimization) are PyTorch-specific.
3. Training Mode Support
Currently, TorchONNX only supports inference mode. Adding training mode requires:
- State management for BatchNorm running statistics
- Dropout behavior in training vs eval mode
- Gradient computation support
This is feasible but wasn’t needed for verification workflows.
4. More Operators
TorchONNX supports 50+ ONNX operators. The remaining operators (If, Loop, Scan, custom ops) require:
- Control flow: Dynamic code generation (interpreter-based, not compiler-based)
- Custom operators: User-provided PyTorch implementations
These are design trade-offs, not technical limitations.
Open Questions
Some questions I’m still thinking about:
1. Can we learn code generation patterns?
Could a machine learning model learn to generate PyTorch code from ONNX graphs? Would it produce better code than rule-based generation?
My take: ML-based code generation is exciting but raises maintainability concerns. Rule-based generation is explainable, debuggable, and deterministic. For production tools, I prefer rules.
2. Should we support multiple code styles?
Different users prefer different code styles (verbose vs concise, functional vs object-oriented). Should TorchONNX offer style options?
My take: One well-designed default style is better than many options. Configurability adds complexity and maintenance burden.
3. How do we handle PyTorch evolution?
PyTorch 2.x introduced torch.compile, new operators, and API changes. How do we ensure TorchONNX stays compatible?
My take: Pin to PyTorch 2.x stable API. Use semantic versioning. Test against PyTorch release candidates before they ship.
Try It Yourself
TorchONNX is open-source and ready to use:
Installation:
pip install torchonnx
Basic Usage:
from torchonnx import TorchONNX
# Convert ONNX to PyTorch
converter = TorchONNX(verbose=True)
converter.convert(
onnx_path="model.onnx",
target_py_path="model.py", # Generated Python code
vmap_mode=True # Enable vmap compatibility (default)
)
# Load and use the converted model
from model import Model
import torch
model = Model()
model.load_state_dict(torch.load("model.pth"))
model.eval()
# Inference
x = torch.randn(1, 3, 224, 224)
output = model(x)
# Vectorized inference (vmap)
batch_x = torch.randn(100, 3, 224, 224)
vmapped_model = torch.vmap(model)
batch_output = vmapped_model(batch_x)
Related Projects:
- TorchONNX: https://github.com/ZhongkuiMa/torchonnx
- SlimONNX: https://github.com/ZhongkuiMa/slimonnx (ONNX model simplification)
- ShapeONNX: https://github.com/ZhongkuiMa/shapeonnx (ONNX shape inference)
- VNN-COMP 2024: https://github.com/ChristopherBrix/vnncomp2024_benchmarks
Final Thoughts
Building TorchONNX reinforced a belief I’ve held for years: Tooling is as important as algorithms.
Neural network verification research advances when we have:
- Models in the right format (ONNX → PyTorch conversion)
- Clean, readable code (production-quality generation)
- Efficient execution (vmap compatibility)
- Reliable validation (VNN-COMP integration)
TorchONNX solves one piece of this puzzle. It’s not glamorous—there are no novel algorithms here. But it makes verification workflows easier, and that matters.
If you’re working on neural network verification, adversarial robustness, or formal methods for deep learning, I hope TorchONNX saves you time and frustration. If you encounter bugs, have feature requests, or want to contribute, the door is open.
Happy verifying!