Files
mscclpp/examples/torch-integration/customized_comm_with_tuning.py
Binyang Li 25435acf5d Add new algos for GB200 (#747)
- Add new algos (allreduce_rsag, allreduce_rsag_pipeline and
allreduce_rsag_zero_copy) for GB200.
- Add IB stub for non-IB env
- Provides example for algorithm tunning with different nblocks/nthreads

Perf for allreduce_rsag
```
#                                                              out-of-place                       in-place          
#       size         count      type   redop    root     time   algbw   busbw  #wrong     time   algbw   busbw  #wrong 
#        (B)    (elements)                               (us)  (GB/s)  (GB/s)             (us)  (GB/s)  (GB/s)         
     1048576        262144     float     sum      -1    25.16   41.67   62.51       0    23.73   44.18   66.27       0
     2097152        524288     float     sum      -1    26.06   80.47  120.71       0    25.31   82.86  124.29       0
     4194304       1048576     float     sum      -1    31.09  134.93  202.39       0    30.75  136.39  204.58       0
     8388608       2097152     float     sum      -1    45.52  184.29  276.43       0    45.13  185.87  278.80       0
    16777216       4194304     float     sum      -1    75.73  221.53  332.30       0    75.51  222.18  333.27       0
    33554432       8388608     float     sum      -1   137.25  244.48  366.72       0   137.22  244.54  366.81       0
    67108864      16777216     float     sum      -1   271.34  247.32  370.99       0   270.86  247.76  371.65       0
   134217728      33554432     float     sum      -1   534.25  251.22  376.84       0   534.43  251.14  376.71       0
# Out of bounds values : 0 OK
# Avg bus bandwidth    : 264.454 
#
# Collective test concluded: all_reduce_perf
```

perf for allreduce_rsag_pipeline
```
#                                                              out-of-place                       in-place          
#       size         count      type   redop    root     time   algbw   busbw  #wrong     time   algbw   busbw  #wrong 
#        (B)    (elements)                               (us)  (GB/s)  (GB/s)             (us)  (GB/s)  (GB/s)         
     1048576        262144     float     sum      -1    61.57   17.03   25.55       0    61.51   17.05   25.57       0
     2097152        524288     float     sum      -1    61.31   34.20   51.31       0    61.23   34.25   51.38       0
     4194304       1048576     float     sum      -1    61.62   68.06  102.10       0    61.84   67.83  101.74       0
     8388608       2097152     float     sum      -1    61.97  135.37  203.06       0    61.89  135.53  203.30       0
    16777216       4194304     float     sum      -1    63.15  265.65  398.48       0    62.89  266.76  400.15       0
    33554432       8388608     float     sum      -1   100.63  333.46  500.19       0    99.76  336.34  504.51       0
    67108864      16777216     float     sum      -1   180.04  372.75  559.13       0   179.75  373.34  560.01       0
   134217728      33554432     float     sum      -1   339.60  395.23  592.84       0   338.16  396.91  595.36       0
# Out of bounds values : 0 OK
# Avg bus bandwidth    : 304.665 
#
# Collective test concluded: all_reduce_perf
```

perf for allreduce_rsag_zero_copy
```
#                                                              out-of-place                       in-place          
#       size         count      type   redop    root     time   algbw   busbw  #wrong     time   algbw   busbw  #wrong 
#        (B)    (elements)                               (us)  (GB/s)  (GB/s)             (us)  (GB/s)  (GB/s)         
     1048576        262144     float     sum      -1    14.99   69.93  104.90       0    14.44   72.61  108.92       0
     2097152        524288     float     sum      -1    16.19  129.56  194.33       0    15.85  132.32  198.48       0
     4194304       1048576     float     sum      -1    21.19  197.98  296.97       0    20.64  203.20  304.81       0
     8388608       2097152     float     sum      -1    31.04  270.27  405.41       0    30.68  273.44  410.16       0
    16777216       4194304     float     sum      -1    50.34  333.26  499.89       0    50.15  334.51  501.77       0
    33554432       8388608     float     sum      -1    89.58  374.56  561.84       0    88.65  378.48  567.73       0
    67108864      16777216     float     sum      -1   165.69  405.03  607.54       0   163.64  410.10  615.16       0
   134217728      33554432     float     sum      -1   323.19  415.28  622.93       0   318.01  422.05  633.07       0
# Out of bounds values : 0 OK
# Avg bus bandwidth    : 414.619 
#
# Collective test concluded: all_reduce_perf
```

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: chhwang <8018170+chhwang@users.noreply.github.com>
Co-authored-by: Qinghua Zhou <qinghuazhou@microsoft.com>
Co-authored-by: Caio Rocha <caiorocha@microsoft.com>
2026-02-24 16:43:23 -08:00

283 lines
12 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# MSCCLPP_MASTER_ADDR=<master_ip> MSCCLPP_MASTER_PORT=<port> torchrun --nnodes=1 --nproc_per_node=8 customized_comm_with_tuning.py
import os
import torch
import mscclpp.utils as mscclpp_utils
import mscclpp
import mscclpp.ext
import netifaces as ni
import ipaddress
def load_algorithms(scratch_buffer: torch.tensor, rank: int) -> mscclpp.AlgorithmCollection:
collection_builder = mscclpp.ext.AlgorithmCollectionBuilder()
return collection_builder.build_default_algorithms(
scratch_buffer=scratch_buffer.data_ptr(), scratch_buffer_size=scratch_buffer.nbytes, rank=rank
)
def interfaces_for_ip_netifaces(ip: str):
target = ipaddress.ip_address(ip)
for interface in ni.interfaces():
addresses = ni.ifaddresses(interface)
if ni.AF_INET in addresses:
for link in addresses[ni.AF_INET]:
if "addr" in link:
addr = ipaddress.ip_address(link["addr"])
if addr == target:
return interface
return None
def to_mscclpp_reduce_op(op: torch.distributed.ReduceOp) -> mscclpp.ReduceOp:
if op == torch.distributed.ReduceOp.SUM:
return mscclpp.ReduceOp.SUM
elif op == torch.distributed.ReduceOp.MIN:
return mscclpp.ReduceOp.MIN
else:
raise ValueError(f"unsupported op: {op}")
class CustomizedComm:
def __init__(self, comm: mscclpp.CommGroup):
self.comm = comm
self.rank = comm.my_rank
self.world_size = comm.nranks
self.local_rank = comm.my_rank % comm.nranks_per_node
self.n_ranks_per_node = comm.nranks_per_node
dlpack = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(torch.float16))
self.scratch_buffer = torch.utils.dlpack.from_dlpack(dlpack)
algorithms = load_algorithms(scratch_buffer=self.scratch_buffer, rank=self.rank)
self._algorithm_nvls_packet = [
algo
for algo in algorithms
if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_packet"
][0]
self._algorithm_rsag_zero_copy = [
algo
for algo in algorithms
if algo.collective == "allreduce" and algo.name == "default_allreduce_rsag_zero_copy"
][0]
self._algorithm_packet = [
algo for algo in algorithms if algo.collective == "allreduce" and algo.name == "default_allreduce_packet"
][0]
self._tune(n_warmup=5, n_graph_launches=10, n_ops_per_graph=100)
def _tune(self, n_warmup, n_graph_launches, n_ops_per_graph):
sizes = [1 << i for i in range(10, 28)]
# Pre-fill with defaults for barrier
self.best_configs = {1024: (self._algorithm_nvls_packet, 0, 0)}
tune_tensor = torch.rand(1 << 27, dtype=torch.float16, device="cuda")
candidates_nblocks = [4, 8, 16, 24, 32, 48, 64, 128]
candidates_nthreads = [512, 768, 1024]
for size in sizes:
algos = []
if size <= 4 * 1024 * 1024:
algos.append(self._algorithm_nvls_packet)
algos.append(self._algorithm_packet)
if size >= 512 * 1024:
algos.append(self._algorithm_rsag_zero_copy)
best_time = float("inf")
best_config = None
for algo in algos:
for nb in candidates_nblocks:
if algo.name == "default_allreduce_nvls_packet" and nb > 16:
continue
if algo.name == "default_allreduce_packet" and nb > 56:
continue
for nt in candidates_nthreads:
if self._run_algo(algo, tune_tensor, size, nb, nt) != 0:
continue
for _ in range(n_warmup):
self._run_algo(algo, tune_tensor, size, nb, nt)
self.barrier()
capture_stream = torch.cuda.Stream()
capture_stream.wait_stream(torch.cuda.current_stream())
g = torch.cuda.CUDAGraph()
# Warmup on capture stream
with torch.cuda.stream(capture_stream):
self._run_algo(algo, tune_tensor, size, nb, nt)
capture_stream.synchronize()
with torch.cuda.graph(g, stream=capture_stream):
for _ in range(n_ops_per_graph):
self._run_algo(algo, tune_tensor, size, nb, nt)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record(capture_stream)
with torch.cuda.stream(capture_stream):
for _ in range(n_graph_launches):
g.replay()
end_event.record(capture_stream)
end_event.synchronize()
elapsed = start_event.elapsed_time(end_event)
# Synchronize timing results across all ranks to ensure consistent algorithm selection
# replicate n times such due to algo limitations
time_tensor = torch.full((self.world_size,), elapsed, dtype=torch.float64, device="cuda").to(
dtype=torch.float32
)
torch.cuda.current_stream().wait_stream(capture_stream)
# TODO: use all_reduce may cause problem if the time elapsed between different algos are too close.
# May change to broadcast in the future if that becomes an issue.
self.all_reduce(time_tensor, op=torch.distributed.ReduceOp.SUM)
avg_time = time_tensor[self.rank].item() / self.world_size
if avg_time < best_time:
best_time = avg_time
best_config = (algo, nb, nt)
if best_config:
self.best_configs[size] = best_config
if self.rank == 0:
print(
f"Size {size}: Best Algo {best_config[0].name} nblocks {best_config[1]} nthreads {best_config[2]} Time {(best_time/(n_graph_launches * n_ops_per_graph))*1000:.2f} us"
)
# reset the algorithms after tuning
torch.cuda.synchronize()
for algo in algos:
algo.reset()
def _run_algo(self, algo, tensor, size, nblocks, nthreads):
return algo.execute(
comm=self.comm.communicator,
input_buffer=tensor.data_ptr(),
output_buffer=tensor.data_ptr(),
input_size=size,
output_size=size,
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(tensor.dtype),
op=mscclpp.ReduceOp.SUM,
stream=torch.cuda.current_stream().cuda_stream,
nblocks=nblocks,
nthreads_per_block=nthreads,
)
def get_tuned_config(self, size):
if size < 1024:
target_size = 1024
elif size > 256 * 1024 * 1024:
target_size = 256 * 1024 * 1024
else:
target_size = 1 << (size - 1).bit_length()
return self.best_configs.get(target_size)
def all_reduce(self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM, stream: torch.cuda.Stream = None):
assert op == torch.distributed.ReduceOp.SUM
config = self.get_tuned_config(tensor.nbytes)
algo, nblocks, nthreads = config if config else (self._algorithm_nvls_packet, 0, 0)
ret = algo.execute(
comm=self.comm.communicator,
input_buffer=tensor.data_ptr(),
output_buffer=tensor.data_ptr(),
input_size=tensor.nbytes,
output_size=tensor.nbytes,
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(tensor.dtype),
op=to_mscclpp_reduce_op(op),
stream=stream.cuda_stream if stream is not None else torch.cuda.current_stream().cuda_stream,
nblocks=nblocks,
nthreads_per_block=nthreads,
)
if ret != 0:
print(f"Rank {self.rank}: Algo {algo.name} failed with error {ret}")
def barrier(self):
tensor = torch.empty(self.world_size, dtype=torch.float, device=torch.device("cuda"))
self.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM, stream=torch.cuda.current_stream())
def benchmark(self, n_warmup=10, n_graph_launches=10, n_iter_per_graph=100):
low = 5 * 1024
high = 80 * 1024 * 1024
sizes = []
curr = low
while curr <= high:
sizes.append(curr)
curr *= 2
if self.rank == 0:
print(f"{'Size (Bytes)':<20} {'Time (us)':<20} {'AlgoBW (GB/s)':<20}")
dtype = torch.float16
capture_stream = torch.cuda.Stream()
for size in sizes:
tensor = torch.rand(size // 2, dtype=dtype, device="cuda")
capture_stream.wait_stream(torch.cuda.current_stream())
# Capture Graph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g, stream=capture_stream):
for _ in range(n_iter_per_graph):
self.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
# warmup: Execute the graph once to prime the driver
with torch.cuda.stream(capture_stream):
for _ in range(n_warmup):
g.replay()
self.barrier()
capture_stream.synchronize()
# Benchmark
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record(capture_stream)
with torch.cuda.stream(capture_stream):
for _ in range(n_graph_launches):
g.replay()
end_event.record(capture_stream)
end_event.synchronize()
# Get elapsed time in milliseconds
elapsed_ms = start_event.elapsed_time(end_event)
avg_time_ms = elapsed_ms / (n_graph_launches * n_iter_per_graph)
time_us = avg_time_ms * 1000
alg_bw = size / (avg_time_ms * 1e-3) if avg_time_ms > 0 else 0
if self.rank == 0:
print(f"{size:<20} {time_us:<20.2f} {alg_bw / 1e9:<20.2f}")
def destroy(self):
self._algorithm_nvls_nonzero_copy = None
self._algorithm_nvls_packet = None
self.scratch_buffer = None
self.comm = None
def init_dist() -> CustomizedComm:
rank = int(os.environ["RANK"])
world = int(os.environ["WORLD_SIZE"])
master_addr = os.environ["MSCCLPP_MASTER_ADDR"]
master_port = os.environ["MSCCLPP_MASTER_PORT"]
interface = interfaces_for_ip_netifaces(master_addr)
if interface is None:
raise ValueError(f"Cannot find network interface for IP address {master_addr}")
interfaceIpPortTrio = f"{interface}:{master_addr}:{master_port}"
mscclpp_group = mscclpp.CommGroup(interfaceIpPortTrio=interfaceIpPortTrio, rank=rank, size=world)
return CustomizedComm(mscclpp_group)
def main():
local = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local)
comm = init_dist()
comm.benchmark(n_warmup=5, n_graph_launches=10, n_iter_per_graph=100)
comm.barrier()
torch.cuda.synchronize()
comm.destroy()
print(f"rank {local} All-reduce operation completed successfully.")
if __name__ == "__main__":
main()