feat: support per-sample loss weighting (sample_weights)#100
Merged
Conversation
…ataset Add an optional sample_weights array (defaulting to 1 for every sample) that flows through __getitem__ and collate_fn into a new "sample_weights" key in the batch dict, so downstream loss computation can weight samples individually. Co-Authored-By: Claude Sonnet 5 <noreply@anthropic.com>
TextClassificationModule now reads sample_weights from the batch and applies them when computing the loss: losses that accept a sample_weights kwarg (e.g. MultiLevelCrossEntropyLoss) receive it directly, otherwise reduction is switched to "none" and the module computes the weighted mean itself. With all weights equal to 1 this is mathematically identical to the previous unweighted reduction. MultiLevelCrossEntropyLoss.forward gains an optional sample_weights argument, weighting each level's per-sample loss before combining them. Co-Authored-By: Claude Sonnet 5 <noreply@anthropic.com>
…hts on train() Add optional sample_weights/val_sample_weights parameters to train(), validated via a new _check_sample_weights helper (1D array, matching length, non-negative) and forwarded to the train/val TextClassificationDataset instances. Also documents the sample_weights contract for custom losses used via from_model(). Co-Authored-By: Claude Sonnet 5 <noreply@anthropic.com>
Add unit tests for default/custom sample_weights in TextClassificationDataset's collate output, weighted-loss computation in TextClassificationModule.step (including zero-weight equivalent to sample exclusion), MultiLevelCrossEntropyLoss weighting, wrapper-level validation, and an end-to-end train() run with sample_weights and val_sample_weights. Co-Authored-By: Claude Sonnet 5 <noreply@anthropic.com>
Add a "Weighting Individual Samples in the Loss" section to the architecture overview explaining how sample_weights flows through the dataset, batching, and loss computation, plus a shorter practical example in the quickstart guide. Co-Authored-By: Claude Sonnet 5 <noreply@anthropic.com>
Replace the inspect.signature probing of whether a loss accepts a sample_weights kwarg with a single, classic convention: any loss used with TextClassificationModule must expose per-sample (reduction="none") output. Standard torch.nn.*Loss objects are switched to reduction="none" automatically; custom losses (e.g. MultiLevelCrossEntropyLoss) just return an unreduced (batch,) tensor directly. TextClassificationModule always applies sample_weights and reduces the batch itself, with no branching. This also keeps torch.nn.CrossEntropyLoss's own per-class weight= argument untouched and fully compatible. Co-Authored-By: Claude Sonnet 5 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
sample_weights(andval_sample_weights) support so users can weight individual training/validation samples in the loss, defaulting to 1 (no change in behavior) when not provided.TextClassificationDataset→collate_fn/batch →TextClassificationModule.step()(andMultiLevelCrossEntropyLoss) →torchTextClassifiers.train().torch.nn.*Losslosses are automatically switched toreduction="none"internally so the weighted mean can be computed; this composes naturally with a loss's own per-classweight=argument (e.g.CrossEntropyLoss(weight=...)).sample_weightskwarg toforward(documented infrom_model()'s docstring).Test plan
uv run pytest tests/— 41 passed (HuggingFace extra installed viauv sync --extra huggingface)tests/test_sample_weights.py: dataset default/custom weights, weighted loss vs. manual computation, zero-weight ≡ sample exclusion,MultiLevelCrossEntropyLossweighting, wrapper-level validation, end-to-endtrain()run withsample_weights/val_sample_weightsruff checkclean on all touched files🤖 Generated with Claude Code