[JAX] Extend tensor inspect utility to dump out tensors in identifiable names#3086
[JAX] Extend tensor inspect utility to dump out tensors in identifiable names#3086tdophung wants to merge 7 commits into
Conversation
…evice JIT
InspectPrimitive defines five operands (x, x_min, x_max, x_mean, x_std)
and one output but never overrode BasePrimitive.partition or
BasePrimitive.shardy_sharding_rule. The default sharding rule
"... -> ..." declares one operand, so any caller wrapped in
custom_partitioning -- i.e. any multi-device jax.jit -- fails at trace
time with:
ValueError: Sharding rule has 1 operands, but the operation has 5 operands
The existing test_debug_inspect_ffi only exercises a single-device
jax.jit with no mesh, where custom_partitioning skips the rule check,
so the bug was latent since the primitive was added. Add an identity
partition (output sharding mirrors x; the four scalar stats are
replicated) and a matching shardy rule ("..., , , , -> ...") so
inspect_array can be used inside multi-device JITs for diagnostics.
…ames
Previously the FFI hardcoded the output path to my_tensor_gpu{N}.bin and
ignored the `name` argument on the Python side, so every probe call in
a program overwrote the same files; the only surviving on-disk dumps
were whichever probe happened to fire last per rank. That made
multi-probe debugging (e.g. wiring TE_MOE_INSPECT through several fwd
and bwd steps of an MoE block) impossible to do offline -- only the
live printf log could be correlated, and only by shape/dtype.
Pass `name` through as an XLA FFI string attribute. On the C++ side it
gets sanitised to a POSIX-safe filename component
({[A-Za-z0-9._-]} preserved, everything else mapped to `_`) and used
as a suffix:
my_tensor_gpu{device}_{sanitized_name}.bin
my_tensor_gpu{device}_{sanitized_name}_meta.json
The unsanitised name is echoed verbatim in the JSON metadata and in the
printed log line so probe identity survives the rename.
On the Python side `name` is carried as a custom_vjp nondiff arg, threaded
into the InspectPrimitive bind as a static kwarg, and surfaced through
abstract / lowering / impl / partition / shardy_sharding_rule.
…n resolve it
The prior commit that threaded a probe `name` through the
InspectPrimitive declared it as a keyword-only argument (`*, name`)
on `impl` / `partition` and left `impl_static_args = ()`. With
JAX's `custom_partitioning`, that breaks at trace time:
TypeError: keyword arguments could not be resolved to positions
`register_primitive` wraps `cls.impl` as
`custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)`,
and the wrapper's `__call__` runs `_resolve_kwargs(self.fun, args,
kwargs)` to push bind-time kwargs back into positional slots. A
keyword-only parameter has no positional slot to resolve into, so any
`outer_primitive.bind(x, ..., name=name)` call from inside a
`jax.jit` aborts before the FFI ever runs.
Follow the established TE pattern (e.g. `ActLuPrimitive`,
`FusedMoEAuxLossBwdPrimitive`):
* Set `impl_static_args = (5,)` to declare position of `name`.
* Drop the `*` separator on `impl` so `name` is
positional-or-keyword; `_resolve_kwargs` can now push the
bind kwarg to position 5.
* Move static args to the head of `partition(name, mesh, ...)`
and `shardy_sharding_rule(name, mesh, ...)` per
`custom_partitioning`'s convention when `static_argnums`
is set on the wrapped impl.
* `abstract` and `lowering` keep `*, name` because they are
called by JAX directly with bind kwargs and never go through
`_resolve_kwargs`.
The user-facing API (`inspect_array(x, name)`) and the
`outer_primitive.bind(x, ..., name=name)` call site are unchanged.
for more information, see https://pre-commit.ci
Signed-off-by: tdophung <tdophung@nvidia.com>
Greptile SummaryThis PR wires a
Confidence Score: 4/5Safe to merge; the core name-threading logic, JAX custom_vjp wiring, and C++ FFI binding are all correct. Two issues flagged in a prior review thread (unescaped name in JSON, unused algorithm include) remain open but are not regressions. The change correctly routes name from Python through custom_vjp nondiff args, custom_partitioning static args, and the XLA FFI attribute layer. The filename sanitization prevents path traversal. The two open items from the prior thread are non-blocking for the core feature. The two open prior-thread findings both live in transformer_engine/jax/csrc/extensions/inspect.cpp. Important Files Changed
Reviews (2): Last reviewed commit: "Merge branch 'main' into teddy/inspect-a..." | Re-trigger Greptile |
| meta_file << "{"; | ||
| // Echo the original (un-sanitized) probe name so analysis tools can | ||
| // recover the semantic label even when the filename had to mangle it. | ||
| meta_file << "\"name\": \"" << name << "\", "; |
There was a problem hiding this comment.
Unescaped
name produces malformed JSON
The original name is streamed directly into the JSON value without escaping, so any name containing ", \, or ASCII control characters (newline, tab, etc.) silently produces invalid JSON that no standard parser can read. For example, inspect_array(x, 'attention"key"') writes "name": "attention"key"", which is structurally broken. The sanitized safe_name is already available for the filename; either reuse it for the JSON name field, or add a simple per-character JSON-escape loop before writing.
| #include <algorithm> | ||
| #include <fstream> | ||
| #include <iostream> | ||
| #include <string> | ||
| #include <string_view> |
There was a problem hiding this comment.
The
<algorithm> header was added in this diff but is never referenced — SanitizeProbeName uses a hand-written character loop with no std:: algorithm calls.
| #include <algorithm> | |
| #include <fstream> | |
| #include <iostream> | |
| #include <string> | |
| #include <string_view> | |
| #include <fstream> | |
| #include <iostream> | |
| #include <string> | |
| #include <string_view> |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM, thanks! Didn't realize you could pass strings as FFI args, good to know!
|
/te-ci |
Description
Previously tensors are dumped out with the same my_tensor_gpu*.bin name, and later dumps will overwrite previous dumps making debugging harder as developers have to run multiple times, each capturing a different tensor.
With this change, we only need to run once with all the tensors dumped out in separate name that is set-able in the API
Fixes # (issue)
Type of change
Changes
Checklist: