Skip to content

ONNXFunction block: native ONNX inference via JIT lowering (no ORT) #13

Description

@milanofthe

Motivation

There is a clear gap in the state of the art for embedding trained ML models (NN surrogates, RL policies, learned residuals) into system simulation:

  • ORT-as-blackbox (DNV mlfmu, ONNX2FMU, Simulink Co-Execution, MuJoCo policy deployment): simple, but each solver step incurs an ORT Session.Run boundary. Inside implicit/Newton solvers this multiplies. No autodiff, no FPI friendliness.
  • Native operator mapping (Simulink importNetworkFromONNX, conceptually L4CasADi over PyTorch): fast and AD-capable, but framework-specific — no generic ONNX path.

No open-source project currently lowers ONNX ops onto an AD-capable JIT IR. That is precisely where fastsim's existing JIT (src/jit/) can sit naturally: we already have an SSA operator DAG, symbolic reverse-mode autodiff, hash-consing CSE, algebraic simplification, DCE, and FMA fusion. An ONNX importer that emits into the same graph gets all of that for free.

This issue proposes an ONNXFunction block that:

  • Loads .onnx at construction, eagerly lowers to the JIT SSA graph (shapes are statically known from ValueInfoProto)
  • Executes in the Rust hot path with zero Python boundary and zero ORT dependency
  • Composes like any other feedthrough block (connects into the Simulation DAG, participates in FPI)
  • Gets AD for free via the existing backward pass — relevant for implicit solvers that need Jacobians through NN surrogates

Competitive context

Approach AD Solver-step overhead Dependency footprint Scope
ORT blackbox (mlfmu, ONNX2FMU) no Session.Run per step full ONNX Runtime C++ lib any ONNX
Simulink native yes none MATLAB closed framework-locked
L4CasADi yes none PyTorch dep, not ONNX PyTorch only
fastsim ONNXFunction (proposed) yes (free from JIT AD) none pure Rust, no ORT ONNX op whitelist

No known project occupies the "portable ONNX + native lowering + AD" quadrant. This is a genuinely novel position.

Proposed architecture

User-facing API

from fastsim.blocks import ONNXFunction, ODE, Scope

