diff --git a/test/python/ext/ep/test_internode_multirank.py b/test/python/ext/ep/test_internode_multirank.py index 901a67a7..eddfb6be 100644 --- a/test/python/ext/ep/test_internode_multirank.py +++ b/test/python/ext/ep/test_internode_multirank.py @@ -259,7 +259,10 @@ def main(): diff = (got - expected).abs().max().item() max_exp = expected.abs().max().item() print(f"[combine r{rank}] max|got-expected|={diff:.4e} max|expected|={max_exp:.4e}", flush=True) - assert diff < 1e-2, f"rank{rank}: combine mismatch max diff {diff}" + # bf16 accumulator has 7-bit mantissa; intermediate partial sums can + # round at ulp = max_exp * 2**-7. Use a tolerance that scales with magnitude. + tol = max(1e-2, max_exp * (1.0 / 64)) + assert diff <= tol, f"rank{rank}: combine mismatch max diff {diff} > tol {tol} (max_exp={max_exp})" dist.barrier(group=group) if rank == 0: