Skip to content

feat: support per-sample loss weighting (sample_weights)#100

Merged
meilame-tayebjee merged 6 commits into
mainfrom
feat/sample-weights
Jul 2, 2026
Merged

feat: support per-sample loss weighting (sample_weights)#100
meilame-tayebjee merged 6 commits into
mainfrom
feat/sample-weights

Conversation

@meilame-tayebjee

Copy link
Copy Markdown
Member

Summary

  • Add optional sample_weights (and val_sample_weights) support so users can weight individual training/validation samples in the loss, defaulting to 1 (no change in behavior) when not provided.
  • Threaded end-to-end: TextClassificationDatasetcollate_fn/batch → TextClassificationModule.step() (and MultiLevelCrossEntropyLoss) → torchTextClassifiers.train().
  • Standard torch.nn.*Loss losses are automatically switched to reduction="none" internally so the weighted mean can be computed; this composes naturally with a loss's own per-class weight= argument (e.g. CrossEntropyLoss(weight=...)).
  • Custom multi-task losses can opt in by adding an optional sample_weights kwarg to forward (documented in from_model()'s docstring).

Test plan

  • uv run pytest tests/ — 41 passed (HuggingFace extra installed via uv sync --extra huggingface)
  • New tests/test_sample_weights.py: dataset default/custom weights, weighted loss vs. manual computation, zero-weight ≡ sample exclusion, MultiLevelCrossEntropyLoss weighting, wrapper-level validation, end-to-end train() run with sample_weights/val_sample_weights
  • ruff check clean on all touched files

🤖 Generated with Claude Code

meilame-tayebjee and others added 6 commits July 2, 2026 14:17
…ataset

Add an optional sample_weights array (defaulting to 1 for every sample)
that flows through __getitem__ and collate_fn into a new "sample_weights"
key in the batch dict, so downstream loss computation can weight samples
individually.

Co-Authored-By: Claude Sonnet 5 <noreply@anthropic.com>
TextClassificationModule now reads sample_weights from the batch and
applies them when computing the loss: losses that accept a
sample_weights kwarg (e.g. MultiLevelCrossEntropyLoss) receive it
directly, otherwise reduction is switched to "none" and the module
computes the weighted mean itself. With all weights equal to 1 this is
mathematically identical to the previous unweighted reduction.

MultiLevelCrossEntropyLoss.forward gains an optional sample_weights
argument, weighting each level's per-sample loss before combining them.

Co-Authored-By: Claude Sonnet 5 <noreply@anthropic.com>
…hts on train()

Add optional sample_weights/val_sample_weights parameters to train(),
validated via a new _check_sample_weights helper (1D array, matching
length, non-negative) and forwarded to the train/val
TextClassificationDataset instances. Also documents the sample_weights
contract for custom losses used via from_model().

Co-Authored-By: Claude Sonnet 5 <noreply@anthropic.com>
Add unit tests for default/custom sample_weights in
TextClassificationDataset's collate output, weighted-loss computation
in TextClassificationModule.step (including zero-weight equivalent to
sample exclusion), MultiLevelCrossEntropyLoss weighting, wrapper-level
validation, and an end-to-end train() run with sample_weights and
val_sample_weights.

Co-Authored-By: Claude Sonnet 5 <noreply@anthropic.com>
Add a "Weighting Individual Samples in the Loss" section to the
architecture overview explaining how sample_weights flows through the
dataset, batching, and loss computation, plus a shorter practical
example in the quickstart guide.

Co-Authored-By: Claude Sonnet 5 <noreply@anthropic.com>
Replace the inspect.signature probing of whether a loss accepts a
sample_weights kwarg with a single, classic convention: any loss used
with TextClassificationModule must expose per-sample (reduction="none")
output. Standard torch.nn.*Loss objects are switched to reduction="none"
automatically; custom losses (e.g. MultiLevelCrossEntropyLoss) just
return an unreduced (batch,) tensor directly. TextClassificationModule
always applies sample_weights and reduces the batch itself, with no
branching. This also keeps torch.nn.CrossEntropyLoss's own per-class
weight= argument untouched and fully compatible.

Co-Authored-By: Claude Sonnet 5 <noreply@anthropic.com>
@meilame-tayebjee meilame-tayebjee merged commit d85249b into main Jul 2, 2026
5 checks passed
@meilame-tayebjee meilame-tayebjee deleted the feat/sample-weights branch July 2, 2026 16:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant