Motivation
There is a clear gap in the state of the art for embedding trained ML models (NN surrogates, RL policies, learned residuals) into system simulation:
- ORT-as-blackbox (DNV
mlfmu, ONNX2FMU, Simulink Co-Execution, MuJoCo policy deployment): simple, but each solver step incurs an ORT Session.Run boundary. Inside implicit/Newton solvers this multiplies. No autodiff, no FPI friendliness.
- Native operator mapping (Simulink
importNetworkFromONNX, conceptually L4CasADi over PyTorch): fast and AD-capable, but framework-specific — no generic ONNX path.
No open-source project currently lowers ONNX ops onto an AD-capable JIT IR. That is precisely where fastsim's existing JIT (src/jit/) can sit naturally: we already have an SSA operator DAG, symbolic reverse-mode autodiff, hash-consing CSE, algebraic simplification, DCE, and FMA fusion. An ONNX importer that emits into the same graph gets all of that for free.
This issue proposes an ONNXFunction block that:
- Loads
.onnx at construction, eagerly lowers to the JIT SSA graph (shapes are statically known from ValueInfoProto)
- Executes in the Rust hot path with zero Python boundary and zero ORT dependency
- Composes like any other feedthrough block (connects into the Simulation DAG, participates in FPI)
- Gets AD for free via the existing backward pass — relevant for implicit solvers that need Jacobians through NN surrogates
Competitive context
| Approach |
AD |
Solver-step overhead |
Dependency footprint |
Scope |
ORT blackbox (mlfmu, ONNX2FMU) |
no |
Session.Run per step |
full ONNX Runtime C++ lib |
any ONNX |
| Simulink native |
yes |
none |
MATLAB closed |
framework-locked |
| L4CasADi |
yes |
none |
PyTorch dep, not ONNX |
PyTorch only |
| fastsim ONNXFunction (proposed) |
yes (free from JIT AD) |
none |
pure Rust, no ORT |
ONNX op whitelist |
No known project occupies the "portable ONNX + native lowering + AD" quadrant. This is a genuinely novel position.
Proposed architecture
User-facing API
from fastsim.blocks import ONNXFunction, ODE, Scope
# Pure feedthrough ML block in the DAG
policy = ONNXFunction(\"policy.onnx\") # input/output ports auto-inferred from model
plant = ODE(plant_rhs, initial_value=x0)
sim = Simulation(
blocks=[plant, policy],
connections=[
Connection(plant[:n], policy[:n]), # plant state → NN input
Connection(policy[:m], plant.u[:m]), # NN output → plant input
],
)
# Optional: hybrid physics+NN dynamics
residual = ONNXFunction(\"residual.onnx\")
sys = DynamicalSystem(
f=lambda x, u, t: f_phys(x, u, t) + residual(x, u), # residual() returns symbolic array in trace
initial_value=x0,
)
Eager lowering (no lazy-rejit)
Unlike Python callables whose input shapes only resolve at connection time, ONNX models carry full shape information in ValueInfoProto. We trace once at construction — no need for the LazyTraced shape-keyed cache path. Dynamic axes (dim_param) can fall back to the lazy path, but that is not MVP.
Op dispatch
ONNX op → SSA-graph builder function. Three categories:
Category 1 — direct mapping to existing graph ops (zero new primitives):
Add, Sub, Mul, Div, Pow, Mod, Min, Max, Neg, Abs, Sqrt, Exp, Log, Sin, Cos, Tan, Tanh, Sinh, Cosh, Erf, MatMul, Gemm (α·A·B + β·C), Equal, Greater, Less, Where, Clip, ReduceSum, ReduceMean, ReduceMax, ReduceMin, ReduceProd, Concat
Category 2 — synthesizable from existing graph ops (no new primitives, just builder composition):
Relu(x) → max(x, 0)
Sigmoid(x) → 1 / (1 + exp(-x))
Tanh(x) — direct
Gelu(x) → 0.5·x·(1 + erf(x/√2)) (we already have Erf as a graph unary)
Softplus(x) → log1p(exp(x))
LeakyRelu(x, α) → Select(x > 0, x, α·x)
Elu(x, α) → Select(x > 0, x, α·(exp(x) - 1))
Softmax(x) → Exp(x - max(x)) / ReduceSum(...) (use stable form)
LayerNormalization → (x - μ)/√(σ² + ε)·γ + β (all ops present)
BatchNormalization — fold constants where possible; runtime form is the same normalization
Category 3 — shape-metadata only (internally we are 1D-flat; these are Tracer-side reshapes, no runtime cost):
Reshape, Transpose, Squeeze, Unsqueeze, Flatten, Identity
Category 4 — genuine new work (post-MVP):
Conv / ConvTranspose — unrolling into MACs over existing ops; needs windowing logic in importer
MaxPool / AveragePool — windowing in importer
Slice, Gather, ScatterND — require proper index semantics; pair with Tier-1 JIT work on N-D indexing
LSTM / GRU / RNN — stateful; belong in DynamicalSystem/ODE, not in a stateless ONNXFunction. Deferred indefinitely.
What does not go in
- Control flow:
If, Loop, Scan break the SSA model. Hard error at import with a clear message.
- Detection-specific:
NonMaxSuppression, TopK, etc.
- Training-only:
Gradient, TrainingInfo. Irrelevant for inference-time integration.
Dropout / BatchNormalization in training mode: expect the user to run onnx-simplifier / onnxoptimizer first (standard pre-processing).
- Fallback runtime (
tract, ort): no. Pulling in a parallel ONNX runtime defeats the whole point (zero ORT boundary, AD through everything, single optimization pipeline). If a model uses unsupported ops, hard error — user can either (a) simplify the model, (b) use the mlfmu/ONNX2FMU route instead. Fastsim is not a universal ONNX runtime; it is a simulation-native ONNX importer for the productive op subset.
JIT coverage audit (baseline)
The existing JIT already provides everything we need except the frontend. From src/jit/graph.rs:
- Binary (10):
Add, Sub, Mul, Div, Pow, Mod, Min, Max, Atan2, Hypot
- Unary (32): all trig/hyperbolic + inverses,
Exp, Log, Log2, Log10, Log1p, Expm1, Abs, Sqrt, Cbrt, Sign, Floor, Ceil, Round, Trunc, Erf, Erfc, Lgamma, Tgamma, Digamma, Neg
- Cmp (6):
Gt, Ge, Lt, Le, Eq, Ne
- Control:
Select (branchless), Fma
Python/numpy surface (src/jit/tracer.rs) — relevant for composing ONNX-style ops symbolically:
- Linear algebra:
@, np.dot, np.matmul, np.vdot, np.inner, np.cross
- Reductions:
np.sum, prod, min, max, mean, var, std
- Array assembly:
np.stack, hstack, vstack, concatenate, asarray, array
- Control:
np.clip, np.where, fastsim.where_, clip
AD (src/jit/autodiff.rs):
- Reverse-mode symbolic, memoized
jacobian, jacobian_wrt_slot(name), jacobian_wrt_flat_range — all emit into the same graph, subject to the same optimize/DCE passes
Conclusion: the real work is the ONNX frontend + an N-D shape metadata layer on the importer side. Not a backend change.
Implementation plan
-
ONNX proto parsing
- Crate:
prost + generated Rust types from onnx.proto3
- Path:
src/jit/onnx/proto.rs (auto-generated)
- Alternative: hand-written minimal parser if we want zero build-time codegen. Not recommended;
prost is well-behaved.
-
Importer scaffolding (src/jit/onnx/importer.rs)
ModelProto → GraphProto → topological op walk
- Initializer tensors →
Graph::constant nodes (folded at build time)
ValueInfoProto → shape table keyed by tensor name
- Op dispatcher:
HashMap<op_type, fn(&mut Graph, &NodeProto, &ShapeCtx) -> PyResult<()>>
-
N-D shape layer (internal to the importer only — JIT core stays 1D-flat)
TensorView { shape: Vec<usize>, node_ids: Vec<NodeId> }: row-major flattening
- Broadcasting logic (N-D numpy rules) over
TensorView
Reshape / Transpose / Squeeze / Unsqueeze become pure metadata ops (shuffle node_ids, keep graph unchanged)
-
Op builders (src/jit/onnx/ops/*.rs, one module per category)
elementwise.rs: Add, Sub, Mul, Div, Pow, Min, Max, etc.
activations.rs: Relu, Sigmoid, Tanh, Gelu, Softplus, Softmax, LeakyRelu, Elu
linalg.rs: MatMul, Gemm
shape.rs: Reshape, Transpose, Squeeze, Unsqueeze, Flatten
reductions.rs: ReduceSum, ReduceMean, ReduceMax, ReduceMin, ReduceProd
norm.rs: LayerNorm, BatchNorm (constant-folded)
-
Block constructor (src/blocks/constructors/onnx.rs + PyO3 binding)
ONNXFunction::new(path: &str) → loads model, runs importer, returns a feedthrough block
- Input/output port dimensions inferred from
ValueInfoProto
- Role:
{ is_alg: true, is_dyn: false, is_src: false, is_rec: false } (stateless feedthrough)
- Slot names derived from ONNX input names for introspection / debugging
-
Error diagnostics
- Unsupported op → clear error naming the op type, the ONNX node name, and a suggestion (run
onnx-simplifier, or file an issue to add the op)
- Dynamic axes → fall back to lazy-rejit path (v2) or hard error naming the offending tensor (v1)
-
Tests
- Round-trip: train a tiny MLP in PyTorch → export ONNX →
ONNXFunction → compare to ORT reference within tolerance
- Hybrid: plant ODE + ONNX residual in a closed loop, check solver convergence
- AD: compile
jacobian_wrt_slot through an ONNX block, compare to FD
- Unsupported op: assert clean error
MVP scope
v1 (this issue): Categories 1-3, fully inference-capable on MLP/FFN-style models (the dominant shape of surrogates, policies, learned residuals in control applications).
v2 (follow-up): Conv/ConvTranspose + pooling for spatial surrogates. Depends on completing the JIT's N-D shape Tier-1 work separately.
Deferred indefinitely: stateful ops (LSTM, GRU, RNN), control flow (If, Loop, Scan), detection ops, training ops.
Dependencies on other JIT work
This issue interlocks with the broader JIT op-surface expansion:
- N-D shape metadata layer (importer-side for this issue; Tracer-side generally) — a prerequisite for a clean implementation here, so we should sequence those together.
- Slice / Gather op surface in the JIT — reuseable between ONNX Slice/Gather and user Python code that indexes arrays.
Open design questions
- Parameter surface: ONNX initializers are model weights. Fold all of them as
Const nodes (no runtime update), or expose a subset as fastsim Param nodes so users can retune without re-importing? Default: fold all. User can reload the model for weight changes.
- Shape fallback: if a model has
dim_param dynamic axes, do we error at construction or fall back to lazy-rejit? Leaning toward error with a clear message (users can freeze shapes with onnx-shape-inference + onnx-simplifier).
opset_version policy: which opset versions do we commit to supporting? Propose opset ≥ 13 for MVP (covers practically everything shipped since 2020).
- Integer / bool tensors: fastsim's JIT is f64-only. Do we reject int/bool tensors, or auto-cast at boundaries? Propose: accept at the Python/block interface, cast internally, reject if the model has true integer-only subgraphs (rare in inference models).
References
Reference implementations & background:
- DNV Open Source:
mlfmu — ONNX-to-FMU wrapper (ORT-based). https://github.com/dnv-opensource/mlfmu
- HSBI:
NonLinearSystemNeuralNetworkFMU.jl — ONNX replacing nonlinear blocks inside OpenModelica FMUs. https://github.com/AMIT-HSBI/NonLinearSystemNeuralNetworkFMU.jl
- Salzmann et al., Real-time Neural-MPC — arXiv:2203.07747; the L4CasADi approach that bypasses ONNX by mapping PyTorch directly into CasADi. Useful design reference for the "native lowering" paradigm.
- Thummerer et al., NeuralFMU — arXiv:2109.04351 + MDPI Electronics 11(19):3202. Structural integration of FMUs into Neural-ODE topologies — closest thing to what we want, but Julia-side and FMU-centric.
- Modelica Conference 2024: A Tool for the Implementation of ONNX Models in FMUs (ONNX2FMU). https://ecp.ep.liu.se/index.php/modelica/article/view/1351
ONNX spec:
Rust ONNX runtimes (reference only — not dependencies):
Motivation
There is a clear gap in the state of the art for embedding trained ML models (NN surrogates, RL policies, learned residuals) into system simulation:
mlfmu,ONNX2FMU, Simulink Co-Execution, MuJoCo policy deployment): simple, but each solver step incurs an ORT Session.Run boundary. Inside implicit/Newton solvers this multiplies. No autodiff, no FPI friendliness.importNetworkFromONNX, conceptually L4CasADi over PyTorch): fast and AD-capable, but framework-specific — no generic ONNX path.No open-source project currently lowers ONNX ops onto an AD-capable JIT IR. That is precisely where fastsim's existing JIT (
src/jit/) can sit naturally: we already have an SSA operator DAG, symbolic reverse-mode autodiff, hash-consing CSE, algebraic simplification, DCE, and FMA fusion. An ONNX importer that emits into the same graph gets all of that for free.This issue proposes an
ONNXFunctionblock that:.onnxat construction, eagerly lowers to the JIT SSA graph (shapes are statically known fromValueInfoProto)Competitive context
mlfmu,ONNX2FMU)No known project occupies the "portable ONNX + native lowering + AD" quadrant. This is a genuinely novel position.
Proposed architecture
User-facing API
Eager lowering (no lazy-rejit)
Unlike Python callables whose input shapes only resolve at connection time, ONNX models carry full shape information in
ValueInfoProto. We trace once at construction — no need for theLazyTracedshape-keyed cache path. Dynamic axes (dim_param) can fall back to the lazy path, but that is not MVP.Op dispatch
ONNX op → SSA-graph builder function. Three categories:
Category 1 — direct mapping to existing graph ops (zero new primitives):
Add, Sub, Mul, Div, Pow, Mod, Min, Max, Neg, Abs, Sqrt, Exp, Log, Sin, Cos, Tan, Tanh, Sinh, Cosh, Erf, MatMul, Gemm (α·A·B + β·C), Equal, Greater, Less, Where, Clip, ReduceSum, ReduceMean, ReduceMax, ReduceMin, ReduceProd, ConcatCategory 2 — synthesizable from existing graph ops (no new primitives, just builder composition):
Relu(x) → max(x, 0)Sigmoid(x) → 1 / (1 + exp(-x))Tanh(x)— directGelu(x) → 0.5·x·(1 + erf(x/√2))(we already haveErfas a graph unary)Softplus(x) → log1p(exp(x))LeakyRelu(x, α) → Select(x > 0, x, α·x)Elu(x, α) → Select(x > 0, x, α·(exp(x) - 1))Softmax(x) → Exp(x - max(x)) / ReduceSum(...)(use stable form)LayerNormalization → (x - μ)/√(σ² + ε)·γ + β(all ops present)BatchNormalization— fold constants where possible; runtime form is the same normalizationCategory 3 — shape-metadata only (internally we are 1D-flat; these are Tracer-side reshapes, no runtime cost):
Reshape, Transpose, Squeeze, Unsqueeze, Flatten, IdentityCategory 4 — genuine new work (post-MVP):
Conv/ConvTranspose— unrolling into MACs over existing ops; needs windowing logic in importerMaxPool/AveragePool— windowing in importerSlice,Gather,ScatterND— require proper index semantics; pair with Tier-1 JIT work on N-D indexingLSTM/GRU/RNN— stateful; belong inDynamicalSystem/ODE, not in a statelessONNXFunction. Deferred indefinitely.What does not go in
If,Loop,Scanbreak the SSA model. Hard error at import with a clear message.NonMaxSuppression,TopK, etc.Gradient,TrainingInfo. Irrelevant for inference-time integration.Dropout/BatchNormalizationin training mode: expect the user to runonnx-simplifier/onnxoptimizerfirst (standard pre-processing).tract,ort): no. Pulling in a parallel ONNX runtime defeats the whole point (zero ORT boundary, AD through everything, single optimization pipeline). If a model uses unsupported ops, hard error — user can either (a) simplify the model, (b) use themlfmu/ONNX2FMUroute instead. Fastsim is not a universal ONNX runtime; it is a simulation-native ONNX importer for the productive op subset.JIT coverage audit (baseline)
The existing JIT already provides everything we need except the frontend. From
src/jit/graph.rs:Add, Sub, Mul, Div, Pow, Mod, Min, Max, Atan2, HypotExp, Log, Log2, Log10, Log1p, Expm1, Abs, Sqrt, Cbrt, Sign, Floor, Ceil, Round, Trunc, Erf, Erfc, Lgamma, Tgamma, Digamma, NegGt, Ge, Lt, Le, Eq, NeSelect(branchless),FmaPython/numpy surface (
src/jit/tracer.rs) — relevant for composing ONNX-style ops symbolically:@,np.dot,np.matmul,np.vdot,np.inner,np.crossnp.sum, prod, min, max, mean, var, stdnp.stack, hstack, vstack, concatenate, asarray, arraynp.clip, np.where,fastsim.where_, clipAD (
src/jit/autodiff.rs):jacobian,jacobian_wrt_slot(name),jacobian_wrt_flat_range— all emit into the same graph, subject to the same optimize/DCE passesConclusion: the real work is the ONNX frontend + an N-D shape metadata layer on the importer side. Not a backend change.
Implementation plan
ONNX proto parsing
prost+ generated Rust types fromonnx.proto3src/jit/onnx/proto.rs(auto-generated)prostis well-behaved.Importer scaffolding (
src/jit/onnx/importer.rs)ModelProto→GraphProto→ topological op walkGraph::constantnodes (folded at build time)ValueInfoProto→ shape table keyed by tensor nameHashMap<op_type, fn(&mut Graph, &NodeProto, &ShapeCtx) -> PyResult<()>>N-D shape layer (internal to the importer only — JIT core stays 1D-flat)
TensorView { shape: Vec<usize>, node_ids: Vec<NodeId> }: row-major flatteningTensorViewReshape/Transpose/Squeeze/Unsqueezebecome pure metadata ops (shufflenode_ids, keep graph unchanged)Op builders (
src/jit/onnx/ops/*.rs, one module per category)elementwise.rs: Add, Sub, Mul, Div, Pow, Min, Max, etc.activations.rs: Relu, Sigmoid, Tanh, Gelu, Softplus, Softmax, LeakyRelu, Elulinalg.rs: MatMul, Gemmshape.rs: Reshape, Transpose, Squeeze, Unsqueeze, Flattenreductions.rs: ReduceSum, ReduceMean, ReduceMax, ReduceMin, ReduceProdnorm.rs: LayerNorm, BatchNorm (constant-folded)Block constructor (
src/blocks/constructors/onnx.rs+ PyO3 binding)ONNXFunction::new(path: &str)→ loads model, runs importer, returns a feedthrough blockValueInfoProto{ is_alg: true, is_dyn: false, is_src: false, is_rec: false }(stateless feedthrough)Error diagnostics
onnx-simplifier, or file an issue to add the op)Tests
ONNXFunction→ compare to ORT reference within tolerancejacobian_wrt_slotthrough an ONNX block, compare to FDMVP scope
v1 (this issue): Categories 1-3, fully inference-capable on MLP/FFN-style models (the dominant shape of surrogates, policies, learned residuals in control applications).
v2 (follow-up):
Conv/ConvTranspose+ pooling for spatial surrogates. Depends on completing the JIT's N-D shape Tier-1 work separately.Deferred indefinitely: stateful ops (
LSTM,GRU,RNN), control flow (If,Loop,Scan), detection ops, training ops.Dependencies on other JIT work
This issue interlocks with the broader JIT op-surface expansion:
Open design questions
Constnodes (no runtime update), or expose a subset as fastsimParamnodes so users can retune without re-importing? Default: fold all. User can reload the model for weight changes.dim_paramdynamic axes, do we error at construction or fall back to lazy-rejit? Leaning toward error with a clear message (users can freeze shapes withonnx-shape-inference+onnx-simplifier).opset_versionpolicy: which opset versions do we commit to supporting? Propose opset ≥ 13 for MVP (covers practically everything shipped since 2020).References
Reference implementations & background:
mlfmu— ONNX-to-FMU wrapper (ORT-based). https://github.com/dnv-opensource/mlfmuNonLinearSystemNeuralNetworkFMU.jl— ONNX replacing nonlinear blocks inside OpenModelica FMUs. https://github.com/AMIT-HSBI/NonLinearSystemNeuralNetworkFMU.jlONNX spec:
onnx.protodefinition: https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3Rust ONNX runtimes (reference only — not dependencies):
tract(Sonos): https://github.com/sonos/tract — pure Rust, no ADort(pykeio): https://github.com/pykeio/ort — ONNX Runtime C++ wrapper