PropDAG: A Template Framework for Bound Propagation on Neural Network DAGs¶
Introduction¶
Neural network verification research has a recurring problem: every new bound propagation algorithm requires reimplementing the same graph traversal infrastructure.
When I implemented multiple verification algorithms, I found myself writing nearly identical code for topological sorting, graph traversal, and cache management. The novel algorithm logic—the actual research contribution—was a small fraction of the codebase. The rest was infrastructure that every tool needs but nobody wants to write.
PropDAG solves this by providing a reusable template framework for bound propagation on directed acyclic graphs (DAGs). As a zero-dependency, pure Python framework, PropDAG adds negligible overhead—graph traversal infrastructure is typically <1% of verification runtime, with bound computation dominating total time. It handles the infrastructure (graph traversal, topological sorting, cache clearing) so researchers can focus on algorithm implementation.
The core idea: separate what to compute (algorithm logic) from how to traverse the graph (infrastructure). PropDAG provides abstract base classes (TNode, TModel, TCache) that users extend with their specific propagation logic.
Key features:
Zero dependencies: Pure Python 3.10+, no PyTorch/NumPy required
Template-based: Extend abstract classes, don’t reimplement infrastructure
DAG-native: First-class support for residual connections and branching
Flexible traversal: Choose BFS or DFS topological sorting
Bidirectional: Support both forward and backward propagation modes
Educational: Toy implementations with verbose logging
This post walks through PropDAG’s design, starting with the problem it solves.
Part 1: The Infrastructure Problem¶
1.1 What PropDAG Addresses¶
Bound propagation algorithms (used in neural network verification) all require the same infrastructure:
Graph representation: Nodes are layers/operations, edges are data dependencies
Topological sorting: Process nodes in dependency order
Graph traversal: Visit nodes forward (input→output) or backward (output→input)
Cache management: Store intermediate results, clear when no longer needed
Every verification tool reimplements these components. The algorithm-specific code (how to propagate bounds through a ReLU or Conv layer) is relatively small—most code is infrastructure.
PropDAG’s approach: Implement infrastructure once as a reusable framework. Users extend abstract base classes with algorithm-specific logic.
1.2 Design Goals¶
PropDAG was designed with these principles:
Separation of concerns: Framework handles graph traversal, users handle bound computation
Reusability: Write infrastructure once, use for multiple algorithms
Flexibility: Support different traversal strategies (BFS/DFS) and propagation modes (forward/backward)
Zero dependencies: Don’t force users to adopt a specific tensor library (PyTorch, NumPy, JAX)
Simplicity: Clear abstractions, educational examples
The result is a lightweight framework (~500 lines of template code) that provides graph infrastructure without dictating algorithm implementation.
Part 2: Template Method Pattern and Abstract Base Classes¶
2.1 The Template Method Pattern¶
PropDAG uses the template method pattern: the framework defines the algorithm skeleton, subclasses fill in specific steps.
The pattern in PropDAG:
# Framework (PropDAG) provides the loop
class TModel:
def run(self):
sorted_nodes = topological_sort(self.nodes)
for node in sorted_nodes:
node.forward() # User implements this
# User provides the computation
class MyNode(TNode):
def forward(self):
# Algorithm-specific bound propagation logic
...
Why this works: The framework controls traversal order (via topological sort), ensuring nodes are processed after their dependencies. Users only implement the propagation logic for each node type.
2.2 Abstract Base Classes¶
PropDAG defines three abstract base classes:
TNode: Node-level Computation
class TNode(ABC):
# Abstract methods users must implement
def forward(self):
"""Compute bounds for this node (forward propagation)."""
raise NotImplementedError
def backward(self):
"""Back-substitute symbolic expressions (backward propagation)."""
raise NotImplementedError
def build_rlx(self):
"""Build linear relaxations for non-linear operations."""
raise NotImplementedError
def clear_fwd_cache(self):
"""Clear forward computation cache."""
raise NotImplementedError
def clear_bwd_cache(self):
"""Clear backward computation cache."""
raise NotImplementedError
# Additional methods for symbolic bound propagation
def init_symbnd(self): ...
def fwdprop_symbnd(self): ...
def bwdprop_symbnd(self): ...
def cal_and_update_cur_node_bnd(self): ...
# Properties (graph structure)
@property
def pre_nodes(self) -> list["TNode"]:
"""Predecessor nodes (inputs to this operation)."""
return self._pre_nodes
@property
def next_nodes(self) -> list["TNode"]:
"""Successor nodes (consumers of this operation)."""
return self._next_nodes
TModel: Graph-level Orchestration
class TModel(ABC):
def __init__(self, nodes, sort_strategy="bfs", verbose=False):
# Topologically sort nodes (BFS or DFS)
if sort_strategy == "bfs":
self._nodes = topo_sort_forward_bfs(nodes, verbose)
elif sort_strategy == "dfs":
self._nodes = topo_sort_forward_dfs(nodes, verbose)
def run(self):
"""Execute forward pass, optionally with backward substitution."""
cache_counter = {node: len(node.next_nodes) for node in self._nodes}
for node in self._nodes:
node.forward()
# Optionally run backward pass
if self.arguments.prop_mode == PropMode.BACKWARD:
self.backsub(node)
# Clear caches when no longer needed (reference counting)
for pre_node in node.pre_nodes:
cache_counter[pre_node] -= 1
if cache_counter[pre_node] == 0:
pre_node.clear_fwd_cache()
TCache: Cache Management Interface
@dataclass(slots=True)
class TCache:
"""Base class for caching computation results.
Users extend this to define what to cache (bounds, symbolic
expressions, relaxations, etc.).
"""
TArgument: Configuration
@dataclass(frozen=True, slots=True)
class TArgument:
"""Immutable configuration for propagation mode."""
prop_mode: PropMode = PropMode.BACKWARD
The frozen dataclass ensures immutability (prevents accidental state changes during propagation).
2.3 How Users Extend PropDAG¶
Users create concrete implementations by:
Subclass
TNodeand implement abstract methodsOptionally subclass
TModel(can use base implementation)Define
TCachesubclass for caching strategyBuild node graph and run
model.run()
Example from PropDAG’s toy implementation:
class ForwardToyNode(TNode):
def forward(self):
if len(self.pre_nodes) == 0: # Input node
print(f"{self.name}: Skip input node")
return
self.build_rlx() # Build relaxation
self.fwdprop_symbnd() # Propagate symbolic bounds
self.cal_and_update_cur_node_bnd() # Calculate concrete bounds
def build_rlx(self):
print(f"{self.name}: Calculate relaxation if this is non-linear node")
def fwdprop_symbnd(self):
print(f"{self.name}: Forward propagate symbolic bounds")
self.cache.symbnds[self.name] = ("symbolic bounds",)
def cal_and_update_cur_node_bnd(self):
print(f"{self.name}: Calculate and cache scalar bounds")
self.cache.bnds[self.name] = ("scalar bounds",)
The framework handles topological sorting and traversal; the user only implements propagation logic.
Part 3: Topological Sorting and Graph Traversal¶
3.1 Why Topological Sorting Matters¶
Neural networks are directed acyclic graphs (DAGs). To propagate bounds correctly, we must process nodes in topological order: dependencies before dependents.
For a simple chain Input → Conv → ReLU → Output, the order is obvious. But for graphs with skip connections (e.g., ResNets), multiple valid orderings exist:
ResNet block:
Input ──────────────┐
│ │
└→ Conv1 → Conv2 ─┴→ Add → ReLU → Output
Valid orders:
- Input, Conv1, Conv2, Add, ReLU, Output
- Input, Conv1, Conv2, Add, ReLU, Output (same, unique in this case)
Topological sorting guarantees that when we process Add, both inputs (from Conv2 and Input) have already been computed.
3.2 BFS vs DFS Topological Sort¶
PropDAG implements two topological sorting algorithms:
BFS (Breadth-First Search) - Kahn’s Algorithm:
def topo_sort_forward_bfs(nodes, verbose=False):
"""BFS topological sort."""
# Compute in-degrees (number of predecessors)
in_degrees = {node: len(set(node.pre_nodes)) for node in nodes}
# Start with nodes having no predecessors (input nodes)
queue = [node for node in nodes if in_degrees[node] == 0]
sorted_nodes = []
while queue:
node = queue.pop(0) # FIFO
sorted_nodes.append(node)
# Decrement successors' in-degrees
for next_node in node.next_nodes:
in_degrees[next_node] -= 1
if in_degrees[next_node] == 0:
queue.append(next_node)
# Cycle detection
if len(sorted_nodes) != len(nodes):
raise ValueError("Graph has a cycle")
return sorted_nodes
DFS (Depth-First Search):
def topo_sort_forward_dfs(nodes, verbose=False):
"""DFS topological sort."""
visited = set()
temp_mark = set() # For cycle detection
sorted_nodes = []
def dfs(node):
if node in temp_mark:
raise ValueError("Graph has a cycle") # Back edge
if node not in visited:
temp_mark.add(node)
for next_node in node.next_nodes:
dfs(next_node)
temp_mark.remove(node)
visited.add(node)
sorted_nodes.append(node) # Post-order
# Start from input nodes
for node in nodes:
if len(node.pre_nodes) == 0:
dfs(node)
return sorted_nodes[::-1] # Reverse post-order
BFS vs DFS trade-offs:
BFS: Processes nodes level-by-level. Better for parallelization (all nodes at one depth can be processed concurrently).
DFS: Processes one path completely before backtracking. Better cache locality (complete one branch before moving to another).
PropDAG lets users choose via configuration:
model = TModel(nodes, sort_strategy="bfs") # or "dfs"
3.3 Forward and Backward Traversal¶
PropDAG supports two propagation directions:
Forward traversal: Input → Output (compute bounds layer by layer)
for node in sorted_nodes: # Topological order
node.forward()
Backward traversal: Output → Input (back-substitute symbolic expressions)
# Compute backward topological sort for each node
backward_sorts = topo_sort_backward(nodes)
for node in sorted_nodes:
node.forward() # Forward pass
# Backward pass: substitute from this node to inputs
for bwd_node in backward_sorts[node]:
bwd_node.backward()
The PropMode enum controls which mode is used:
class PropMode(IntEnum):
FORWARD = 1 # Forward-only propagation
BACKWARD = 2 # Forward + backward substitution
3.4 Cache Management via Reference Counting¶
Naive caching (store all intermediate results) wastes memory. PropDAG uses reference counting: clear a node’s cache when all consumers have finished.
From TModel.run():
cache_counter = {node: len(node.next_nodes) for node in self._nodes}
for node in self._nodes:
node.forward()
# Decrement predecessors' reference counts
for pre_node in node.pre_nodes:
cache_counter[pre_node] -= 1
if cache_counter[pre_node] == 0:
pre_node.clear_fwd_cache() # No more consumers
This ensures nodes with multiple consumers (e.g., skip connection branches) stay cached until all consumers finish, then are cleared to free memory.
Part 4: Propagation Modes and Design Decisions¶
4.1 Forward and Backward Modes¶
PropDAG supports two propagation modes via the PropMode enum:
PropMode.FORWARD: Forward-only propagation
Compute bounds from inputs to outputs sequentially
Each node processes inputs from predecessors
Faster but potentially looser bounds
PropMode.BACKWARD: Forward + backward substitution
Forward pass: compute concrete bounds
Backward pass: back-substitute symbolic expressions to eliminate intermediate variables
Typically yields tighter bounds at higher computational cost
The mode is configured via TArgument:
# Forward-only
args = TArgument(prop_mode=PropMode.FORWARD)
model = TModel(nodes, arguments=args)
model.run()
# Forward + backward
args = TArgument(prop_mode=PropMode.BACKWARD)
model = TModel(nodes, arguments=args)
model.run()
4.2 Immutability via Frozen Dataclasses¶
TArgument uses Python’s frozen dataclasses:
@dataclass(frozen=True, slots=True)
class TArgument:
prop_mode: PropMode = PropMode.BACKWARD
Why frozen?
Prevents bugs: Can’t accidentally modify configuration mid-propagation
Thread-safe: Multiple threads can read without locks
Debuggable: State doesn’t change, easier to trace issues
The slots=True optimization reduces memory overhead by avoiding __dict__ storage.
4.3 DAG Representation¶
PropDAG represents graphs via adjacency lists: each node stores references to predecessors and successors:
class TNode:
@property
def pre_nodes(self) -> list["TNode"]:
"""Nodes that provide inputs to this node."""
return self._pre_nodes
@property
def next_nodes(self) -> list["TNode"]:
"""Nodes that consume outputs from this node."""
return self._next_nodes
This naturally handles skip connections and branching. For a ResNet block:
# Input node has two successors (branching)
input_node.next_nodes = [conv1_node, add_node]
# Add node has two predecessors (skip connection)
add_node.pre_nodes = [conv2_node, input_node]
No special handling needed—the list structure accommodates multiple inputs/outputs.
Part 5: Learning from Toy Examples¶
5.1 Toy Implementations¶
PropDAG includes toy/ implementations that demonstrate the framework with verbose logging:
ForwardToyNode: Shows forward-only propagation
class ForwardToyNode(TNode):
def forward(self):
if len(self.pre_nodes) == 0: # Input node
print(f"{self.name}: Skip input node")
return
self.cache.cur_node = self
self.build_rlx()
self.fwdprop_symbnd()
self.cal_and_update_cur_node_bnd()
def build_rlx(self):
print(f"{self.name}: Calculate relaxation if this is non-linear node")
def fwdprop_symbnd(self):
if len(self.pre_nodes) == 0:
print(f"{self.name}: Prepare symbolic bounds of {self.name}")
else:
pre_names = [n.name for n in self.pre_nodes]
print(f"{self.name}: Forward propagate symbolic bounds of {pre_names}")
print(f"{self.name}: Cache symbolic bounds")
self.cache.symbnds[self.name] = ("symbolic bounds",)
def cal_and_update_cur_node_bnd(self):
print(f"{self.name}: Calculate scalar bounds")
print(f"{self.name}: Cache scalar bounds")
self.cache.bnds[self.name] = ("scalar bounds",)
BackwardToyNode: Shows backward substitution
class BackwardToyNode(TNode):
def forward(self):
if len(self.pre_nodes) == 0:
print(f"{self.name}: Skip input node")
return
self.cache.cur_node = self
self.build_rlx()
self.init_symbnd()
def backward(self):
self.bwdprop_symbnd()
self.cal_and_update_cur_node_bnd()
def bwdprop_symbnd(self):
if self == self.cache.cur_node:
print(f"{self.name}: Prepare symbolic bounds of {self.name}")
else:
print(f"{self.name}: Backsubstitute symbolic bounds of {self.cache.cur_node.name}")
print(f"{self.name}: Cache substitution")
self.cache.symbnds[self.name] = (f"substitution of {self.name}",)
5.2 Example DAG Structure¶
The toy examples use a 6-node DAG with branching and merging:
Node-1
/ \
Node-2 Node-3
\ / \
Node-4 Node-5
\ /
Node-6
Running the example shows the framework’s execution flow:
from propdag import BackwardToyNode, PropMode, ToyArgument, ToyCache, ToyModel
cache = ToyCache()
cache.bnds["Node-1"] = ("input bounds",)
arguments = ToyArgument(prop_mode=PropMode.BACKWARD)
# Build node graph (code shows adjacency list construction)
node1 = BackwardToyNode("Node-1", cache, arguments)
# ... create nodes 2-6 ...
# Set graph edges
node1.next_nodes = [node2, node3]
node2.pre_nodes = [node1]
# ... etc ...
model = ToyModel([node1, node2, node3, node4, node5, node6], verbose=True)
model.run()
Output shows topological ordering, forward/backward passes, and cache management.
5.3 ToyCache Implementation¶
The toy cache demonstrates a simple caching strategy:
@dataclass(slots=True)
class ToyCache(TCache):
"""Cache for toy model computations."""
cur_node: TNode | None = None
symbnds: dict[str, tuple] = field(default_factory=OrderedDict)
bnds: dict[str, tuple] = field(default_factory=OrderedDict)
rlxs: dict[str, tuple] = field(default_factory=OrderedDict)
This stores:
symbnds: Symbolic bound expressionsbnds: Concrete numerical boundsrlxs: Linear relaxations for non-linear operations
Users can define custom cache implementations with different storage strategies.
Part 6: Design Philosophy and Key Decisions¶
6.1 Template Method vs Callbacks¶
Why template method pattern instead of callbacks?
Callback approach (considered but rejected):
def propagate(graph, forward_fn, backward_fn):
sorted_nodes = topological_sort(graph)
for node in sorted_nodes:
forward_fn(node) # Pass callback
Problems:
No node-specific state encapsulation
No polymorphism (different node types need different logic)
Hard to extend with new methods
Template method approach (used by PropDAG):
class TNode(ABC):
def forward(self): ...
def backward(self): ...
class TModel:
def run(self):
for node in self.nodes:
node.forward() # Polymorphic dispatch
Benefits:
State encapsulation (each node has its own data)
Polymorphism (different node types override methods differently)
Extensibility (add new abstract methods without breaking existing code)
6.2 Zero Dependencies¶
PropDAG has zero dependencies (pure Python 3.10+). No PyTorch, NumPy, JAX, or other libraries required.
Why?
User choice: Verification researchers have strong preferences for tensor libraries. PropDAG doesn’t force a choice.
No conflicts: Avoids dependency version conflicts
Lightweight: Small installation, no heavy frameworks
The trade-off: PropDAG’s pure Python is slower than NumPy. But this doesn’t matter—in practice, graph traversal (topological sorting, reference counting) accounts for <1% of verification runtime. The computational bottleneck is bound propagation (SMT solving, constraint generation), which users implement in their TNode.forward() methods. PropDAG’s pure Python infrastructure is fast enough to be imperceptible.
Users bring their own tensor library for the actual numerical computations in node implementations.
6.3 Educational Design¶
PropDAG prioritizes learnability:
Verbose toy examples: Show execution flow step-by-step
Clear abstractions: Abstract methods explicitly document what users must implement
Minimal API surface: Three main classes (TNode, TModel, TCache), simple to understand
The toy examples serve as templates: users copy-paste and modify for their algorithms.
6.4 Why Frozen Dataclasses¶
TArgument uses @dataclass(frozen=True, slots=True):
Frozen (immutability): - Prevents accidental configuration changes during propagation - Thread-safe (no locks needed) - Easier debugging (state doesn’t change mid-execution)
Slots:
- Reduces memory overhead (no __dict__ per instance)
- Faster attribute access
The slight inconvenience (must create new objects to change values) is worth the correctness guarantees.
Part 7: What PropDAG Provides¶
7.1 Core Infrastructure¶
PropDAG provides:
Topological sorting: BFS and DFS implementations with cycle detection
Graph traversal: Forward and backward passes with configurable modes
Cache management: Reference counting for automatic cache clearing
Abstract interfaces: Clear contracts for what users must implement
Template classes: Base implementations of TModel and TCache
Users get infrastructure for free and focus on algorithm implementation.
7.2 What Users Implement¶
To use PropDAG, users must:
Extend TNode: Implement abstract methods for their algorithm
Define TCache: Specify what to cache (bounds, symbolic expressions, etc.)
Build node graph: Create node instances and set
pre_nodes/next_nodesRun: Call
model.run()
Minimal example:
class MyNode(TNode):
def forward(self):
# Your bound propagation logic
...
def backward(self):
# Your backward substitution logic
...
# Implement other abstract methods as needed
# Build graph
nodes = [MyNode(...) for _ in range(n)]
# Set edges
nodes[0].next_nodes = [nodes[1]]
nodes[1].pre_nodes = [nodes[0]]
# ... etc ...
model = TModel(nodes, sort_strategy="bfs")
model.run()
7.3 Design Principles Evident in Code¶
Several design principles are evident from the codebase structure:
Separation of concerns: TModel handles graph-level logic (sorting, traversal), TNode handles node-level logic (bound computation).
Dependency inversion: TModel depends on abstract TNode interface, not concrete implementations. Users can swap node implementations without changing TModel.
Open-closed principle: Framework is closed for modification (users don’t edit TModel), but open for extension (users extend TNode).
Single responsibility: Each class has one job (TNode: compute bounds, TModel: orchestrate traversal, TCache: store results).
These aren’t explicitly documented but are evident from the code structure.
Conclusion¶
Where PropDAG Fits¶
PropDAG is an infrastructure framework for neural network verification research. It provides graph traversal infrastructure but doesn’t implement specific verification algorithms.
Think of it as:
Not a complete verification tool (doesn’t verify properties)
Not an algorithm library (doesn’t implement CROWN, DeepPoly, etc.)
Yes an infrastructure layer that eliminates boilerplate
Users still implement the interesting parts (bound propagation logic), but get graph traversal for free.
What PropDAG Provides¶
Concrete deliverables:
Abstract base classes (
TNode,TModel,TCache)Topological sorting (BFS/DFS with cycle detection)
Graph traversal orchestration (forward/backward modes)
Reference counting cache management
Toy examples with verbose logging
Design benefits:
Zero dependencies (pure Python)
Negligible overhead (graph traversal <1% of verification runtime)
Clear abstractions (template method pattern)
Educational examples (toy implementations)
Flexible configuration (BFS/DFS, forward/backward modes)
Try It Yourself¶
PropDAG is open-source:
git clone https://github.com/ZhongkuiMa/propdag.git
cd propdag
Quick start (from toy examples):
from propdag import BackwardToyNode, PropMode, ToyArgument, ToyCache, ToyModel
# Create cache and arguments
cache = ToyCache()
cache.bnds["Node-1"] = ("input bounds",)
arguments = ToyArgument(prop_mode=PropMode.BACKWARD)
# Create nodes
node1 = BackwardToyNode("Node-1", cache, arguments)
node2 = BackwardToyNode("Node-2", cache, arguments)
# ... etc ...
# Build graph (set pre_nodes, next_nodes)
node1.next_nodes = [node2]
node2.pre_nodes = [node1]
# Run
model = ToyModel([node1, node2], verbose=True)
model.run()
For real algorithms: Extend TNode with your bound propagation logic, define your TCache strategy, and build the node graph from your neural network.
Final Thoughts¶
PropDAG solves a simple problem: researchers shouldn’t reimplement graph traversal for every verification algorithm.
The solution: template method pattern + abstract base classes. PropDAG provides infrastructure (sorting, traversal, cache management). Users provide algorithm logic (bound propagation).
The framework is small (~500 lines), has zero dependencies, and provides clear abstractions for extending with new algorithms.
If you’re building neural network verification tools, PropDAG can eliminate infrastructure code and let you focus on research.
Comments & Discussion