test/ext/ep: HT — scale combine tolerance with bf16 ulp

At 16 nodes (64 ranks) with topk=8, expected combine values reach
rank*8 = 504, while intermediate partial sums (rank*7 etc.) cross the
bf16 ulp=2 boundary at 256. With the test pattern x = rank*ones and
weights = 1, this produces deterministic +/-1 round-off on certain
ranks (odd local_rank on nodes >= 9), tripping the previous 1e-2
absolute tolerance even though the kernel is correct.

Use tol = max(1e-2, max_exp / 64) which matches the bf16 mantissa
precision and scales with the magnitude of the expected combined
output. The previous tight bound is preserved for small-scale runs
where max_exp < 0.64.
This commit is contained in:
qinghuazhou
2026-05-12 05:37:42 +00:00
parent 3f459a995d
commit 13babbfff2

View File

@@ -259,7 +259,10 @@ def main():
diff = (got - expected).abs().max().item()
max_exp = expected.abs().max().item()
print(f"[combine r{rank}] max|got-expected|={diff:.4e} max|expected|={max_exp:.4e}", flush=True)
assert diff < 1e-2, f"rank{rank}: combine mismatch max diff {diff}"
# bf16 accumulator has 7-bit mantissa; intermediate partial sums can
# round at ulp = max_exp * 2**-7. Use a tolerance that scales with magnitude.
tol = max(1e-2, max_exp * (1.0 / 64))
assert diff <= tol, f"rank{rank}: combine mismatch max diff {diff} > tol {tol} (max_exp={max_exp})"
dist.barrier(group=group)
if rank == 0: