PropDAG: A Template Framework for Bound Propagation on Neural Network DAGs
A reusable template framework for bound propagation on directed acyclic graphs, separating algorithm logic from graph traversal infrastructure.
Introduction
Every neural network verification tool reimplements the same graph traversal boilerplate:
- Topological sorting (to process nodes in dependency order)
- Graph traversal (forward/backward passes)
- Cache management (store intermediate bounds, clear when no longer needed)
- Reference counting (for skip connections and branching)
The tragedy: This infrastructure takes 30% of implementation effort, while the actual research contribution (bound propagation algorithm) takes 70%.
PropDAG solves this: A zero-dependency Python framework that handles graph infrastructure, letting researchers focus on algorithm logic.
Result: Save ~200 lines of boilerplate per algorithm. Write once, reuse across multiple verification approaches.
Part 1: The Problem (Cost of Reimplementation)
Consider implementing a bound propagation algorithm (e.g., interval arithmetic):
Without PropDAG (~200 lines):
# Graph representation
class Model:
def __init__(self, nodes):
self.nodes = nodes
# Topological sort (manually)
def _toposort(self):
visited = set()
order = []
def dfs(node):
if node in visited: return
visited.add(node)
for dep in node.dependencies:
dfs(dep)
order.append(node)
for node in self.nodes:
dfs(node)
return order
# Graph traversal (manually)
def run(self):
nodes = self._toposort()
cache_count = {n: len(n.consumers) for n in nodes}
for node in nodes:
# YOUR ALGORITHM HERE (interval propagation)
node.forward()
# Reference counting (manually)
for dep in node.dependencies:
cache_count[dep] -= 1
if cache_count[dep] == 0:
dep.clear_cache()
With PropDAG (~15 lines):
from propdag import TNode, TModel
class MyNode(TNode):
def forward(self):
# YOUR ALGORITHM HERE (interval propagation)
...
model = TModel(nodes, sort_strategy="bfs")
model.run()
Savings: 185 lines of boilerplate. Same for CROWN, DeepPoly, or any other algorithm.
Part 2: The Solution (Template Method Pattern)
PropDAG uses the template method pattern: Framework defines the loop structure, users fill in the algorithm-specific methods.
Abstract interface:
class TNode(ABC):
@property
def pre_nodes(self) -> list["TNode"]:
"""Input dependencies."""
return self._pre_nodes
@property
def next_nodes(self) -> list["TNode"]:
"""Output consumers."""
return self._next_nodes
def forward(self):
"""Implement your bound propagation logic here."""
raise NotImplementedError
def backward(self):
"""Optional: back-substitute to tighten bounds."""
raise NotImplementedError
def clear_fwd_cache(self):
"""Clear intermediate results after all consumers finish."""
raise NotImplementedError
class TModel:
def __init__(self, nodes, sort_strategy="bfs"):
# Automatically sorts nodes topologically
if sort_strategy == "bfs":
self._nodes = topo_sort_bfs(nodes)
else:
self._nodes = topo_sort_dfs(nodes)
def run(self):
"""Execute: topological sort + forward pass + reference counting."""
cache_counter = {n: len(n.next_nodes) for n in self._nodes}
for node in self._nodes:
node.forward() # Your algorithm
# Auto-clear caches
for pre_node in node.pre_nodes:
cache_counter[pre_node] -= 1
if cache_counter[pre_node] == 0:
pre_node.clear_fwd_cache()
User implementation (e.g., interval arithmetic):
class IntervalNode(TNode):
def forward(self):
if len(self.pre_nodes) == 0:
# Input: use input bounds
return
# Propagate intervals from predecessors
lower = max(p.bounds.lower for p in self.pre_nodes)
upper = min(p.bounds.upper for p in self.pre_nodes)
self.bounds = Interval(lower, upper)
def clear_fwd_cache(self):
self.bounds = None
That’s it. 20 lines of application logic. Graph infrastructure is implicit.
Part 3: A Complete Example - DeepPoly vs Interval Arithmetic
To show why PropDAG matters, compare two bound propagation algorithms:
Interval Arithmetic (simple):
class IntervalNode(TNode):
def forward(self):
if len(self.pre_nodes) == 0: return
# Interval union: [l1, u1] U [l2, u2] = [max(l), min(u)]
lower = max(p.bounds.lower for p in self.pre_nodes)
upper = min(p.bounds.upper for p in self.pre_nodes)
self.bounds = Interval(lower, upper)
DeepPoly (affine forms, tighter bounds):
class DeepPolyNode(TNode):
def forward(self):
if len(self.pre_nodes) == 0: return
# Affine form: y = sum(c_i * x_i) + c0
# Represent bounds as linear combinations of input variables
coeff_sum = sum(p.affine_coeff for p in self.pre_nodes)
self.affine = Affine(coeff=coeff_sum, bias=self.bias)
self.bounds = self.affine.concretize() # Convert to intervals
def backward(self):
# Tighter: back-substitute to eliminate intermediate variables
for pre in self.pre_nodes:
pre.affine = self.affine.substitute(pre.affine)
Without PropDAG: Implement both → rewrite topological sort, reference counting, traversal logic twice → 400 lines.
With PropDAG: Implement just the algorithmic difference → 30 lines of forward() and backward().
The framework handles the boring infrastructure. You focus on bounds.
Part 4: Common Pitfalls When Rolling Your Own
Pitfall 1: Reference Counting Off-by-One
# Wrong: decrement counter after processing
for node in nodes:
node.forward() # Process this node
for dep in node.pre_nodes:
counter[dep] -= 1
if counter[dep] == 0:
dep.clear_cache() # BUG: second consumer sees None!
Why it kills you: A node with 2 consumers (e.g., Conv → [Add, MatMul]). First consumer decrements counter to 0 and clears cache. Second consumer crashes because it expected bounds to be there.
PropDAG prevents this: Reference counting is built-in and correct.
Pitfall 2: Not Handling Skip Connections
# Naive approach: assume 1 consumer per node
# But ResNet blocks have skip connections!
Input ──→ Conv1 ──→ Conv2 ─────→ Add ← Conv1's output
└──────────────────────────────┘ stays cached!
Why it kills you: Conv1’s bounds are needed by both Conv2 (via Conv1→Conv2) and Add (via skip). If you clear Conv1’s cache after Conv2, Add crashes.
PropDAG prevents this: Adjacency list tracks all consumers. Reference counting ensures caches persist until all consumers finish.
Pitfall 3: Backward Pass Corruption
Why it kills you: In algorithms like DeepPoly, backward pass tightens bounds by back-substituting symbolic expressions. If you don’t clear caches between forward/backward, old intermediate values corrupt the tightening.
PropDAG prevents this: Forces you to implement both clear_fwd_cache() and clear_bwd_cache(). No way to accidentally skip cleanup.
Part 5: Key Design Decisions
DAG Representation (adjacency list):
- Each node stores
pre_nodes(dependencies) andnext_nodes(consumers) - Naturally handles skip connections (ResNets, attention, etc.)
Topological Sorting (BFS or DFS):
- BFS: Process level-by-level (better for parallelization)
- DFS: Process one path fully (better cache locality)
- Users choose via
sort_strategyparameter
Reference Counting:
- Track how many consumers each node has
- When counter reaches 0, cache is safely cleared
- Eliminates out-of-memory errors from retaining intermediate bounds
Zero Dependencies:
- Pure Python 3.10+
- Doesn’t force PyTorch/NumPy—users bring their own tensor library
- Graph traversal is ~1% of verification runtime (bound computation dominates)
Part 4: Real-World Usage
Algorithm 1: Interval Arithmetic (simple bounds):
class IntervalNode(TNode):
def forward(self):
# Propagate intervals: [l, u] ⊕ [l', u'] = [l⊕l', u⊕u']
bounds = [p.bounds for p in self.pre_nodes]
if not bounds: return
self.bounds = Interval(
lower=max(b.lower for b in bounds),
upper=min(b.upper for b in bounds)
)
Algorithm 2: DeepPoly (affine forms):
class DeepPolyNode(TNode):
def forward(self):
# Propagate affine expressions: y = Σ wᵢ·xᵢ + b
expressions = [p.expr for p in self.pre_nodes]
self.expr = sum(expressions) + self.bias
Algorithm 3: Symbolic Bounds (with backward substitution):
class SymbolicNode(TNode):
def forward(self):
# Propagate symbolic expressions
self.symbolic = f(self.pre_nodes[0].symbolic)
def backward(self):
# Eliminate intermediate variables by substitution
for pre in self.pre_nodes:
pre.symbolic = simplify(self.symbolic.substitute(pre.symbolic))
All three use the same TModel.run(). No reimplementation of graph infrastructure.
Part 5: Try It Yourself
Installation:
pip install propdag
30-second example (interval arithmetic on a 3-node network):
from propdag import TNode, TModel
class MyNode(TNode):
def __init__(self, name, bounds=None):
self.name = name
self.bounds = bounds
self._pre_nodes = []
self._next_nodes = []
def forward(self):
if self.bounds is None and len(self.pre_nodes) > 0:
self.bounds = self.pre_nodes[0].bounds
def clear_fwd_cache(self):
pass
# Build a simple DAG: Input → Conv → ReLU → Output
input_node = MyNode("input", bounds=(0, 1))
conv_node = MyNode("conv")
relu_node = MyNode("relu")
input_node.next_nodes = [conv_node]
conv_node.pre_nodes = [input_node]
conv_node.next_nodes = [relu_node]
relu_node.pre_nodes = [conv_node]
# Run
model = TModel([input_node, conv_node, relu_node])
model.run()
print(f"ReLU output bounds: {relu_node.bounds}")
For real algorithms: Replace forward() with your bound propagation logic (DeepPoly, CROWN, symbolic bounds, etc.). Graph traversal is automatic.
Conclusion
PropDAG is infrastructure for neural network verification research. It eliminates the boilerplate that every verification tool needs:
- Saves: ~200 lines per algorithm (topological sort, traversal, reference counting)
- Cost: ~500 lines framework code (written once, reused forever)
- Benefit: Focus on research (algorithm) instead of engineering (graph management)
The template method pattern separates concerns: framework controls traversal order, users control bound propagation logic.
If you’re building verification tools, PropDAG removes infrastructure friction and lets you iterate on algorithms faster.
Repository: https://github.com/ZhongkuiMa/propdag