feat(jax): add JAX-MD interface#5590
Conversation
📝 WalkthroughWalkthroughAdds a JAX-MD integration module for DeePMD JAX models, plus third-party documentation, a water NVE example, and tests covering direct and dense-neighbor evaluation paths. ChangesJAX-MD integration
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested labels
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
source/tests/jax/test_jax_md.py (1)
108-112: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valueOptional: prefer
pytest.importorskip("jax_md")for the optional dependency.Functionally equivalent to the
find_spec+unittest.skipIfcombination, but more idiomatic in a pytest suite and removes the need for theunittest/find_specimports used only here.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/jax/test_jax_md.py` around lines 108 - 112, The optional jax_md test currently uses a unittest-style skip check with find_spec, which is less idiomatic in this pytest suite. Update test_actual_jax_md_neighbor_list to use pytest.importorskip("jax_md") inside the test instead of the `@unittest.skipIf`(find_spec(...)) decorator, and remove the now-unused unittest and find_spec imports from this test module.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/jax/jax_md.py`:
- Around line 170-175: The neighbor-list creation path in neighbor_list
currently allows callers to override the format via kwargs.setdefault, which can
defer failures until _jax_md_neighbor_to_lower_inputs sees an incompatible
shape. Update neighbor_list to explicitly validate the requested
partition.NeighborListFormat before calling partition.neighbor_list, and reject
any non-Dense format with a clear error so unsupported sparse formats fail fast.
- Around line 179-197: Reject scalar metrics in as_jax_md and
_jax_md_neighbor_to_lower_inputs by adding an early shape validation before
ghost-coordinate math. Ensure displacement_or_metric passed into
neighbor_list/_jax_md_neighbor_to_lower_inputs returns vector displacements with
the same trailing shape as coordinate differences, and raise a clear error if a
scalar metric is provided. Use the existing symbols as_jax_md, neighbor_list,
and _jax_md_neighbor_to_lower_inputs to place the guard where the displacement
function is first consumed.
---
Nitpick comments:
In `@source/tests/jax/test_jax_md.py`:
- Around line 108-112: The optional jax_md test currently uses a unittest-style
skip check with find_spec, which is less idiomatic in this pytest suite. Update
test_actual_jax_md_neighbor_list to use pytest.importorskip("jax_md") inside the
test instead of the `@unittest.skipIf`(find_spec(...)) decorator, and remove the
now-unused unittest and find_spec imports from this test module.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 93bc6f18-93c7-4fea-8770-6d47044ae7d4
📒 Files selected for processing (6)
deepmd/jax/jax_md.pydoc/third-party/index.rstdoc/third-party/jaxmd.mdexamples/water/jax_md/README.mdexamples/water/jax_md/run_jax_md.pysource/tests/jax/test_jax_md.py
| kwargs.setdefault("format", partition.NeighborListFormat.Dense) | ||
| return partition.neighbor_list( | ||
| displacement_or_metric, | ||
| box, | ||
| r_cutoff=jax_model.get_rcut(), | ||
| **kwargs, |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟡 Minor
🧩 Analysis chain
🏁 Script executed:
cat -n deepmd/jax/jax_md.pyRepository: deepmodeling/deepmd-kit
Length of output: 17803
Fail fast when callers request a non-dense neighbor-list format.
kwargs.setdefault permits sparse formats to be passed to partition.neighbor_list, but _jax_md_neighbor_to_lower_inputs fails later during evaluation if the shape is not (N, max_occupancy). Rejecting unsupported formats in neighbor_list provides earlier, clearer feedback.
Proposed validation
- kwargs.setdefault("format", partition.NeighborListFormat.Dense)
+ neighbor_format = kwargs.setdefault("format", partition.NeighborListFormat.Dense)
+ if neighbor_format != partition.NeighborListFormat.Dense:
+ raise ValueError("Only dense JAX-MD neighbor lists are supported.")
return partition.neighbor_list(📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| kwargs.setdefault("format", partition.NeighborListFormat.Dense) | |
| return partition.neighbor_list( | |
| displacement_or_metric, | |
| box, | |
| r_cutoff=jax_model.get_rcut(), | |
| **kwargs, | |
| neighbor_format = kwargs.setdefault("format", partition.NeighborListFormat.Dense) | |
| if neighbor_format != partition.NeighborListFormat.Dense: | |
| raise ValueError("Only dense JAX-MD neighbor lists are supported.") | |
| return partition.neighbor_list( | |
| displacement_or_metric, | |
| box, | |
| r_cutoff=jax_model.get_rcut(), | |
| **kwargs, |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/jax/jax_md.py` around lines 170 - 175, The neighbor-list creation path
in neighbor_list currently allows callers to override the format via
kwargs.setdefault, which can defer failures until
_jax_md_neighbor_to_lower_inputs sees an incompatible shape. Update
neighbor_list to explicitly validate the requested partition.NeighborListFormat
before calling partition.neighbor_list, and reject any non-Dense format with a
clear error so unsupported sparse formats fail fast.
| def as_jax_md( | ||
| model: str | Path | Any, | ||
| displacement_or_metric: Callable[..., Array], | ||
| box: Array | Sequence[float], | ||
| atom_types: Sequence[int | str] | Array, | ||
| **kwargs: Any, | ||
| ) -> tuple[Any, EnergyFn]: | ||
| """Return ``(neighbor_fn, energy_fn)`` in the usual JAX-MD style.""" | ||
| jax_model = load_model(model) | ||
| potential = energy_fn( | ||
| jax_model, | ||
| atom_types, | ||
| box=_normalize_box(box), | ||
| displacement_fn=displacement_or_metric, | ||
| fparam=kwargs.pop("fparam", None), | ||
| aparam=kwargs.pop("aparam", None), | ||
| charge_spin=kwargs.pop("charge_spin", None), | ||
| ) | ||
| nlist_fn = neighbor_list(jax_model, displacement_or_metric, box, **kwargs) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major
🧩 Analysis chain
🏁 Script executed:
cat -n deepmd/jax/jax_md.py | head -400 | tail -40Repository: deepmodeling/deepmd-kit
Length of output: 1935
🏁 Script executed:
sed -n '170,200p' deepmd/jax/jax_md.pyRepository: deepmodeling/deepmd-kit
Length of output: 1105
🏁 Script executed:
ast-grep run --lang python --pattern 'def neighbor_list($_, $displacement_or_metric, $$$): $$$' deepmd/jax/jax_md.pyRepository: deepmodeling/deepmd-kit
Length of output: 161
🏁 Script executed:
grep -n "to_lower_interface" deepmd/jax/jax_md.pyRepository: deepmodeling/deepmd-kit
Length of output: 161
🏁 Script executed:
grep -n "lower" deepmd/jax/jax_md.py | head -20Repository: deepmodeling/deepmd-kit
Length of output: 862
🏁 Script executed:
sed -n '340,360p' deepmd/jax/jax_md.pyRepository: deepmodeling/deepmd-kit
Length of output: 1100
🏁 Script executed:
sed -n '340,410p' deepmd/jax/jax_md.pyRepository: deepmodeling/deepmd-kit
Length of output: 2881
🏁 Script executed:
grep -n "_jax_md_neighbor_to_lower_inputs" deepmd/jax/jax_md.pyRepository: deepmodeling/deepmd-kit
Length of output: 294
🏁 Script executed:
sed -n '310,345p' deepmd/jax/jax_md.pyRepository: deepmodeling/deepmd-kit
Length of output: 1199
🏁 Script executed:
sed -n '270,320p' deepmd/jax/jax_md.pyRepository: deepmodeling/deepmd-kit
Length of output: 1953
🏁 Script executed:
sed -n '200,270p' deepmd/jax/jax_md.pyRepository: deepmodeling/deepmd-kit
Length of output: 3008
🏁 Script executed:
python3 - <<'PY'
import jax.numpy as jnp
import jax
# Simulate JAX-MD displacement vs metric behavior
def disp_fn(c, n):
return c - n # Vector: shape (3,)
def metric_fn(c, n):
return jnp.linalg.norm(c - n) # Scalar: shape ()
# Inputs
c = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) # (2, 3)
n = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) # (2, 3)
neighbor_coord = jnp.array([[1.0, 2.0, 3.0], [0.0, 0.0, 0.0], [4.0, 5.0, 6.0], [1.0, 2.0, 3.0]]) # (2, 2, 3) - simulated
central_coord = jnp.broadcast_to(c[:, None, :], neighbor_coord.shape) # (2, 2, 3)
# Test 1: Vector Displacement
disp = jax.vmap(jax.vmap(disp_fn))(central_coord, neighbor_coord)
print(f"Displacement (Vector) shape: {disp.shape}")
ghost = central_coord - disp
print(f"Ghost computed shape: {ghost.shape}")
print(f"Match neighbor_coord shape: {ghost.shape == neighbor_coord.shape}")
# Test 2: Scalar Metric
met = jax.vmap(jax.vmap(metric_fn))(central_coord, neighbor_coord)
print(f"Metric (Scalar) shape: {met.shape}")
ghost_b = central_coord - met
print(f"Ghost (metric) computed shape: {ghost_b.shape}")
print(f"Value mismatch logic: Broadcasting occurred. Shape matches, values incorrect.")
PYRepository: deepmodeling/deepmd-kit
Length of output: 277
🏁 Script executed:
pip install jax jaxlib --quietRepository: deepmodeling/deepmd-kit
Length of output: 161
🏁 Script executed:
python3 - <<'PY'
import jax.numpy as jnp
import jax
# Simulate JAX-MD displacement vs metric behavior
def disp_fn(c, n):
return c - n # Vector: shape (3,)
def metric_fn(c, n):
return jnp.linalg.norm(c - n) # Scalar: shape ()
# Inputs
c = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) # (2, 3)
# Simulated neighbor coordinates: 2 atoms, each with 2 neighbors (indices)
neighbor_coord = jnp.array([
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], # neighbors for atom 0
[[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]] # neighbors for atom 1
]) # shape (2, 2, 3)
central_coord = jnp.broadcast_to(c[:, None, :], neighbor_coord.shape) # (2, 2, 3)
print("central_coord shape:", central_coord.shape)
print("neighbor_coord shape:", neighbor_coord.shape)
print()
# Test 1: Vector Displacement
disp = jax.vmap(jax.vmap(disp_fn))(central_coord, neighbor_coord)
print(f"Displacement (Vector) shape: {disp.shape}")
ghost = central_coord - disp
print(f"Ghost computed shape: {ghost.shape}")
print(f"Ghost values (first row):\n{ghost[0]}")
print()
# Test 2: Scalar Metric
met = jax.vmap(jax.vmap(metric_fn))(central_coord, neighbor_coord)
print(f"Metric (Scalar) shape: {met.shape}")
ghost_b = central_coord - met
print(f"Ghost (metric) computed shape: {ghost_b.shape}")
print(f"Ghost values (metric, first row) - INCORRECT:\n{ghost_b[0]}")
print()
print("Analysis:")
if ghost_b.shape == neighbor_coord.shape:
print("❌ Shapes match, so a raw shape check fails to catch the metric error.")
print("The formula 'central_coord - metric' is geometrically invalid.")
PYRepository: deepmodeling/deepmd-kit
Length of output: 4258
Reject scalar metrics to prevent shape mismatch errors in ghost coordinate calculation.
_jax_md_neighbor_to_lower_inputs assumes displacement_fn returns vector displacements. If a scalar metric is passed (e.g. from neighbor_list), the operation central_coord - displacement on line 380 raises a ValueError due to shape incompatibility between (N, N, 3) and (N, N). Add a shape check to fail earlier with a clear message.
Proposed guard
displacement = jax.vmap(
jax.vmap(
lambda central, neighbor: displacement_fn(
central, neighbor, **displacement_kwargs
)
)
)(central_coord, neighbor_coord)
+ if displacement.shape != neighbor_coord.shape:
+ raise ValueError(
+ "Dense neighbor evaluation requires a displacement function returning "
+ "vectors with shape (..., 3); scalar metric functions are not supported."
+ )
# JAX-MD displacement functions use the Ra - Rb convention.
ghost_coord = central_coord - displacementAlso applies to: 372-380
🧰 Tools
🪛 ast-grep (0.44.0)
[warning] 186-186: Loading a Keras model from an untrusted file can execute arbitrary code via Lambda layers or custom objects. Load only trusted models and avoid deserializing custom objects from untrusted sources.
Context: load_model(model)
Note: [CWE-502] Deserialization of Untrusted Data.
(keras-load-model-python)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/jax/jax_md.py` around lines 179 - 197, Reject scalar metrics in
as_jax_md and _jax_md_neighbor_to_lower_inputs by adding an early shape
validation before ghost-coordinate math. Ensure displacement_or_metric passed
into neighbor_list/_jax_md_neighbor_to_lower_inputs returns vector displacements
with the same trailing shape as coordinate differences, and raise a clear error
if a scalar metric is provided. Use the existing symbols as_jax_md,
neighbor_list, and _jax_md_neighbor_to_lower_inputs to place the guard where the
displacement function is first consumed.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5590 +/- ##
==========================================
- Coverage 82.27% 82.24% -0.04%
==========================================
Files 887 888 +1
Lines 100331 100503 +172
Branches 4060 4056 -4
==========================================
+ Hits 82550 82659 +109
- Misses 16320 16384 +64
+ Partials 1461 1460 -1 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Summary
as_jax_mdhelpersTests
ruff check .PYTHONPATH=. pytest source/tests/jax/test_jax_md.py -qSummary by CodeRabbit
New Features
Documentation
Tests