TorchONNX: A Compiler for ONNX-to-PyTorch Conversion
A pure Python compiler that converts ONNX models to native PyTorch code through a 6-stage pipeline, achieving 100% success on VNN-COMP 2024 benchmarks.
30-Second Overview
Problem: ONNX→PyTorch conversion tools are runtime wrappers (interpret ONNX ops at inference time). Slow, opaque, unmaintainable.
Solution: TorchONNX is a compiler that generates clean PyTorch code from ONNX in seconds.
Result: 100+ VNN-COMP 2024 models converted, human-readable code, <5% performance overhead, full vmap support.
The Problem: Why Existing Tools Fail
The Asymmetry
PyTorch → ONNX: Official support, works flawlessly. ONNX → PyTorch: No official support, fragmented landscape.
Existing approach (onnx2pytorch):
class ONNXModel(nn.Module):
def forward(self, x):
tensors = {"input": x}
for node in self.onnx_nodes: # Runtime interpretation
outputs = execute_onnx_op(node.op_type, ...)
tensors.update(outputs)
return tensors["output"]
Problems:
- Runtime overhead: Every forward() iterates through all nodes
- Opaque code: Can’t debug or modify
- Dependency on ONNX: Requires onnxruntime at runtime
- vmap incompatible: Stateful operations break functional requirements
The Speedup That Changes Everything: vmap
Verification evaluates thousands of inputs to find adversarial counterexamples. Batch processing is critical.
Without vmap (runtime wrapper):
# Python loop, sequential
for x in batch_inputs: # 1000 inputs
bounds = model(x) # Interpreted ONNX ops
# Time: 10 seconds
With vmap (compiled code):
# Vectorized execution, parallel GPU
vmapped_model = torch.vmap(model)
bounds = vmapped_model(batch_inputs) # 1000 inputs at once
# Time: 0.1 seconds ← 100x speedup!
Why runtime wrappers can’t be vmapped: They’re stateful (interpret nodes sequentially, store intermediate values). vmap requires functional code (no mutations, pure transformations).
TorchONNX’s generated code is functional, so vmap works automatically. This 100x speedup makes verification interactive instead of overnight.
The Solution: Compiler Architecture
TorchONNX converts ONNX to clean PyTorch code:
ONNX Model
↓ normalize & infer shapes
Validated ONNX
↓ extract topology (nodes, edges)
Structural IR
↓ map to PyTorch (Conv→Conv2d, Gemm→Linear)
Semantic IR
↓ optimize (dead code, simplify)
Optimized IR
↓ generate code (__init__, forward, state_dict)
Final: model.py + model.pth
Why stages? Each stage is independently testable and reusable. Add a new optimizer pass? Just extend Stage 4.
Technical Challenge: Maintaining vmap Compatibility
The hard part: ONNX operations often use patterns that break vmap (in-place mutations, dynamic indexing with .item()).
Example: Dynamic Slicing
# ONNX Slice with runtime indices (breaks vmap)
start = start_tensor.item() # Can't call .item() in vmap
end = end_tensor.item()
output = data[start:end]
# TorchONNX solution: use torch.gather instead
# No .item() calls, fully functional
indices = torch.arange(start, end)
output = torch.gather(data, dim, indices)
This is what separates TorchONNX from naive approaches. The generated code must be functional or vmap fails.
The Solution: Compiler Architecture (5 Transformations)
TorchONNX converts ONNX to PyTorch through five clean transformations:
ONNX validate + infer shapes
↓
Normalized ONNX
↓ extract structure: nodes, edges, attributes
Structural IR
↓ map to PyTorch: Conv→Conv2d, Gemm→Linear, etc.
Semantic IR
↓ optimize: dead code, simplify
Optimized IR
↓ generate code: __init__ + forward + load weights
Final: model.py + model.pth (human-readable)
Why separate transformations?
- Each stage has one job (clear responsibility)
- Easy to test (IR at each stage is inspectable)
- Easy to extend (add optimizer pass without touching code gen)
- Reusable (structural IR could feed JAX/TensorFlow generators too)
Validation Results
VNN-COMP 2024 Coverage:
- 100+ models from 22 benchmarks
- Architectures: Vision Transformers, ResNets, VGG, YOLO, GANs, feedforward
- Success rate: 100% (all models converted, all validated)
Performance:
- Conversion time: 1-15 seconds (depending on model size)
- Runtime: <5% overhead on CPU, often faster on GPU
- Code generation: Produces hand-written quality PyTorch
Code Quality:
- Full type hints (mypy —strict passes)
- Docstrings for all methods
- Black-formatted
- Semantic layer names (self.conv1, not self.onnx_op_0)
Implementation Highlights
Immutable IR: All intermediate representations are frozen dataclasses. Prevents accidental mutations, enables debugging.
Handler Registry: Operator mapping is pluggable.
HANDLER_REGISTRY = {
("layer", "Conv2d"): _handle_conv2d,
("operation", "reshape"): _handle_reshape,
("operator", "add"): _handle_add,
}
Separation of Concerns:
- Stage 1-2: Extract structure (no algorithm knowledge)
- Stage 3: Map to PyTorch (algorithm-specific)
- Stage 4: Optimize IR (generic, reusable)
- Stage 5: Generate code (template-based)
- Stage 6: Polish output (formatting)
Each stage can be extended independently.
Known Limitations
- Control flow not supported: If/Loop/Scan require runtime interpretation
- Custom operators: Must implement as PyTorch extensions first
- Training mode: Only inference mode supported (all verification models are inference-only)
- Dynamic shapes: Limited support (batch dimension OK, spatial dimensions limited)
Usage Example
from torchonnx import TorchONNX
# Convert ONNX to PyTorch
converter = TorchONNX(verbose=True)
converter.convert(
onnx_path="model.onnx",
target_py_path="model.py",
vmap_mode=True
)
# Load and use
from model import Model
model = Model()
model.load_state_dict(torch.load("model.pth"))
model.eval()
# Inference
x = torch.randn(1, 3, 224, 224)
output = model(x)
# Vectorized inference (vmap)
batch_x = torch.randn(1000, 3, 224, 224)
vmapped = torch.vmap(model)
batch_output = vmapped(batch_x) # 10-100x faster
Lessons Learned
-
Separate concerns: Infrastructure (sorting, traversal) is different from algorithm (bound propagation). Keeping them separate enables reuse.
-
Test early on real models: Synthetic tests give false confidence. Integrate real benchmarks (VNN-COMP) from day one.
-
Code generation is UX: Readers will debug and modify generated code. Invest in readability (semantic names, type hints, formatting).
-
Numerical precision is context: Hardware, BLAS versions, operation order all introduce 1e-7 to 1e-8 differences. This is normal, manage tolerances accordingly.
-
vmap matters: For verification, vmap support is non-negotiable. The 10-100x speedup justifies the effort.
Conclusion
TorchONNX demonstrates that compiler-based conversion beats runtime wrappers for verification workflows. Clean, readable, modifiable code beats black boxes.
The 6-stage pipeline separates concerns, making the codebase maintainable and extensible. Add a new operator? Implement a handler. Optimize the IR? Add to Stage 4. No cross-cutting concerns.
If you work with ONNX models and need PyTorch code, TorchONNX is ready to use.
Repository: https://github.com/ZhongkuiMa/torchonnx
Related projects:
- ShapeONNX: Shape inference with static shape resolution
- SlimONNX: ONNX optimizer for verification
- PropDAG: Graph traversal framework for bound propagation