mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-24 14:54:51 +00:00
test/ext/ep: HT — scale combine tolerance with bf16 ulp
At 16 nodes (64 ranks) with topk=8, expected combine values reach rank*8 = 504, while intermediate partial sums (rank*7 etc.) cross the bf16 ulp=2 boundary at 256. With the test pattern x = rank*ones and weights = 1, this produces deterministic +/-1 round-off on certain ranks (odd local_rank on nodes >= 9), tripping the previous 1e-2 absolute tolerance even though the kernel is correct. Use tol = max(1e-2, max_exp / 64) which matches the bf16 mantissa precision and scales with the magnitude of the expected combined output. The previous tight bound is preserved for small-scale runs where max_exp < 0.64.
This commit is contained in:
@@ -259,7 +259,10 @@ def main():
|
||||
diff = (got - expected).abs().max().item()
|
||||
max_exp = expected.abs().max().item()
|
||||
print(f"[combine r{rank}] max|got-expected|={diff:.4e} max|expected|={max_exp:.4e}", flush=True)
|
||||
assert diff < 1e-2, f"rank{rank}: combine mismatch max diff {diff}"
|
||||
# bf16 accumulator has 7-bit mantissa; intermediate partial sums can
|
||||
# round at ulp = max_exp * 2**-7. Use a tolerance that scales with magnitude.
|
||||
tol = max(1e-2, max_exp * (1.0 / 64))
|
||||
assert diff <= tol, f"rank{rank}: combine mismatch max diff {diff} > tol {tol} (max_exp={max_exp})"
|
||||
|
||||
dist.barrier(group=group)
|
||||
if rank == 0:
|
||||
|
||||
Reference in New Issue
Block a user