# Pure feedthrough ML block in the DAG
policy = ONNXFunction(\"policy.onnx\")  # input/output ports auto-inferred from model
plant  = ODE(plant_rhs, initial_value=x0)
sim = Simulation(
    blocks=[plant, policy],
    connections=[
        Connection(plant[:n], policy[:n]),    # plant state → NN input
        Connection(policy[:m], plant.u[:m]),  # NN output → plant input
    ],
)

# Optional: hybrid physics+NN dynamics
residual = ONNXFunction(\"residual.onnx\")
sys = DynamicalSystem(
    f=lambda x, u, t: f_phys(x, u, t) + residual(x, u),  # residual() returns symbolic array in trace
    initial_value=x0,
)

Eager lowering (no lazy-rejit)

Unlike Python callables whose input shapes only resolve at connection time, ONNX models carry full shape information in ValueInfoProto. We trace once at construction — no need for the LazyTraced shape-keyed cache path. Dynamic axes (dim_param) can fall back to the lazy path, but that is not MVP.

Op dispatch

ONNX op → SSA-graph builder function. Three categories:

Category 1 — direct mapping to existing graph ops (zero new primitives):
Add, Sub, Mul, Div, Pow, Mod, Min, Max, Neg, Abs, Sqrt, Exp, Log, Sin, Cos, Tan, Tanh, Sinh, Cosh, Erf, MatMul, Gemm (α·A·B + β·C), Equal, Greater, Less, Where, Clip, ReduceSum, ReduceMean, ReduceMax, ReduceMin, ReduceProd, Concat

Category 2 — synthesizable from existing graph ops (no new primitives, just builder composition):

  • Relu(x) → max(x, 0)
  • Sigmoid(x) → 1 / (1 + exp(-x))
  • Tanh(x) — direct
  • Gelu(x) → 0.5·x·(1 + erf(x/√2)) (we already have Erf as a graph unary)
  • Softplus(x) → log1p(exp(x))
  • LeakyRelu(x, α) → Select(x > 0, x, α·x)
  • Elu(x, α) → Select(x > 0, x, α·(exp(x) - 1))
  • Softmax(x) → Exp(x - max(x)) / ReduceSum(...) (use stable form)
  • LayerNormalization → (x - μ)/√(σ² + ε)·γ + β (all ops present)
  • BatchNormalization — fold constants where possible; runtime form is the same normalization

Category 3 — shape-metadata only (internally we are 1D-flat; these are Tracer-side reshapes, no runtime cost):
Reshape, Transpose, Squeeze, Unsqueeze, Flatten, Identity

Category 4 — genuine new work (post-MVP):

  • Conv / ConvTranspose — unrolling into MACs over existing ops; needs windowing logic in importer
  • MaxPool / AveragePool — windowing in importer
  • Slice, Gather, ScatterND — require proper index semantics; pair with Tier-1 JIT work on N-D indexing
  • LSTM / GRU / RNN — stateful; belong in DynamicalSystem/ODE, not in a stateless ONNXFunction. Deferred indefinitely.

What does not go in

  • Control flow: If, Loop, Scan break the SSA model. Hard error at import with a clear message.
  • Detection-specific: NonMaxSuppression, TopK, etc.
  • Training-only: Gradient, TrainingInfo. Irrelevant for inference-time integration.
  • Dropout / BatchNormalization in training mode: expect the user to run onnx-simplifier / onnxoptimizer first (standard pre-processing).
  • Fallback runtime (tract, ort): no. Pulling in a parallel ONNX runtime defeats the whole point (zero ORT boundary, AD through everything, single optimization pipeline). If a model uses unsupported ops, hard error — user can either (a) simplify the model, (b) use the mlfmu/ONNX2FMU route instead. Fastsim is not a universal ONNX runtime; it is a simulation-native ONNX importer for the productive op subset.

JIT coverage audit (baseline)

The existing JIT already provides everything we need except the frontend. From src/jit/graph.rs:

  • Binary (10): Add, Sub, Mul, Div, Pow, Mod, Min, Max, Atan2, Hypot
  • Unary (32): all trig/hyperbolic + inverses, Exp, Log, Log2, Log10, Log1p, Expm1, Abs, Sqrt, Cbrt, Sign, Floor, Ceil, Round, Trunc, Erf, Erfc, Lgamma, Tgamma, Digamma, Neg
  • Cmp (6): Gt, Ge, Lt, Le, Eq, Ne
  • Control: Select (branchless), Fma

Python/numpy surface (src/jit/tracer.rs) — relevant for composing ONNX-style ops symbolically:

  • Linear algebra: @, np.dot, np.matmul, np.vdot, np.inner, np.cross
  • Reductions: np.sum, prod, min, max, mean, var, std
  • Array assembly: np.stack, hstack, vstack, concatenate, asarray, array
  • Control: np.clip, np.where, fastsim.where_, clip

AD (src/jit/autodiff.rs):

  • Reverse-mode symbolic, memoized
  • jacobian, jacobian_wrt_slot(name), jacobian_wrt_flat_range — all emit into the same graph, subject to the same optimize/DCE passes

Conclusion: the real work is the ONNX frontend + an N-D shape metadata layer on the importer side. Not a backend change.

Implementation plan

  1. ONNX proto parsing

    • Crate: prost + generated Rust types from onnx.proto3
    • Path: src/jit/onnx/proto.rs (auto-generated)
    • Alternative: hand-written minimal parser if we want zero build-time codegen. Not recommended; prost is well-behaved.
  2. Importer scaffolding (src/jit/onnx/importer.rs)

    • ModelProtoGraphProto → topological op walk
    • Initializer tensors → Graph::constant nodes (folded at build time)
    • ValueInfoProto → shape table keyed by tensor name
    • Op dispatcher: HashMap<op_type, fn(&mut Graph, &NodeProto, &ShapeCtx) -> PyResult<()>>
  3. N-D shape layer (internal to the importer only — JIT core stays 1D-flat)

    • TensorView { shape: Vec<usize>, node_ids: Vec<NodeId> }: row-major flattening
    • Broadcasting logic (N-D numpy rules) over TensorView
    • Reshape / Transpose / Squeeze / Unsqueeze become pure metadata ops (shuffle node_ids, keep graph unchanged)
  4. Op builders (src/jit/onnx/ops/*.rs, one module per category)

    • elementwise.rs: Add, Sub, Mul, Div, Pow, Min, Max, etc.
    • activations.rs: Relu, Sigmoid, Tanh, Gelu, Softplus, Softmax, LeakyRelu, Elu
    • linalg.rs: MatMul, Gemm
    • shape.rs: Reshape, Transpose, Squeeze, Unsqueeze, Flatten
    • reductions.rs: ReduceSum, ReduceMean, ReduceMax, ReduceMin, ReduceProd
    • norm.rs: LayerNorm, BatchNorm (constant-folded)
  5. Block constructor (src/blocks/constructors/onnx.rs + PyO3 binding)

    • ONNXFunction::new(path: &str) → loads model, runs importer, returns a feedthrough block
    • Input/output port dimensions inferred from ValueInfoProto
    • Role: { is_alg: true, is_dyn: false, is_src: false, is_rec: false } (stateless feedthrough)
    • Slot names derived from ONNX input names for introspection / debugging
  6. Error diagnostics

    • Unsupported op → clear error naming the op type, the ONNX node name, and a suggestion (run onnx-simplifier, or file an issue to add the op)
    • Dynamic axes → fall back to lazy-rejit path (v2) or hard error naming the offending tensor (v1)
  7. Tests

    • Round-trip: train a tiny MLP in PyTorch → export ONNX → ONNXFunction → compare to ORT reference within tolerance
    • Hybrid: plant ODE + ONNX residual in a closed loop, check solver convergence
    • AD: compile jacobian_wrt_slot through an ONNX block, compare to FD
    • Unsupported op: assert clean error

MVP scope

v1 (this issue): Categories 1-3, fully inference-capable on MLP/FFN-style models (the dominant shape of surrogates, policies, learned residuals in control applications).

v2 (follow-up): Conv/ConvTranspose + pooling for spatial surrogates. Depends on completing the JIT's N-D shape Tier-1 work separately.

Deferred indefinitely: stateful ops (LSTM, GRU, RNN), control flow (If, Loop, Scan), detection ops, training ops.

Dependencies on other JIT work

This issue interlocks with the broader JIT op-surface expansion:

  • N-D shape metadata layer (importer-side for this issue; Tracer-side generally) — a prerequisite for a clean implementation here, so we should sequence those together.
  • Slice / Gather op surface in the JIT — reuseable between ONNX Slice/Gather and user Python code that indexes arrays.

Open design questions

  • Parameter surface: ONNX initializers are model weights. Fold all of them as Const nodes (no runtime update), or expose a subset as fastsim Param nodes so users can retune without re-importing? Default: fold all. User can reload the model for weight changes.
  • Shape fallback: if a model has dim_param dynamic axes, do we error at construction or fall back to lazy-rejit? Leaning toward error with a clear message (users can freeze shapes with onnx-shape-inference + onnx-simplifier).
  • opset_version policy: which opset versions do we commit to supporting? Propose opset ≥ 13 for MVP (covers practically everything shipped since 2020).
  • Integer / bool tensors: fastsim's JIT is f64-only. Do we reject int/bool tensors, or auto-cast at boundaries? Propose: accept at the Python/block interface, cast internally, reject if the model has true integer-only subgraphs (rare in inference models).

References

Reference implementations & background:

  • DNV Open Source: mlfmu — ONNX-to-FMU wrapper (ORT-based). https://github.com/dnv-opensource/mlfmu
  • HSBI: NonLinearSystemNeuralNetworkFMU.jl — ONNX replacing nonlinear blocks inside OpenModelica FMUs. https://github.com/AMIT-HSBI/NonLinearSystemNeuralNetworkFMU.jl
  • Salzmann et al., Real-time Neural-MPC — arXiv:2203.07747; the L4CasADi approach that bypasses ONNX by mapping PyTorch directly into CasADi. Useful design reference for the "native lowering" paradigm.
  • Thummerer et al., NeuralFMU — arXiv:2109.04351 + MDPI Electronics 11(19):3202. Structural integration of FMUs into Neural-ODE topologies — closest thing to what we want, but Julia-side and FMU-centric.
  • Modelica Conference 2024: A Tool for the Implementation of ONNX Models in FMUs (ONNX2FMU). https://ecp.ep.liu.se/index.php/modelica/article/view/1351

ONNX spec:

Rust ONNX runtimes (reference only — not dependencies):

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions