feat(pt): add custom save behaviors#5589
Conversation
📝 WalkthroughWalkthroughCheckpoint retention, checkpoint directory selection, and checkpoint save-path handling are updated for PyTorch training and validation. New configuration fields, helper functions, and tests cover save directory, best-checkpoint directory, and ratio-based retention behavior. ChangesCheckpoint path and retention updates
Sequence Diagram(s)sequenceDiagram
participant Trainer
participant latest_checkpoint_path
participant save_dir
participant checkpoint_file
Trainer->>latest_checkpoint_path: resolve prefix, step, and save_dir
latest_checkpoint_path-->>Trainer: checkpoint path
Trainer->>save_dir: write periodic checkpoint file
Trainer->>checkpoint_file: update pointer to the resolved path
sequenceDiagram
participant Trainer
participant resolve_best_checkpoint_dir
participant FullValidator
participant checkpoint_dir
Trainer->>resolve_best_checkpoint_dir: resolve validating.save_best_dir or save_ckpt parent
resolve_best_checkpoint_dir-->>Trainer: checkpoint_dir
Trainer->>FullValidator: create validator with checkpoint_dir
FullValidator->>checkpoint_dir: mkdir(parents=True, exist_ok=True)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt/train/utils.py`:
- Around line 283-284: The checkpoint retention calculation in the helper that
returns the keep count is undercounting because it ignores the final off-cadence
checkpoint written by Trainer.run() at num_steps. Update the logic around
total_periodic_ckpts/ckpt_keep_ratio so it accounts for the extra terminal
checkpoint (for example, by including the final step in the total when num_steps
is not an exact multiple of save_freq), and keep the existing max(1, ...)
safeguards intact.
In `@examples/water/dpa4/input.json`:
- Around line 123-124: The save_best_dir setting is unused in this example
because the validation path that triggers best-checkpoint saving is never
enabled. Update the input in this example by either turning on the
validating.full_validation flow so ckpt_best can be created, or remove the
save_best_dir field from the example to avoid misleading users; make the change
in the example configuration where tf32_infer and save_best_dir are defined.
In `@source/tests/pt/test_training.py`:
- Around line 967-968: Add the standard training test timeout guard to the new
validation test so it cannot hang CI; decorate
test_full_validation_save_best_dir with `@TRAINING_TEST_TIMEOUT` alongside the
existing `@patch` on FullValidator.evaluate_all_systems, matching the pattern used
by other training tests that call trainer.run().
- Around line 1228-1231: The checkpoint alias test is incorrectly asserting that
the prefix files are symlinks, which breaks on platforms where
symlink_prefix_files() falls back to copying. Update the test in the
checkpoint-saving area to validate the alias by checking that the Path resolves
to the expected target file, without requiring is_symlink(), using the existing
save_ckpt and ema_save_ckpt references so the test remains cross-platform.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 40ff16a5-1016-4465-ba48-a71de2e87b50
📒 Files selected for processing (9)
deepmd/pt/train/training.pydeepmd/pt/train/utils.pydeepmd/pt/train/validation.pydeepmd/utils/argcheck.pydoc/train/training-advanced.mdexamples/water/dpa4/input.jsonsource/tests/pt/test_train_utils.pysource/tests/pt/test_training.pysource/tests/pt/test_validation.py
| total_periodic_ckpts = max(1, num_steps // save_freq) | ||
| return max(1, ceil(ckpt_keep_ratio * total_periodic_ckpts)) |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major | ⚡ Quick win
Count the final off-cadence checkpoint in the keep-ratio calculation.
Trainer.run() still writes a checkpoint at display_step_id == self.num_steps, so a run like num_steps=5, save_freq=2 produces 2, 4, 5. Using num_steps // save_freq counts only 2 and makes ckpt_keep_ratio evict one checkpoint too early.
Proposed fix
- total_periodic_ckpts = max(1, num_steps // save_freq)
+ total_periodic_ckpts = max(1, (num_steps + save_freq - 1) // save_freq)
return max(1, ceil(ckpt_keep_ratio * total_periodic_ckpts))📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| total_periodic_ckpts = max(1, num_steps // save_freq) | |
| return max(1, ceil(ckpt_keep_ratio * total_periodic_ckpts)) | |
| total_periodic_ckpts = max(1, (num_steps + save_freq - 1) // save_freq) | |
| return max(1, ceil(ckpt_keep_ratio * total_periodic_ckpts)) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/pt/train/utils.py` around lines 283 - 284, The checkpoint retention
calculation in the helper that returns the keep count is undercounting because
it ignores the final off-cadence checkpoint written by Trainer.run() at
num_steps. Update the logic around total_periodic_ckpts/ckpt_keep_ratio so it
accounts for the extra terminal checkpoint (for example, by including the final
step in the total when num_steps is not an exact multiple of save_freq), and
keep the existing max(1, ...) safeguards intact.
| "tf32_infer": false, | ||
| "save_best_dir": "ckpt_best" |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
save_best_dir is a no-op in this example.
This file never enables validating.full_validation, so copying the example as-is will never create anything under ckpt_best. Either enable the validation path that exercises best-checkpoint saving, or drop this field from the example.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/water/dpa4/input.json` around lines 123 - 124, The save_best_dir
setting is unused in this example because the validation path that triggers
best-checkpoint saving is never enabled. Update the input in this example by
either turning on the validating.full_validation flow so ckpt_best can be
created, or remove the save_best_dir field from the example to avoid misleading
users; make the change in the example configuration where tf32_infer and
save_best_dir are defined.
| @patch("deepmd.pt.train.validation.FullValidator.evaluate_all_systems") | ||
| def test_full_validation_save_best_dir(self, mocked_eval) -> None: |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟡 Minor | ⚡ Quick win
Add the standard timeout guard to this training test.
This new case calls trainer.run() but isn't wrapped in @TRAINING_TEST_TIMEOUT, so a regression here can hang CI instead of failing fast. As per coding guidelines, **/tests/**/*training*.py: Set training test timeouts to 60 seconds maximum for validation purposes.
Proposed fix
+ `@TRAINING_TEST_TIMEOUT`
`@patch`("deepmd.pt.train.validation.FullValidator.evaluate_all_systems")
def test_full_validation_save_best_dir(self, mocked_eval) -> None:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @patch("deepmd.pt.train.validation.FullValidator.evaluate_all_systems") | |
| def test_full_validation_save_best_dir(self, mocked_eval) -> None: | |
| `@TRAINING_TEST_TIMEOUT` | |
| `@patch`("deepmd.pt.train.validation.FullValidator.evaluate_all_systems") | |
| def test_full_validation_save_best_dir(self, mocked_eval) -> None: |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@source/tests/pt/test_training.py` around lines 967 - 968, Add the standard
training test timeout guard to the new validation test so it cannot hang CI;
decorate test_full_validation_save_best_dir with `@TRAINING_TEST_TIMEOUT`
alongside the existing `@patch` on FullValidator.evaluate_all_systems, matching
the pattern used by other training tests that call trainer.run().
Source: Coding guidelines
| for prefix in (save_ckpt, ema_save_ckpt): | ||
| link = Path(f"{prefix}.pt") | ||
| self.assertTrue(link.is_symlink()) | ||
| self.assertEqual(link.resolve(), (save_dir / f"{prefix}-4.pt").resolve()) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
Don’t require symlinks in this cross-platform checkpoint test.
symlink_prefix_files() copies files on Windows, so is_symlink() is false there even when the checkpoint alias is correct. That makes this new test fail on a platform the helper already supports.
Proposed fix
for prefix in (save_ckpt, ema_save_ckpt):
link = Path(f"{prefix}.pt")
- self.assertTrue(link.is_symlink())
- self.assertEqual(link.resolve(), (save_dir / f"{prefix}-4.pt").resolve())
+ target = save_dir / f"{prefix}-4.pt"
+ self.assertTrue(link.exists())
+ if os.name != "nt":
+ self.assertTrue(link.is_symlink())
+ self.assertEqual(link.resolve(), target.resolve())
+ else:
+ self.assertTrue(link.is_file())
+ self.assertEqual(link.read_bytes(), target.read_bytes())📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| for prefix in (save_ckpt, ema_save_ckpt): | |
| link = Path(f"{prefix}.pt") | |
| self.assertTrue(link.is_symlink()) | |
| self.assertEqual(link.resolve(), (save_dir / f"{prefix}-4.pt").resolve()) | |
| for prefix in (save_ckpt, ema_save_ckpt): | |
| link = Path(f"{prefix}.pt") | |
| target = save_dir / f"{prefix}-4.pt" | |
| self.assertTrue(link.exists()) | |
| if os.name != "nt": | |
| self.assertTrue(link.is_symlink()) | |
| self.assertEqual(link.resolve(), target.resolve()) | |
| else: | |
| self.assertTrue(link.is_file()) | |
| self.assertEqual(link.read_bytes(), target.read_bytes()) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@source/tests/pt/test_training.py` around lines 1228 - 1231, The checkpoint
alias test is incorrectly asserting that the prefix files are symlinks, which
breaks on platforms where symlink_prefix_files() falls back to copying. Update
the test in the checkpoint-saving area to validate the alias by checking that
the Path resolves to the expected target file, without requiring is_symlink(),
using the existing save_ckpt and ema_save_ckpt references so the test remains
cross-platform.
njzjz-bot
left a comment
There was a problem hiding this comment.
Thanks for adding the checkpoint save directory and ratio-based retention knobs. I found a few issues worth fixing before merge:
-
ckpt_keep_ratiocurrently under-counts when the final checkpoint is off-cadence. Inresolve_keep_ckpt_count(),total_periodic_ckpts = num_steps // save_freqignores the final checkpoint thatTrainer.run()still writes whennum_steps % save_freq != 0. For example,num_steps=5,save_freq=2,ckpt_keep_ratio=0.5produces checkpoints at steps 2, 4, and 5, but the helper returnsceil(0.5 * (5 // 2)) = 1; the documented formulaceil(ckpt_keep_ratio * numb_steps / save_freq)would keep 2. Please account for the terminal checkpoint, e.g. useceil(num_steps / save_freq)(with the existing minimum-one guard). -
The new
save_best_dirinexamples/water/dpa4/input.jsonis misleading unless full validation is enabled. Sincevalidating.full_validationdefaults to false,ckpt_bestwill not actually be used by this example. Either enable the full-validation flow in the example or omitsave_best_dirthere. -
The new
test_save_dir_redirects_checkpoints_with_local_symlinksassumesPath(...).is_symlink(), butsymlink_prefix_files()copies files on Windows. If these tests are expected to be portable, please avoid requiring symlinks in the assertion (or explicitly scope the test/docs to non-Windows behavior).
Reviewed by OpenClaw 2026.6.8 (model: custom-chat-jinzhezeng-group/gpt-5.5).
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5589 +/- ##
=======================================
Coverage 82.27% 82.28%
=======================================
Files 887 887
Lines 100331 100361 +30
Branches 4060 4058 -2
=======================================
+ Hits 82550 82581 +31
+ Misses 16320 16318 -2
- Partials 1461 1462 +1 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Summary by CodeRabbit
New Features
Bug Fixes
Documentation