mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
Add detection of torch.baseline and debug info
This commit is contained in:
@@ -295,10 +295,10 @@ def main():
|
||||
use_torch_baseline = (backend == "nccl")
|
||||
if use_torch_baseline:
|
||||
try:
|
||||
# Quick test: if the NCCL shim is active it may not support all_to_all_single
|
||||
tiny_in = torch.zeros(world_size, dtype=torch.float32, device='cuda')
|
||||
tiny_out = torch.zeros(world_size, dtype=torch.float32, device='cuda')
|
||||
dist.all_to_all_single(tiny_out, tiny_in)
|
||||
torch.cuda.synchronize()
|
||||
except Exception:
|
||||
use_torch_baseline = False
|
||||
if rank == 0:
|
||||
@@ -387,8 +387,19 @@ def main():
|
||||
|
||||
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp_view, out_view, in_splits, out_splits, n_warmup, n_iters)
|
||||
if use_torch_baseline:
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp_view, out_view, in_splits, out_splits, n_warmup, n_iters)
|
||||
print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw)
|
||||
try:
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp_view, out_view, in_splits, out_splits, n_warmup, n_iters)
|
||||
print_row(fmt_size(avg_msg_size), m_lat, m_bw, t_lat, t_bw)
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
print(f" [WARN] torch baseline failed: {e}")
|
||||
print(f" [INFO] Disabling torch baseline for remaining sizes")
|
||||
use_torch_baseline = False
|
||||
try:
|
||||
torch.cuda.synchronize()
|
||||
except Exception:
|
||||
pass
|
||||
print_row(fmt_size(avg_msg_size), m_lat, m_bw)
|
||||
else:
|
||||
print_row(fmt_size(avg_msg_size), m_lat, m_bw)
|
||||
|
||||
@@ -459,12 +470,22 @@ def main():
|
||||
n_warmup, n_iters = 5, 20
|
||||
|
||||
m_lat, m_bw = bench_alltoallv(mscclpp_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
avg_bytes = total_bytes // world_size
|
||||
if use_torch_baseline:
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
avg_bytes = total_bytes // world_size
|
||||
print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw)
|
||||
try:
|
||||
t_lat, t_bw = bench_alltoallv(torch_fn, inp, out, in_splits, out_splits, n_warmup, n_iters)
|
||||
print_row(fmt_size(avg_bytes), m_lat, m_bw, t_lat, t_bw)
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
print(f" [WARN] torch baseline failed: {e}")
|
||||
print(f" [INFO] Disabling torch baseline for remaining workloads")
|
||||
use_torch_baseline = False
|
||||
try:
|
||||
torch.cuda.synchronize()
|
||||
except Exception:
|
||||
pass
|
||||
print_row(fmt_size(avg_bytes), m_lat, m_bw)
|
||||
else:
|
||||
avg_bytes = total_bytes // world_size
|
||||
print_row(fmt_size(avg_bytes), m_lat, m_bw)
|
||||
else:
|
||||
if rank == 0:
|
||||
|
||||
Reference in New Issue
Block a user