From 96f96cc4a678f29208a67272e0030fcda40dce0a Mon Sep 17 00:00:00 2001 From: Francesco Bertolotti Date: Fri, 5 Jun 2026 17:42:06 +0200 Subject: [PATCH] increased a bit tolerance for pytorch/distributed/run_numerics.py Signed-off-by: Francesco Bertolotti --- tests/pytorch/distributed/run_numerics.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index fe02f990b4..9514e2348f 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -207,8 +207,10 @@ def _get_tolerances(dtype): if dtype == torch.bfloat16: return {"rtol": 1.6e-2, "atol": 1e-5} if dtype == torch.float32: - # TF32 has same mantissa bits as FP16 - return {"rtol": 1e-3, "atol": 1e-5} + # TF32 has same mantissa bits as FP16. The atol is looser than for FP16 + # because near-zero gradient elements can differ by a few 1e-5 between + # the TP-sharded and single-device GEMM reduction orders (observed on A100). + return {"rtol": 1e-3, "atol": 5e-5} raise ValueError(f"Unsupported dtype ({dtype})")