Skip to content

feat(pt): add custom save behaviors#5589

Open
OutisLi wants to merge 2 commits into
deepmodeling:masterfrom
OutisLi:pr/save
Open

feat(pt): add custom save behaviors#5589
OutisLi wants to merge 2 commits into
deepmodeling:masterfrom
OutisLi:pr/save

Conversation

@OutisLi

@OutisLi OutisLi commented Jun 25, 2026

Copy link
Copy Markdown
Collaborator

Summary by CodeRabbit

  • New Features

    • Added configurable checkpoint output locations for training and best-validation saves.
    • Added checkpoint retention by ratio, with automatic rounding and minimum retention safeguards.
  • Bug Fixes

    • Updated checkpoint path handling so periodic and EMA checkpoints are saved consistently, including when using a custom save directory.
    • Best-checkpoint files now land in the configured validation directory instead of the default location.
  • Documentation

    • Expanded training docs and example config with the new checkpoint settings.

@coderabbitai

coderabbitai Bot commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

Checkpoint retention, checkpoint directory selection, and checkpoint save-path handling are updated for PyTorch training and validation. New configuration fields, helper functions, and tests cover save directory, best-checkpoint directory, and ratio-based retention behavior.

Changes

Checkpoint path and retention updates

Layer / File(s) Summary
Utilities and config contract
deepmd/pt/train/utils.py, deepmd/utils/argcheck.py, doc/train/training-advanced.md, examples/water/dpa4/input.json, source/tests/pt/test_train_utils.py
New checkpoint helpers and training config fields are added for save_dir and ckpt_keep_ratio, with docs, an example config, and helper tests updated.
Trainer retention setup
deepmd/pt/train/training.py, source/tests/pt/test_training.py
Trainer now resolves save_dir, derives the keep count from ckpt_keep_ratio after num_steps is known, and updates regular and EMA retention limits.
Best checkpoint directory wiring
deepmd/pt/train/training.py, deepmd/pt/train/validation.py, deepmd/utils/argcheck.py, examples/water/dpa4/input.json, source/tests/pt/test_training.py, source/tests/pt/test_validation.py
resolve_best_checkpoint_dir is used for full-validation checkpoint directories, FullValidator creates the directory during initialization, and tests cover custom best-checkpoint locations.
Checkpoint save paths and symlinks
deepmd/pt/train/training.py, source/tests/pt/test_training.py
Periodic, final, and zero-step checkpoint writes now use latest_checkpoint_path(..., save_dir), and tests verify the resulting files and symlinks.

Sequence Diagram(s)

sequenceDiagram
  participant Trainer
  participant latest_checkpoint_path
  participant save_dir
  participant checkpoint_file
  Trainer->>latest_checkpoint_path: resolve prefix, step, and save_dir
  latest_checkpoint_path-->>Trainer: checkpoint path
  Trainer->>save_dir: write periodic checkpoint file
  Trainer->>checkpoint_file: update pointer to the resolved path
Loading
sequenceDiagram
  participant Trainer
  participant resolve_best_checkpoint_dir
  participant FullValidator
  participant checkpoint_dir
  Trainer->>resolve_best_checkpoint_dir: resolve validating.save_best_dir or save_ckpt parent
  resolve_best_checkpoint_dir-->>Trainer: checkpoint_dir
  Trainer->>FullValidator: create validator with checkpoint_dir
  FullValidator->>checkpoint_dir: mkdir(parents=True, exist_ok=True)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

Python

Suggested reviewers

  • njzjz
  • iProzd
  • wanghan-iapcm
🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title is clearly related to the main change: customizable checkpoint save and retention behavior in PyTorch training/validation.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt/train/utils.py`:
- Around line 283-284: The checkpoint retention calculation in the helper that
returns the keep count is undercounting because it ignores the final off-cadence
checkpoint written by Trainer.run() at num_steps. Update the logic around
total_periodic_ckpts/ckpt_keep_ratio so it accounts for the extra terminal
checkpoint (for example, by including the final step in the total when num_steps
is not an exact multiple of save_freq), and keep the existing max(1, ...)
safeguards intact.

In `@examples/water/dpa4/input.json`:
- Around line 123-124: The save_best_dir setting is unused in this example
because the validation path that triggers best-checkpoint saving is never
enabled. Update the input in this example by either turning on the
validating.full_validation flow so ckpt_best can be created, or remove the
save_best_dir field from the example to avoid misleading users; make the change
in the example configuration where tf32_infer and save_best_dir are defined.

In `@source/tests/pt/test_training.py`:
- Around line 967-968: Add the standard training test timeout guard to the new
validation test so it cannot hang CI; decorate
test_full_validation_save_best_dir with `@TRAINING_TEST_TIMEOUT` alongside the
existing `@patch` on FullValidator.evaluate_all_systems, matching the pattern used
by other training tests that call trainer.run().
- Around line 1228-1231: The checkpoint alias test is incorrectly asserting that
the prefix files are symlinks, which breaks on platforms where
symlink_prefix_files() falls back to copying. Update the test in the
checkpoint-saving area to validate the alias by checking that the Path resolves
to the expected target file, without requiring is_symlink(), using the existing
save_ckpt and ema_save_ckpt references so the test remains cross-platform.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 40ff16a5-1016-4465-ba48-a71de2e87b50

📥 Commits

Reviewing files that changed from the base of the PR and between 5733301 and fc780b1.

📒 Files selected for processing (9)
  • deepmd/pt/train/training.py
  • deepmd/pt/train/utils.py
  • deepmd/pt/train/validation.py
  • deepmd/utils/argcheck.py
  • doc/train/training-advanced.md
  • examples/water/dpa4/input.json
  • source/tests/pt/test_train_utils.py
  • source/tests/pt/test_training.py
  • source/tests/pt/test_validation.py

Comment thread deepmd/pt/train/utils.py
Comment on lines +283 to +284
total_periodic_ckpts = max(1, num_steps // save_freq)
return max(1, ceil(ckpt_keep_ratio * total_periodic_ckpts))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

Count the final off-cadence checkpoint in the keep-ratio calculation.

Trainer.run() still writes a checkpoint at display_step_id == self.num_steps, so a run like num_steps=5, save_freq=2 produces 2, 4, 5. Using num_steps // save_freq counts only 2 and makes ckpt_keep_ratio evict one checkpoint too early.

Proposed fix
-    total_periodic_ckpts = max(1, num_steps // save_freq)
+    total_periodic_ckpts = max(1, (num_steps + save_freq - 1) // save_freq)
     return max(1, ceil(ckpt_keep_ratio * total_periodic_ckpts))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
total_periodic_ckpts = max(1, num_steps // save_freq)
return max(1, ceil(ckpt_keep_ratio * total_periodic_ckpts))
total_periodic_ckpts = max(1, (num_steps + save_freq - 1) // save_freq)
return max(1, ceil(ckpt_keep_ratio * total_periodic_ckpts))
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/train/utils.py` around lines 283 - 284, The checkpoint retention
calculation in the helper that returns the keep count is undercounting because
it ignores the final off-cadence checkpoint written by Trainer.run() at
num_steps. Update the logic around total_periodic_ckpts/ckpt_keep_ratio so it
accounts for the extra terminal checkpoint (for example, by including the final
step in the total when num_steps is not an exact multiple of save_freq), and
keep the existing max(1, ...) safeguards intact.

Comment on lines +123 to +124
"tf32_infer": false,
"save_best_dir": "ckpt_best"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win

save_best_dir is a no-op in this example.

This file never enables validating.full_validation, so copying the example as-is will never create anything under ckpt_best. Either enable the validation path that exercises best-checkpoint saving, or drop this field from the example.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/water/dpa4/input.json` around lines 123 - 124, The save_best_dir
setting is unused in this example because the validation path that triggers
best-checkpoint saving is never enabled. Update the input in this example by
either turning on the validating.full_validation flow so ckpt_best can be
created, or remove the save_best_dir field from the example to avoid misleading
users; make the change in the example configuration where tf32_infer and
save_best_dir are defined.

Comment on lines +967 to +968
@patch("deepmd.pt.train.validation.FullValidator.evaluate_all_systems")
def test_full_validation_save_best_dir(self, mocked_eval) -> None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟡 Minor | ⚡ Quick win

Add the standard timeout guard to this training test.

This new case calls trainer.run() but isn't wrapped in @TRAINING_TEST_TIMEOUT, so a regression here can hang CI instead of failing fast. As per coding guidelines, **/tests/**/*training*.py: Set training test timeouts to 60 seconds maximum for validation purposes.

Proposed fix
+    `@TRAINING_TEST_TIMEOUT`
     `@patch`("deepmd.pt.train.validation.FullValidator.evaluate_all_systems")
     def test_full_validation_save_best_dir(self, mocked_eval) -> None:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@patch("deepmd.pt.train.validation.FullValidator.evaluate_all_systems")
def test_full_validation_save_best_dir(self, mocked_eval) -> None:
`@TRAINING_TEST_TIMEOUT`
`@patch`("deepmd.pt.train.validation.FullValidator.evaluate_all_systems")
def test_full_validation_save_best_dir(self, mocked_eval) -> None:
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/pt/test_training.py` around lines 967 - 968, Add the standard
training test timeout guard to the new validation test so it cannot hang CI;
decorate test_full_validation_save_best_dir with `@TRAINING_TEST_TIMEOUT`
alongside the existing `@patch` on FullValidator.evaluate_all_systems, matching
the pattern used by other training tests that call trainer.run().

Source: Coding guidelines

Comment on lines +1228 to +1231
for prefix in (save_ckpt, ema_save_ckpt):
link = Path(f"{prefix}.pt")
self.assertTrue(link.is_symlink())
self.assertEqual(link.resolve(), (save_dir / f"{prefix}-4.pt").resolve())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Don’t require symlinks in this cross-platform checkpoint test.

symlink_prefix_files() copies files on Windows, so is_symlink() is false there even when the checkpoint alias is correct. That makes this new test fail on a platform the helper already supports.

Proposed fix
         for prefix in (save_ckpt, ema_save_ckpt):
             link = Path(f"{prefix}.pt")
-            self.assertTrue(link.is_symlink())
-            self.assertEqual(link.resolve(), (save_dir / f"{prefix}-4.pt").resolve())
+            target = save_dir / f"{prefix}-4.pt"
+            self.assertTrue(link.exists())
+            if os.name != "nt":
+                self.assertTrue(link.is_symlink())
+                self.assertEqual(link.resolve(), target.resolve())
+            else:
+                self.assertTrue(link.is_file())
+                self.assertEqual(link.read_bytes(), target.read_bytes())
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for prefix in (save_ckpt, ema_save_ckpt):
link = Path(f"{prefix}.pt")
self.assertTrue(link.is_symlink())
self.assertEqual(link.resolve(), (save_dir / f"{prefix}-4.pt").resolve())
for prefix in (save_ckpt, ema_save_ckpt):
link = Path(f"{prefix}.pt")
target = save_dir / f"{prefix}-4.pt"
self.assertTrue(link.exists())
if os.name != "nt":
self.assertTrue(link.is_symlink())
self.assertEqual(link.resolve(), target.resolve())
else:
self.assertTrue(link.is_file())
self.assertEqual(link.read_bytes(), target.read_bytes())
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/pt/test_training.py` around lines 1228 - 1231, The checkpoint
alias test is incorrectly asserting that the prefix files are symlinks, which
breaks on platforms where symlink_prefix_files() falls back to copying. Update
the test in the checkpoint-saving area to validate the alias by checking that
the Path resolves to the expected target file, without requiring is_symlink(),
using the existing save_ckpt and ema_save_ckpt references so the test remains
cross-platform.

@njzjz-bot njzjz-bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the checkpoint save directory and ratio-based retention knobs. I found a few issues worth fixing before merge:

  1. ckpt_keep_ratio currently under-counts when the final checkpoint is off-cadence. In resolve_keep_ckpt_count(), total_periodic_ckpts = num_steps // save_freq ignores the final checkpoint that Trainer.run() still writes when num_steps % save_freq != 0. For example, num_steps=5, save_freq=2, ckpt_keep_ratio=0.5 produces checkpoints at steps 2, 4, and 5, but the helper returns ceil(0.5 * (5 // 2)) = 1; the documented formula ceil(ckpt_keep_ratio * numb_steps / save_freq) would keep 2. Please account for the terminal checkpoint, e.g. use ceil(num_steps / save_freq) (with the existing minimum-one guard).

  2. The new save_best_dir in examples/water/dpa4/input.json is misleading unless full validation is enabled. Since validating.full_validation defaults to false, ckpt_best will not actually be used by this example. Either enable the full-validation flow in the example or omit save_best_dir there.

  3. The new test_save_dir_redirects_checkpoints_with_local_symlinks assumes Path(...).is_symlink(), but symlink_prefix_files() copies files on Windows. If these tests are expected to be portable, please avoid requiring symlinks in the assertion (or explicitly scope the test/docs to non-Windows behavior).

Reviewed by OpenClaw 2026.6.8 (model: custom-chat-jinzhezeng-group/gpt-5.5).

@codecov

codecov Bot commented Jun 25, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 90.47619% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.28%. Comparing base (5733301) to head (fc780b1).

Files with missing lines Patch % Lines
deepmd/pt/train/training.py 82.60% 4 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##           master    #5589   +/-   ##
=======================================
  Coverage   82.27%   82.28%           
=======================================
  Files         887      887           
  Lines      100331   100361   +30     
  Branches     4060     4058    -2     
=======================================
+ Hits        82550    82581   +31     
+ Misses      16320    16318    -2     
- Partials     1461     1462    +1     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants