Add detection of torch.baseline and debug info

This commit is contained in:
Qinghua Zhou
2026-03-25 01:52:24 +00:00
parent 8e22010560
commit ec011f14ea
2 changed files with 75 additions and 14 deletions

View File

@@ -295,10 +295,10 @@ def main():
use_torch_baseline = (backend == "nccl")
if use_torch_baseline:
try:
# Quick test: if the NCCL shim is active it may not support all_to_all_single
tiny_in = torch.zeros(world_size, dtype=torch.float32, device='cuda')
tiny_out = torch.zeros(world_size, dtype=torch.float32, device='cuda')
dist.all_to_all_single(tiny_out, tiny_in)
torch.cuda.synchronize()
except Exception:
use_torch_baseline = False
if rank == 0:
@@ -387,8 +387,19 @@ def main():
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp_view, out_view, in_splits, out_splits, n_warmup, n_iters)
if use_torch_baseline:
t_lat, t_bw = bench_alltoallv(torch_fn, inp_view, out_view, in_splits, out_splits, n_warmup, n_iters)
print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw)
try:
t_lat, t_bw = bench_alltoallv(torch_fn, inp_view, out_view, in_splits, out_splits, n_warmup, n_iters)
print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw)
except Exception as e:
if rank == 0:
print(f" [WARN] torch baseline failed: {e}")
print(f" [INFO] Disabling torch baseline for remaining sizes")
use_torch_baseline = False
try:
torch.cuda.synchronize()
except Exception:
pass
print_row(fmt_size(avg_msg_size), m_lat, m_bw)
else:
print_row(fmt_size(avg_msg_size), m_lat, m_bw)
@@ -459,12 +470,22 @@ def main():
n_warmup, n_iters = 5, 20
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
avg_bytes = total_bytes // world_size
if use_torch_baseline:
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
avg_bytes = total_bytes // world_size
print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw)
try:
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw)
except Exception as e:
if rank == 0:
print(f" [WARN] torch baseline failed: {e}")
print(f" [INFO] Disabling torch baseline for remaining workloads")
use_torch_baseline = False
try:
torch.cuda.synchronize()
except Exception:
pass
print_row(fmt_size(avg_bytes), m_lat, m_bw)
else:
avg_bytes = total_bytes // world_size
print_row(fmt_size(avg_bytes), m_lat, m_bw)
else:
if rank == 0: