add debugging code

This commit is contained in:
Ubuntu
2026-04-03 21:30:21 +00:00
parent b04fa2daa7
commit a4bb8fb4bf
2 changed files with 212 additions and 3 deletions

View File

@@ -166,9 +166,11 @@ def build_bufs(
else:
input_buf = GpuBuffer(nelems_input, dtype=dtype)
in_place = False
test_buf = cp.zeros(nelems, dtype=dtype)
return input_buf, result_buf, test_buf
return input_buf, result_buf, test_buf, nelems
def main(
@@ -190,7 +192,7 @@ def main(
collective = execution_plan.collective
dtype = parse_dtype(dtype_str)
input_buf, result_buf, test_buf = build_bufs(
input_buf, result_buf, test_buf, nelem = build_bufs(
collective,
size,
in_place,
@@ -212,6 +214,22 @@ def main(
)
mscclpp_group.barrier()
print("size= ", size, "nelem= ", nelem)
# Sentinel fill: choose something unlikely in your pattern
result_buf.fill(cp.float16(123.0))
cp.cuda.runtime.deviceSynchronize()
# Run ONE execution (no graph), then sync
stream = cp.cuda.Stream(non_blocking=True)
with stream:
executor_func(stream)
stream.synchronize()
# Count how many elements changed
changed = cp.count_nonzero(result_buf != cp.float16(123.0)).item()
print("changed elements:", changed, "out of", result_buf.size)
bench_correctness(
collective,
input_buf,