From 13babbfff27280f52d275c00bef50c9ef7a4bd04 Mon Sep 17 00:00:00 2001 From: qinghuazhou Date: Tue, 12 May 2026 05:37:42 +0000 Subject: [PATCH] =?UTF-8?q?test/ext/ep:=20HT=20=E2=80=94=20scale=20combine?= =?UTF-8?q?=20tolerance=20with=20bf16=20ulp?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit At 16 nodes (64 ranks) with topk=8, expected combine values reach rank*8 = 504, while intermediate partial sums (rank*7 etc.) cross the bf16 ulp=2 boundary at 256. With the test pattern x = rank*ones and weights = 1, this produces deterministic +/-1 round-off on certain ranks (odd local_rank on nodes >= 9), tripping the previous 1e-2 absolute tolerance even though the kernel is correct. Use tol = max(1e-2, max_exp / 64) which matches the bf16 mantissa precision and scales with the magnitude of the expected combined output. The previous tight bound is preserved for small-scale runs where max_exp < 0.64. --- test/python/ext/ep/test_internode_multirank.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/python/ext/ep/test_internode_multirank.py b/test/python/ext/ep/test_internode_multirank.py index 901a67a7..eddfb6be 100644 --- a/test/python/ext/ep/test_internode_multirank.py +++ b/test/python/ext/ep/test_internode_multirank.py @@ -259,7 +259,10 @@ def main(): diff = (got - expected).abs().max().item() max_exp = expected.abs().max().item() print(f"[combine r{rank}] max|got-expected|={diff:.4e} max|expected|={max_exp:.4e}", flush=True) - assert diff < 1e-2, f"rank{rank}: combine mismatch max diff {diff}" + # bf16 accumulator has 7-bit mantissa; intermediate partial sums can + # round at ulp = max_exp * 2**-7. Use a tolerance that scales with magnitude. + tol = max(1e-2, max_exp * (1.0 / 64)) + assert diff <= tol, f"rank{rank}: combine mismatch max diff {diff} > tol {tol} (max_exp={max_exp})" dist.barrier(group=group) if rank == 0: