Add debug variable MSCCLPP_DEBUG_ALLTOALLV_to print

This commit is contained in:
Qinghua Zhou
2026-04-02 04:39:48 +00:00
parent 36940dbacf
commit 520c890df5
2 changed files with 38 additions and 26 deletions

View File

@@ -14,6 +14,8 @@ import torch.distributed as dist
import os
import sys
import time
_DEBUG = os.environ.get("MSCCLPP_DEBUG_ALLTOALLV", "0") == "1"
import random
import socket
import struct
@@ -166,7 +168,7 @@ def main():
except Exception:
ib_devices = []
if rank == 0:
if rank == 0 and _DEBUG:
print(f" Hostname: {hostname}")
print(f" nRanksPerNode: {n_ranks_per_node}, isMultiNode: {is_multi_node}")
print(f" IB devices: {ib_devices if ib_devices else 'NONE FOUND'}")
@@ -174,7 +176,7 @@ def main():
if is_multi_node and not ib_devices:
print(f" WARNING: Multi-node detected but no IB devices! Cross-node will fail.")
# Also print from rank n_ranks_per_node (first rank on node 1) for comparison
if is_multi_node and rank == n_ranks_per_node:
if is_multi_node and rank == n_ranks_per_node and _DEBUG:
print(f" [Node 1] Hostname: {hostname}, rank={rank}")
print(f" [Node 1] IB devices: {ib_devices if ib_devices else 'NONE FOUND'}")
# ── End diagnostics ────────────────────────────────────────────────
@@ -184,7 +186,7 @@ def main():
# Create MscclppAlltoAllV with existing communicator
alltoallv = MscclppAlltoAllV(communicator=comm)
if rank == 0:
if rank == 0 and _DEBUG:
print(f"MscclppAlltoAllV initialized")
print(f"Algorithm: {alltoallv._algo.name}")
@@ -201,14 +203,15 @@ def main():
)
# ── DEBUG: print tensor sizes before all_to_all_single ──
print(f" [rank {rank}] input_data: numel={input_data.numel()}, shape={input_data.shape}, "
f"dtype={input_data.dtype}, device={input_data.device}, "
f"storage_size={input_data.untyped_storage().size()}, "
f"data_ptr=0x{input_data.data_ptr():x}")
print(f" [rank {rank}] world_size={world_size}, chunk_size={chunk_size}, "
f"expected_total_elems={world_size * chunk_size}, "
f"scratch_buffer_size={alltoallv._scratch_size}")
sys.stdout.flush()
if _DEBUG:
print(f" [rank {rank}] input_data: numel={input_data.numel()}, shape={input_data.shape}, "
f"dtype={input_data.dtype}, device={input_data.device}, "
f"storage_size={input_data.untyped_storage().size()}, "
f"data_ptr=0x{input_data.data_ptr():x}")
print(f" [rank {rank}] world_size={world_size}, chunk_size={chunk_size}, "
f"expected_total_elems={world_size * chunk_size}, "
f"scratch_buffer_size={alltoallv._scratch_size}")
sys.stdout.flush()
dist.barrier()
try:
@@ -224,11 +227,12 @@ def main():
raise
# ── DEBUG: print output tensor sizes ──
print(f" [rank {rank}] output: numel={output.numel()}, shape={output.shape}, "
f"dtype={output.dtype}, device={output.device}, "
f"storage_size={output.untyped_storage().size()}, "
f"data_ptr=0x{output.data_ptr():x}")
sys.stdout.flush()
if _DEBUG:
print(f" [rank {rank}] output: numel={output.numel()}, shape={output.shape}, "
f"dtype={output.dtype}, device={output.device}, "
f"storage_size={output.untyped_storage().size()}, "
f"data_ptr=0x{output.data_ptr():x}")
sys.stdout.flush()
# Verify: each chunk should come from different ranks
try: