Add new algos for GB200 (#747)

- Add new algos (allreduce_rsag, allreduce_rsag_pipeline and
allreduce_rsag_zero_copy) for GB200.
- Add IB stub for non-IB env
- Provides example for algorithm tunning with different nblocks/nthreads

Perf for allreduce_rsag
```
#                                                              out-of-place                       in-place          
#       size         count      type   redop    root     time   algbw   busbw  #wrong     time   algbw   busbw  #wrong 
#        (B)    (elements)                               (us)  (GB/s)  (GB/s)             (us)  (GB/s)  (GB/s)         
     1048576        262144     float     sum      -1    25.16   41.67   62.51       0    23.73   44.18   66.27       0
     2097152        524288     float     sum      -1    26.06   80.47  120.71       0    25.31   82.86  124.29       0
     4194304       1048576     float     sum      -1    31.09  134.93  202.39       0    30.75  136.39  204.58       0
     8388608       2097152     float     sum      -1    45.52  184.29  276.43       0    45.13  185.87  278.80       0
    16777216       4194304     float     sum      -1    75.73  221.53  332.30       0    75.51  222.18  333.27       0
    33554432       8388608     float     sum      -1   137.25  244.48  366.72       0   137.22  244.54  366.81       0
    67108864      16777216     float     sum      -1   271.34  247.32  370.99       0   270.86  247.76  371.65       0
   134217728      33554432     float     sum      -1   534.25  251.22  376.84       0   534.43  251.14  376.71       0
# Out of bounds values : 0 OK
# Avg bus bandwidth    : 264.454 
#
# Collective test concluded: all_reduce_perf
```

perf for allreduce_rsag_pipeline
```
#                                                              out-of-place                       in-place          
#       size         count      type   redop    root     time   algbw   busbw  #wrong     time   algbw   busbw  #wrong 
#        (B)    (elements)                               (us)  (GB/s)  (GB/s)             (us)  (GB/s)  (GB/s)         
     1048576        262144     float     sum      -1    61.57   17.03   25.55       0    61.51   17.05   25.57       0
     2097152        524288     float     sum      -1    61.31   34.20   51.31       0    61.23   34.25   51.38       0
     4194304       1048576     float     sum      -1    61.62   68.06  102.10       0    61.84   67.83  101.74       0
     8388608       2097152     float     sum      -1    61.97  135.37  203.06       0    61.89  135.53  203.30       0
    16777216       4194304     float     sum      -1    63.15  265.65  398.48       0    62.89  266.76  400.15       0
    33554432       8388608     float     sum      -1   100.63  333.46  500.19       0    99.76  336.34  504.51       0
    67108864      16777216     float     sum      -1   180.04  372.75  559.13       0   179.75  373.34  560.01       0
   134217728      33554432     float     sum      -1   339.60  395.23  592.84       0   338.16  396.91  595.36       0
# Out of bounds values : 0 OK
# Avg bus bandwidth    : 304.665 
#
# Collective test concluded: all_reduce_perf
```

perf for allreduce_rsag_zero_copy
```
#                                                              out-of-place                       in-place          
#       size         count      type   redop    root     time   algbw   busbw  #wrong     time   algbw   busbw  #wrong 
#        (B)    (elements)                               (us)  (GB/s)  (GB/s)             (us)  (GB/s)  (GB/s)         
     1048576        262144     float     sum      -1    14.99   69.93  104.90       0    14.44   72.61  108.92       0
     2097152        524288     float     sum      -1    16.19  129.56  194.33       0    15.85  132.32  198.48       0
     4194304       1048576     float     sum      -1    21.19  197.98  296.97       0    20.64  203.20  304.81       0
     8388608       2097152     float     sum      -1    31.04  270.27  405.41       0    30.68  273.44  410.16       0
    16777216       4194304     float     sum      -1    50.34  333.26  499.89       0    50.15  334.51  501.77       0
    33554432       8388608     float     sum      -1    89.58  374.56  561.84       0    88.65  378.48  567.73       0
    67108864      16777216     float     sum      -1   165.69  405.03  607.54       0   163.64  410.10  615.16       0
   134217728      33554432     float     sum      -1   323.19  415.28  622.93       0   318.01  422.05  633.07       0
# Out of bounds values : 0 OK
# Avg bus bandwidth    : 414.619 
#
# Collective test concluded: all_reduce_perf
```

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: chhwang <8018170+chhwang@users.noreply.github.com>
Co-authored-by: Qinghua Zhou <qinghuazhou@microsoft.com>
Co-authored-by: Caio Rocha <caiorocha@microsoft.com>
This commit is contained in:
Binyang Li
2026-02-24 16:43:23 -08:00
committed by GitHub
parent 184dcbf9d7
commit 25435acf5d
11 changed files with 1236 additions and 9 deletions

View File

@@ -51,7 +51,7 @@ jobs:
df -h
- name: Initialize CodeQL
uses: github/codeql-action/init@v3
uses: github/codeql-action/init@v4
with:
languages: ${{ matrix.language }}
@@ -63,10 +63,10 @@ jobs:
run: |
rm -rf build && mkdir build && cd build
cmake -DMSCCLPP_BYPASS_GPU_CHECK=ON -DMSCCLPP_USE_CUDA=ON ..
make -j
make -j4
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3
uses: github/codeql-action/analyze@v4
with:
category: "/language:${{matrix.language}}/version:${{matrix.version}}"
@@ -96,7 +96,7 @@ jobs:
df -h
- name: Initialize CodeQL
uses: github/codeql-action/init@v3
uses: github/codeql-action/init@v4
with:
languages: ${{ matrix.language }}
@@ -108,9 +108,9 @@ jobs:
run: |
rm -rf build && mkdir build && cd build
CXX=/opt/rocm/bin/hipcc cmake -DMSCCLPP_BYPASS_GPU_CHECK=ON -DMSCCLPP_USE_ROCM=ON ..
make -j
make -j4
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3
uses: github/codeql-action/analyze@v4
with:
category: "/language:${{matrix.language}}/version:${{matrix.version}}"

View File

@@ -0,0 +1,282 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# MSCCLPP_MASTER_ADDR=<master_ip> MSCCLPP_MASTER_PORT=<port> torchrun --nnodes=1 --nproc_per_node=8 customized_comm_with_tuning.py
import os
import torch
import mscclpp.utils as mscclpp_utils
import mscclpp
import mscclpp.ext
import netifaces as ni
import ipaddress
def load_algorithms(scratch_buffer: torch.tensor, rank: int) -> mscclpp.AlgorithmCollection:
collection_builder = mscclpp.ext.AlgorithmCollectionBuilder()
return collection_builder.build_default_algorithms(
scratch_buffer=scratch_buffer.data_ptr(), scratch_buffer_size=scratch_buffer.nbytes, rank=rank
)
def interfaces_for_ip_netifaces(ip: str):
target = ipaddress.ip_address(ip)
for interface in ni.interfaces():
addresses = ni.ifaddresses(interface)
if ni.AF_INET in addresses:
for link in addresses[ni.AF_INET]:
if "addr" in link:
addr = ipaddress.ip_address(link["addr"])
if addr == target:
return interface
return None
def to_mscclpp_reduce_op(op: torch.distributed.ReduceOp) -> mscclpp.ReduceOp:
if op == torch.distributed.ReduceOp.SUM:
return mscclpp.ReduceOp.SUM
elif op == torch.distributed.ReduceOp.MIN:
return mscclpp.ReduceOp.MIN
else:
raise ValueError(f"unsupported op: {op}")
class CustomizedComm:
def __init__(self, comm: mscclpp.CommGroup):
self.comm = comm
self.rank = comm.my_rank
self.world_size = comm.nranks
self.local_rank = comm.my_rank % comm.nranks_per_node
self.n_ranks_per_node = comm.nranks_per_node
dlpack = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(torch.float16))
self.scratch_buffer = torch.utils.dlpack.from_dlpack(dlpack)
algorithms = load_algorithms(scratch_buffer=self.scratch_buffer, rank=self.rank)
self._algorithm_nvls_packet = [
algo
for algo in algorithms
if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_packet"
][0]
self._algorithm_rsag_zero_copy = [
algo
for algo in algorithms
if algo.collective == "allreduce" and algo.name == "default_allreduce_rsag_zero_copy"
][0]
self._algorithm_packet = [
algo for algo in algorithms if algo.collective == "allreduce" and algo.name == "default_allreduce_packet"
][0]
self._tune(n_warmup=5, n_graph_launches=10, n_ops_per_graph=100)
def _tune(self, n_warmup, n_graph_launches, n_ops_per_graph):
sizes = [1 << i for i in range(10, 28)]
# Pre-fill with defaults for barrier
self.best_configs = {1024: (self._algorithm_nvls_packet, 0, 0)}
tune_tensor = torch.rand(1 << 27, dtype=torch.float16, device="cuda")
candidates_nblocks = [4, 8, 16, 24, 32, 48, 64, 128]
candidates_nthreads = [512, 768, 1024]
for size in sizes:
algos = []
if size <= 4 * 1024 * 1024:
algos.append(self._algorithm_nvls_packet)
algos.append(self._algorithm_packet)
if size >= 512 * 1024:
algos.append(self._algorithm_rsag_zero_copy)
best_time = float("inf")
best_config = None
for algo in algos:
for nb in candidates_nblocks:
if algo.name == "default_allreduce_nvls_packet" and nb > 16:
continue
if algo.name == "default_allreduce_packet" and nb > 56:
continue
for nt in candidates_nthreads:
if self._run_algo(algo, tune_tensor, size, nb, nt) != 0:
continue
for _ in range(n_warmup):
self._run_algo(algo, tune_tensor, size, nb, nt)
self.barrier()
capture_stream = torch.cuda.Stream()
capture_stream.wait_stream(torch.cuda.current_stream())
g = torch.cuda.CUDAGraph()
# Warmup on capture stream
with torch.cuda.stream(capture_stream):
self._run_algo(algo, tune_tensor, size, nb, nt)
capture_stream.synchronize()
with torch.cuda.graph(g, stream=capture_stream):
for _ in range(n_ops_per_graph):
self._run_algo(algo, tune_tensor, size, nb, nt)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record(capture_stream)
with torch.cuda.stream(capture_stream):
for _ in range(n_graph_launches):
g.replay()
end_event.record(capture_stream)
end_event.synchronize()
elapsed = start_event.elapsed_time(end_event)
# Synchronize timing results across all ranks to ensure consistent algorithm selection
# replicate n times such due to algo limitations
time_tensor = torch.full((self.world_size,), elapsed, dtype=torch.float64, device="cuda").to(
dtype=torch.float32
)
torch.cuda.current_stream().wait_stream(capture_stream)
# TODO: use all_reduce may cause problem if the time elapsed between different algos are too close.
# May change to broadcast in the future if that becomes an issue.
self.all_reduce(time_tensor, op=torch.distributed.ReduceOp.SUM)
avg_time = time_tensor[self.rank].item() / self.world_size
if avg_time < best_time:
best_time = avg_time
best_config = (algo, nb, nt)
if best_config:
self.best_configs[size] = best_config
if self.rank == 0:
print(
f"Size {size}: Best Algo {best_config[0].name} nblocks {best_config[1]} nthreads {best_config[2]} Time {(best_time/(n_graph_launches * n_ops_per_graph))*1000:.2f} us"
)
# reset the algorithms after tuning
torch.cuda.synchronize()
for algo in algos:
algo.reset()
def _run_algo(self, algo, tensor, size, nblocks, nthreads):
return algo.execute(
comm=self.comm.communicator,
input_buffer=tensor.data_ptr(),
output_buffer=tensor.data_ptr(),
input_size=size,
output_size=size,
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(tensor.dtype),
op=mscclpp.ReduceOp.SUM,
stream=torch.cuda.current_stream().cuda_stream,
nblocks=nblocks,
nthreads_per_block=nthreads,
)
def get_tuned_config(self, size):
if size < 1024:
target_size = 1024
elif size > 256 * 1024 * 1024:
target_size = 256 * 1024 * 1024
else:
target_size = 1 << (size - 1).bit_length()
return self.best_configs.get(target_size)
def all_reduce(self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM, stream: torch.cuda.Stream = None):
assert op == torch.distributed.ReduceOp.SUM
config = self.get_tuned_config(tensor.nbytes)
algo, nblocks, nthreads = config if config else (self._algorithm_nvls_packet, 0, 0)
ret = algo.execute(
comm=self.comm.communicator,
input_buffer=tensor.data_ptr(),
output_buffer=tensor.data_ptr(),
input_size=tensor.nbytes,
output_size=tensor.nbytes,
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(tensor.dtype),
op=to_mscclpp_reduce_op(op),
stream=stream.cuda_stream if stream is not None else torch.cuda.current_stream().cuda_stream,
nblocks=nblocks,
nthreads_per_block=nthreads,
)
if ret != 0:
print(f"Rank {self.rank}: Algo {algo.name} failed with error {ret}")
def barrier(self):
tensor = torch.empty(self.world_size, dtype=torch.float, device=torch.device("cuda"))
self.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM, stream=torch.cuda.current_stream())
def benchmark(self, n_warmup=10, n_graph_launches=10, n_iter_per_graph=100):
low = 5 * 1024
high = 80 * 1024 * 1024
sizes = []
curr = low
while curr <= high:
sizes.append(curr)
curr *= 2
if self.rank == 0:
print(f"{'Size (Bytes)':<20} {'Time (us)':<20} {'AlgoBW (GB/s)':<20}")
dtype = torch.float16
capture_stream = torch.cuda.Stream()
for size in sizes:
tensor = torch.rand(size // 2, dtype=dtype, device="cuda")
capture_stream.wait_stream(torch.cuda.current_stream())
# Capture Graph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g, stream=capture_stream):
for _ in range(n_iter_per_graph):
self.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
# warmup: Execute the graph once to prime the driver
with torch.cuda.stream(capture_stream):
for _ in range(n_warmup):
g.replay()
self.barrier()
capture_stream.synchronize()
# Benchmark
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record(capture_stream)
with torch.cuda.stream(capture_stream):
for _ in range(n_graph_launches):
g.replay()
end_event.record(capture_stream)
end_event.synchronize()
# Get elapsed time in milliseconds
elapsed_ms = start_event.elapsed_time(end_event)
avg_time_ms = elapsed_ms / (n_graph_launches * n_iter_per_graph)
time_us = avg_time_ms * 1000
alg_bw = size / (avg_time_ms * 1e-3) if avg_time_ms > 0 else 0
if self.rank == 0:
print(f"{size:<20} {time_us:<20.2f} {alg_bw / 1e9:<20.2f}")
def destroy(self):
self._algorithm_nvls_nonzero_copy = None
self._algorithm_nvls_packet = None
self.scratch_buffer = None
self.comm = None
def init_dist() -> CustomizedComm:
rank = int(os.environ["RANK"])
world = int(os.environ["WORLD_SIZE"])
master_addr = os.environ["MSCCLPP_MASTER_ADDR"]
master_port = os.environ["MSCCLPP_MASTER_PORT"]
interface = interfaces_for_ip_netifaces(master_addr)
if interface is None:
raise ValueError(f"Cannot find network interface for IP address {master_addr}")
interfaceIpPortTrio = f"{interface}:{master_addr}:{master_port}"
mscclpp_group = mscclpp.CommGroup(interfaceIpPortTrio=interfaceIpPortTrio, rank=rank, size=world)
return CustomizedComm(mscclpp_group)
def main():
local = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local)
comm = init_dist()
comm.benchmark(n_warmup=5, n_graph_launches=10, n_iter_per_graph=100)
comm.barrier()
torch.cuda.synchronize()
comm.destroy()
print(f"rank {local} All-reduce operation completed successfully.")
if __name__ == "__main__":
main()

View File

@@ -13,6 +13,9 @@
#include "allreduce/allreduce_nvls_with_copy.hpp"
#include "allreduce/allreduce_nvls_with_copy_2.hpp"
#include "allreduce/allreduce_packet.hpp"
#include "allreduce/allreduce_rsag.hpp"
#include "allreduce/allreduce_rsag_pipeline.hpp"
#include "allreduce/allreduce_rsag_zero_copy.hpp"
#include "logger.hpp"
namespace mscclpp {
@@ -82,6 +85,14 @@ AlgorithmCollection AlgorithmCollectionBuilder::buildDefaultNativeAlgorithms(uin
collection.registerAlgorithm(allreduceNvls->collective(), allreduceNvls->name(), allreduceNvls);
auto allreduceFullmesh = std::make_shared<AllreduceFullmesh>(scratchBuffer, scratchBufferSize)->build();
collection.registerAlgorithm(allreduceFullmesh->collective(), allreduceFullmesh->name(), allreduceFullmesh);
auto allreduceRsag = std::make_shared<AllreduceRsAg>(scratchBuffer, scratchBufferSize)->build();
collection.registerAlgorithm(allreduceRsag->collective(), allreduceRsag->name(), allreduceRsag);
auto allreduceRsagPipeline = std::make_shared<AllreduceRsAgPipeline>(scratchBuffer, scratchBufferSize)->build();
collection.registerAlgorithm(allreduceRsagPipeline->collective(), allreduceRsagPipeline->name(),
allreduceRsagPipeline);
auto allreduceRsagZeroCopy = std::make_shared<AllreduceRsAgZeroCopy>()->build();
collection.registerAlgorithm(allreduceRsagZeroCopy->collective(), allreduceRsagZeroCopy->name(),
allreduceRsagZeroCopy);
auto allgatherFullmesh = std::make_shared<AllgatherFullmesh>(scratchBuffer, scratchBufferSize)->build();
collection.registerAlgorithm(allgatherFullmesh->collective(), allgatherFullmesh->name(), allgatherFullmesh);

View File

@@ -0,0 +1,229 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include "allreduce/allreduce_rsag.hpp"
#include "allreduce/common.hpp"
#include "collective_utils.hpp"
#include "logger.hpp"
namespace mscclpp {
namespace collective {
// Allreduce using the Reduce-Scatter + All-Gather (RSAG) pattern.
//
// This algorithm performs allreduce in three phases over intra-node peers
// connected via CudaIpc memory channels:
//
// 1. Scatter: Each rank copies its input data into a scratch buffer, then
// signals peers and waits for all peers to do the same.
//
// 2. Reduce-Scatter: Each rank reduces its assigned chunk by reading the
// corresponding chunks from all peers' scratch buffers (via remote memory
// handles) and applying the reduction op. The reduced result is written
// back to both the local result buffer and peers' scratch buffers.
//
// 3. All-Gather: After a second signal/wait barrier, each rank copies the
// reduced chunks produced by other ranks from the scratch buffer into its
// result buffer, completing the allreduce.
//
// Data is processed in int4-sized (16-byte) units for coalesced memory access,
// with special handling for any remainder elements at the tail.
template <ReduceOp OpType, typename T>
__global__ void __launch_bounds__(1024, 1)
allreduceRsAg(T* buff, T* scratch, T* resultBuff, DeviceHandle<BaseMemoryChannel>* memoryChannels,
DeviceHandle<SwitchChannel>* switchChannels, void* remoteMemories, int rank, int nRanksPerNode,
int worldSize, size_t nelems) {
int blockId = blockIdx.x;
uint32_t nPeers = nRanksPerNode - 1;
assert((uintptr_t)buff % sizeof(int4) == 0);
assert((uintptr_t)resultBuff % sizeof(int4) == 0);
constexpr uint32_t nelemsPerInt4 = sizeof(int4) / sizeof(T);
uint32_t alignedNelems = ((nelems + nRanksPerNode - 1) / nRanksPerNode + nelemsPerInt4 - 1) / nelemsPerInt4 *
nelemsPerInt4 * nRanksPerNode;
uint32_t nelemsPerRank = alignedNelems / nRanksPerNode;
uint32_t nInt4PerRank = nelemsPerRank / nelemsPerInt4;
uint32_t lastInt4Index = nelems / nelemsPerInt4;
uint32_t remainder = nelems % nelemsPerInt4;
int4* scratch4 = reinterpret_cast<int4*>((char*)scratch);
int4* resultBuff4 = reinterpret_cast<int4*>((char*)resultBuff);
int4* buff4 = reinterpret_cast<int4*>((char*)buff);
DeviceHandle<BaseMemoryChannel>* memoryChannelsLocal = memoryChannels + blockId * nPeers;
uint32_t nInt4PerBlock = nInt4PerRank / gridDim.x;
uint32_t remainderForBlock = nInt4PerRank % gridDim.x;
uint32_t offset4 = blockId * nInt4PerBlock;
if (blockId == (int)(gridDim.x - 1)) {
nInt4PerBlock += remainderForBlock;
}
if (nInt4PerBlock == 0) return;
uint32_t nInt4ForCopy = nInt4PerBlock * nRanksPerNode;
for (uint32_t idx = threadIdx.x; idx < nInt4ForCopy; idx += blockDim.x) {
int rankIdx = idx / nInt4PerBlock;
uint32_t offsetIdx = rankIdx * nInt4PerRank + offset4 + (idx % nInt4PerBlock);
if (offsetIdx > lastInt4Index) continue;
if (offsetIdx == lastInt4Index && remainder != 0) {
for (uint32_t i = 0; i < remainder; i++) {
((T*)&scratch4[offsetIdx])[i] = ((T*)&buff4[offsetIdx])[i];
}
continue;
}
scratch4[offsetIdx] = buff4[offsetIdx];
}
__syncthreads();
if (threadIdx.x < nPeers) {
memoryChannelsLocal[threadIdx.x].signal();
memoryChannelsLocal[threadIdx.x].wait();
}
__syncthreads();
for (uint32_t idx = threadIdx.x; idx < nInt4PerBlock; idx += blockDim.x) {
uint32_t offset = idx + offset4 + rank * nInt4PerRank;
if (offset > lastInt4Index) continue;
int4 tmp = scratch4[offset];
for (uint32_t i = 0; i < nPeers; i++) {
int rankIdx = (rank + i + 1) % nRanksPerNode;
int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1;
int4 data = mscclpp::read<int4>(((void**)remoteMemories)[peerIdx], offset);
tmp = cal_vector<T, OpType>(data, tmp);
}
for (uint32_t i = 0; i < nPeers; i++) {
int rankIdx = (rank + i + 1) % nRanksPerNode;
int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1;
mscclpp::write<int4>(((void**)remoteMemories)[peerIdx], offset, tmp);
}
if (offset == lastInt4Index && remainder != 0) {
for (uint32_t i = 0; i < remainder; i++) {
((T*)&resultBuff4[offset])[i] = ((T*)&tmp)[i];
}
continue;
}
resultBuff4[offset] = tmp;
}
__syncthreads();
if (threadIdx.x < nPeers) {
memoryChannelsLocal[threadIdx.x].signal();
memoryChannelsLocal[threadIdx.x].wait();
}
__syncthreads();
for (uint32_t idx = threadIdx.x; idx < nInt4ForCopy; idx += blockDim.x) {
int rankIdx = idx / nInt4PerBlock;
if (rankIdx == rank) continue;
uint32_t offsetIdx = rankIdx * nInt4PerRank + offset4 + (idx % nInt4PerBlock);
if (offsetIdx > lastInt4Index) continue;
if (offsetIdx == lastInt4Index && remainder != 0) {
for (uint32_t i = 0; i < remainder; i++) {
((T*)&resultBuff4[offsetIdx])[i] = ((T*)&scratch4[offsetIdx])[i];
}
continue;
}
resultBuff4[offsetIdx] = scratch4[offsetIdx];
}
}
template <ReduceOp OpType, typename T>
struct AllreduceRsAgAdapter {
static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories,
DeviceHandle<SwitchChannel>* switchChannel, DeviceHandle<SwitchChannel>*, size_t, size_t,
size_t, int rank, int nRanksPerNode, int worldSize, size_t inputSize, cudaStream_t stream,
void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) {
using ChannelType = DeviceHandle<BaseMemoryChannel>;
size_t nelems = inputSize / sizeof(T);
if (nBlocks == 0 || nThreadsPerBlock == 0) {
nThreadsPerBlock = 1024;
nBlocks = 64;
}
allreduceRsAg<OpType, T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
(T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, switchChannel, remoteMemories, rank,
nRanksPerNode, worldSize, nelems);
return cudaGetLastError();
}
};
void AllreduceRsAg::initialize(std::shared_ptr<Communicator> comm) {
this->conns_ = setupConnections(comm);
nChannelsPerConnection_ = 64;
comm_ = comm;
// setup semaphores
this->scratchSemaphores_ = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_);
RegisteredMemory localMemory = comm->registerMemory(scratchBuffer_, scratchBufferSize_, Transport::CudaIpc);
this->remoteScratchMemories_ = setupRemoteMemories(comm, comm->bootstrap()->getRank(), localMemory);
localScratchMemory_ = std::move(localMemory);
this->baseChannels_ = setupBaseMemoryChannels(this->conns_, this->scratchSemaphores_, nChannelsPerConnection_);
this->baseMemoryChannelHandles_ = setupBaseMemoryChannelDeviceHandles(baseChannels_);
std::vector<void*> remoteMemoryHandles;
for (const auto& remoteMemory : this->remoteScratchMemories_) {
remoteMemoryHandles.push_back(remoteMemory.data());
}
this->remoteMemoryHandles_ = detail::gpuCallocShared<void*>(remoteMemoryHandles.size());
gpuMemcpy(this->remoteMemoryHandles_.get(), remoteMemoryHandles.data(), remoteMemoryHandles.size(),
cudaMemcpyHostToDevice);
}
CommResult AllreduceRsAg::allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output,
size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream,
int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>&) {
auto algoCtx = std::static_pointer_cast<AlgorithmCtx>(ctx);
AllreduceFunc allreduce = dispatch<AllreduceRsAgAdapter>(op, dtype);
if (!allreduce) {
WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast<int>(op),
", dtype=", static_cast<int>(dtype));
return CommResult::CommInvalidArgument;
}
if (inputSize > this->scratchBufferSize_) {
WARN(ALGO, "Input size ", inputSize, " exceeds scratch buffer size ", this->scratchBufferSize_);
return CommResult::CommInvalidArgument;
}
std::pair<int, int> numBlocksAndThreads = {nBlocks, nThreadsPerBlock};
cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->baseMemoryChannelHandles_.get(),
this->remoteMemoryHandles_.get(), nullptr, nullptr, 0, 0, 0, algoCtx->rank,
algoCtx->nRanksPerNode, algoCtx->workSize, inputSize, stream, nullptr, 0, 0,
numBlocksAndThreads.first, numBlocksAndThreads.second);
if (error != cudaSuccess) {
WARN(ALGO, "Allreduce kernel launch failed with error: ", cudaGetErrorString(error));
return CommResult::CommUnhandledCudaError;
}
return CommResult::CommSuccess;
}
AlgorithmCtxKey AllreduceRsAg::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) {
return AlgorithmCtxKey{nullptr, nullptr, 0, 0, 0};
}
std::shared_ptr<void> AllreduceRsAg::initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void*,
size_t, DataType) {
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
ctx->memorySemaphores = this->scratchSemaphores_;
ctx->registeredMemories = this->remoteScratchMemories_;
return ctx;
}
std::shared_ptr<Algorithm> AllreduceRsAg::build() {
auto self = std::make_shared<AllreduceRsAg>((uintptr_t)scratchBuffer_, scratchBufferSize_);
return std::make_shared<NativeAlgorithm>(
"default_allreduce_rsag", "allreduce",
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) -> CommResult {
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
extras);
},
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize,
DataType dtype) { return self->initAllreduceContext(comm, input, output, inputSize, dtype); },
[self](const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype,
bool symmetricMemory) {
return self->generateAllreduceContextKey(input, output, inputSize, dtype, symmetricMemory);
});
}
} // namespace collective
} // namespace mscclpp

View File

@@ -0,0 +1,336 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include "allreduce/allreduce_rsag_pipeline.hpp"
#include "allreduce/common.hpp"
#include "collective_utils.hpp"
#include "logger.hpp"
namespace mscclpp {
namespace collective {
constexpr int MAX_NBLOCKS_FOR_PUT = 32;
constexpr int MAX_NBLOCKS_FOR_RECV = 32;
constexpr int MAX_NBLOCKS_FOR_REDUCE = 64;
constexpr int REDUCE_COPY_RATIO = 2;
__device__ DeviceSemaphore semaphoreForSend[MAX_NBLOCKS_FOR_REDUCE];
__device__ DeviceSemaphore semaphoreForRecv[MAX_NBLOCKS_FOR_REDUCE];
__device__ DeviceSemaphore semaphoreForReduce[MAX_NBLOCKS_FOR_REDUCE];
// TODO: move it to a common header file
template <typename T>
__device__ __forceinline__ int4 loadVec(const T* buff, size_t i, size_t nelems) {
constexpr size_t ElemsPerInt4 = sizeof(int4) / sizeof(T);
size_t offset = i * ElemsPerInt4;
if (offset + ElemsPerInt4 <= nelems) {
return reinterpret_cast<const int4*>(buff)[i];
} else {
union {
int4 i;
T t[ElemsPerInt4];
} vec;
vec.i = make_int4(0, 0, 0, 0);
for (size_t j = 0; j < ElemsPerInt4 && offset + j < nelems; ++j) {
vec.t[j] = buff[offset + j];
}
return vec.i;
}
}
template <typename T>
__device__ __forceinline__ void storeVec(T* buff, size_t i, int4 val, size_t nelems) {
constexpr size_t ElemsPerInt4 = sizeof(int4) / sizeof(T);
size_t offset = i * ElemsPerInt4;
if (offset + ElemsPerInt4 <= nelems) {
reinterpret_cast<int4*>(buff)[i] = val;
} else {
union {
int4 i;
T t[ElemsPerInt4];
} vec;
vec.i = val;
for (size_t j = 0; j < ElemsPerInt4 && offset + j < nelems; ++j) {
buff[offset + j] = vec.t[j];
}
}
}
// Pipelined Reduce-Scatter + All-Gather (RSAG) allreduce.
//
// This is a pipelined variant of the basic RSAG allreduce that overlaps
// communication and computation by splitting the data into chunks processed
// across multiple iterations. Three groups of thread blocks run concurrently
// with different roles, synchronized via device semaphores:
//
// PUT blocks — Read local input chunks and write them into peers' scratch
// buffers via remote memory handles (CudaIpc).
//
// REDUCE blocks — After a signal/wait barrier confirming PUT completion,
// reduce the local chunk with data received from all peers
// in the scratch buffer. Write the reduced result to both
// the local output and peers' scratch (for the AG phase).
//
// RECV blocks — After a signal/wait barrier confirming REDUCE completion,
// copy other ranks' reduced chunks from scratch into the
// local result buffer, completing the all-gather.
//
// Pipelining is achieved by using a circular scratch buffer (pipelineDepth
// stages). PUT blocks wait on a semaphore before reusing a scratch slot,
// allowing the next iteration's PUT to overlap with the current iteration's
// REDUCE and RECV. Each REDUCE block handles a subset of the PUT block's
// data (controlled by REDUCE_COPY_RATIO), enabling finer-grained overlap.
//
// Data is processed in int4-sized (16-byte) units with vectorized load/store
// helpers that handle tail elements.
template <ReduceOp OpType, typename T>
__global__ void __launch_bounds__(1024, 1)
allreduceRsAgPipeline(T* buff, T* scratch, T* resultBuff, DeviceHandle<BaseMemoryChannel>* memoryChannels,
DeviceHandle<SwitchChannel>* switchChannels, void* remoteMemories, int rank,
int nRanksPerNode, int worldSize, size_t nelems, size_t scratchSize, uint32_t nblocksForPut,
uint32_t nblocksForReduce, uint32_t nblocksForRecv) {
uint32_t bid = blockIdx.x;
constexpr uint32_t nStepsPerIter = 4;
uint32_t nInt4 = (nelems * sizeof(T) + sizeof(int4) - 1) / sizeof(int4);
uint32_t nInt4PerIter = nblocksForReduce * blockDim.x * nStepsPerIter;
const uint32_t chunkSize = nInt4PerIter * worldSize;
uint32_t nIters = (nInt4 + chunkSize - 1) / chunkSize;
uint32_t nPeers = nRanksPerNode - 1;
int4* scratch4 = reinterpret_cast<int4*>((char*)scratch);
const uint32_t scratchIterStride = 2 * chunkSize; // one for AS, one for AG
const uint32_t pipelineDepth = scratchSize / sizeof(int4) / scratchIterStride;
assert(pipelineDepth >= 1);
if (bid < nblocksForPut) {
if (threadIdx.x == 0) {
semaphoreForSend[bid].set(pipelineDepth);
}
for (uint32_t iter = 0; iter < nIters; iter++) {
if (threadIdx.x == 0) {
semaphoreForSend[bid].acquire();
}
__syncthreads();
uint32_t threadIdInPut = bid * blockDim.x + threadIdx.x;
for (uint32_t peer = 0; peer < nPeers; peer++) {
int remoteRankId = (rank + peer + 1) % nRanksPerNode;
int peerId = remoteRankId < rank ? remoteRankId : remoteRankId - 1;
// Read chunk[remoteRankId] from local buff, write to peer's scratch[rank] (sender's slot)
uint32_t srcOffset = iter * chunkSize + remoteRankId * nInt4PerIter;
uint32_t dstOffset = (iter % pipelineDepth) * scratchIterStride + rank * nInt4PerIter;
int4 tmp[nStepsPerIter * REDUCE_COPY_RATIO];
#pragma unroll
for (uint32_t step = 0; step < nStepsPerIter * REDUCE_COPY_RATIO; step++) {
uint32_t offset = srcOffset + threadIdInPut + step * blockDim.x * nblocksForPut;
tmp[step] = loadVec(buff, offset, nelems);
}
#pragma unroll
for (uint32_t step = 0; step < nStepsPerIter * REDUCE_COPY_RATIO; step++) {
uint32_t offset = dstOffset + threadIdInPut + step * blockDim.x * nblocksForPut;
mscclpp::write<int4>(((void**)remoteMemories)[peerId], offset, tmp[step]);
}
}
__syncthreads();
if (threadIdx.x < REDUCE_COPY_RATIO) {
semaphoreForReduce[bid * REDUCE_COPY_RATIO + threadIdx.x].release();
}
}
} else if (bid < nblocksForPut + nblocksForReduce) {
uint32_t bidInReduce = bid - nblocksForPut;
DeviceHandle<BaseMemoryChannel>* localMemoryChannels = memoryChannels + bidInReduce * nPeers;
// Map REDUCE blocks to PUT blocks: REDUCE blocks 0,1 handle PUT block 0's data
uint32_t putBlockId = bidInReduce / REDUCE_COPY_RATIO;
uint32_t subBlockId = bidInReduce % REDUCE_COPY_RATIO;
for (uint32_t iter = 0; iter < nIters; iter++) {
if (threadIdx.x == 0) {
semaphoreForReduce[bidInReduce].acquire();
}
uint32_t baseOffset = (iter % pipelineDepth) * scratchIterStride;
uint32_t baseSrcOffset = iter * chunkSize;
// Use same thread mapping as PUT: putBlockId * blockDim.x + threadIdx.x
uint32_t threadIdInPut = putBlockId * blockDim.x + threadIdx.x;
__syncthreads();
if (threadIdx.x < nPeers) {
localMemoryChannels[threadIdx.x].signal();
localMemoryChannels[threadIdx.x].wait();
}
__syncthreads();
#pragma unroll nStepsPerIter
for (uint32_t step = 0; step < nStepsPerIter; step++) {
// Map to PUT's step pattern: each REDUCE block handles nStepsPerIter steps
// subBlockId determines which subset of the REDUCE_COPY_RATIO * nStepsPerIter steps
uint32_t putStep = subBlockId * nStepsPerIter + step;
uint32_t myChunkOffset =
baseSrcOffset + rank * nInt4PerIter + threadIdInPut + putStep * blockDim.x * nblocksForPut;
int4 tmp = loadVec(buff, myChunkOffset, nelems);
// Add data from each peer's slot in scratch (peer sent their chunk[rank] to our scratch[peer])
for (uint32_t peer = 0; peer < nPeers; peer++) {
int remoteRankId = (rank + peer + 1) % nRanksPerNode;
uint32_t peerSlotOffset =
baseOffset + remoteRankId * nInt4PerIter + threadIdInPut + putStep * blockDim.x * nblocksForPut;
int4 data = scratch4[peerSlotOffset];
tmp = cal_vector<T, OpType>(data, tmp);
}
storeVec(resultBuff, myChunkOffset, tmp, nelems);
// Broadcast reduced result to all peers' scratch at SCATTER_AG_OFFSET + rank * nInt4PerIter
uint32_t dstOffset =
baseOffset + chunkSize + rank * nInt4PerIter + threadIdInPut + putStep * blockDim.x * nblocksForPut;
for (uint32_t i = 0; i < nPeers; i++) {
int peerIdx = (rank + i + 1) % nRanksPerNode;
int index = peerIdx < rank ? peerIdx : peerIdx - 1;
mscclpp::write<int4>(((void**)remoteMemories)[index], dstOffset, tmp);
}
}
__syncthreads();
if (threadIdx.x == 0) {
semaphoreForRecv[bidInReduce].release();
}
}
} else if (bid < nblocksForPut + nblocksForReduce + nblocksForRecv) {
uint32_t bidInRecv = bid - nblocksForPut - nblocksForReduce;
DeviceHandle<BaseMemoryChannel>* localMemoryChannels = memoryChannels + (nblocksForReduce + bidInRecv) * nPeers;
for (uint32_t iter = 0; iter < nIters; iter++) {
if (threadIdx.x < REDUCE_COPY_RATIO) {
semaphoreForRecv[bidInRecv * REDUCE_COPY_RATIO + threadIdx.x].acquire();
}
uint32_t baseOffset = scratchIterStride * (iter % pipelineDepth);
uint32_t baseDstOffset = chunkSize * iter;
int threadIdInRecv = bidInRecv * blockDim.x + threadIdx.x;
__syncthreads();
if (threadIdx.x < nPeers) {
localMemoryChannels[threadIdx.x].signal();
localMemoryChannels[threadIdx.x].wait();
}
__syncthreads();
// Copy other ranks' reduced chunks from scratch to result
for (uint32_t peer = 0; peer < nPeers; peer++) {
int remoteRankId = (rank + peer + 1) % nRanksPerNode;
for (uint32_t step = 0; step < nStepsPerIter * REDUCE_COPY_RATIO; step++) {
uint32_t offset = baseOffset + chunkSize + remoteRankId * nInt4PerIter + threadIdInRecv +
step * blockDim.x * nblocksForRecv;
uint32_t dstOffset =
baseDstOffset + remoteRankId * nInt4PerIter + threadIdInRecv + step * blockDim.x * nblocksForRecv;
storeVec(resultBuff, dstOffset, scratch4[offset], nelems);
}
}
__syncthreads();
if (threadIdx.x == 0) {
semaphoreForSend[bidInRecv].release();
}
}
}
}
template <ReduceOp OpType, typename T>
struct AllreduceRsAgPipelineAdapter {
static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories,
DeviceHandle<SwitchChannel>* switchChannel, DeviceHandle<SwitchChannel>*, size_t, size_t,
size_t scratchSize, int rank, int nRanksPerNode, int worldSize, size_t inputSize,
cudaStream_t stream, void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) {
using ChannelType = DeviceHandle<BaseMemoryChannel>;
size_t nelems = inputSize / sizeof(T);
uint32_t nblocksForPut = MAX_NBLOCKS_FOR_PUT;
uint32_t nblocksForReduce = MAX_NBLOCKS_FOR_REDUCE;
uint32_t nblocksForRecv = MAX_NBLOCKS_FOR_RECV;
int maxNblocks = nblocksForPut + nblocksForReduce + nblocksForRecv;
if (nBlocks == 0 || nThreadsPerBlock == 0) {
nThreadsPerBlock = 1024;
nBlocks = maxNblocks;
} else {
nBlocks = nBlocks / (REDUCE_COPY_RATIO + 2) * (REDUCE_COPY_RATIO + 2);
if (nBlocks > maxNblocks) {
WARN(ALGO, "The number of blocks is too large for the allreduce pipeline algorithm, reducing it to ",
maxNblocks);
nBlocks = maxNblocks;
}
nblocksForPut = nBlocks / (REDUCE_COPY_RATIO + 2);
nblocksForReduce = nblocksForPut * REDUCE_COPY_RATIO;
nblocksForRecv = nblocksForPut;
}
allreduceRsAgPipeline<OpType, T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
(T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, switchChannel, remoteMemories, rank,
nRanksPerNode, worldSize, nelems, scratchSize, nblocksForPut, nblocksForReduce, nblocksForRecv);
return cudaGetLastError();
}
};
void AllreduceRsAgPipeline::initialize(std::shared_ptr<Communicator> comm) {
this->conns_ = setupConnections(comm);
nChannelsPerConnection_ = MAX_NBLOCKS_FOR_REDUCE + MAX_NBLOCKS_FOR_RECV;
comm_ = comm;
// setup semaphores
this->scratchSemaphores_ = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_);
RegisteredMemory localMemory = comm->registerMemory(scratchBuffer_, scratchBufferSize_, Transport::CudaIpc);
this->remoteScratchMemories_ = setupRemoteMemories(comm, comm->bootstrap()->getRank(), localMemory);
localScratchMemory_ = std::move(localMemory);
this->baseChannels_ = setupBaseMemoryChannels(this->conns_, this->scratchSemaphores_, nChannelsPerConnection_);
this->baseMemoryChannelHandles_ = setupBaseMemoryChannelDeviceHandles(baseChannels_);
std::vector<void*> remoteMemoryHandles;
for (const auto& remoteMemory : this->remoteScratchMemories_) {
remoteMemoryHandles.push_back(remoteMemory.data());
}
this->remoteMemoryHandles_ = detail::gpuCallocShared<void*>(remoteMemoryHandles.size());
gpuMemcpy(this->remoteMemoryHandles_.get(), remoteMemoryHandles.data(), remoteMemoryHandles.size(),
cudaMemcpyHostToDevice);
}
CommResult AllreduceRsAgPipeline::allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output,
size_t inputSize, DataType dtype, ReduceOp op,
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>&) {
auto algoCtx = std::static_pointer_cast<AlgorithmCtx>(ctx);
AllreduceFunc allreduce = dispatch<AllreduceRsAgPipelineAdapter>(op, dtype);
if (!allreduce) {
WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast<int>(op),
", dtype=", static_cast<int>(dtype));
return CommResult::CommInvalidArgument;
}
std::pair<int, int> numBlocksAndThreads = {nBlocks, nThreadsPerBlock};
cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->baseMemoryChannelHandles_.get(),
this->remoteMemoryHandles_.get(), nullptr, nullptr, 0, 0, this->scratchBufferSize_,
algoCtx->rank, algoCtx->nRanksPerNode, algoCtx->workSize, inputSize, stream, nullptr, 0,
0, numBlocksAndThreads.first, numBlocksAndThreads.second);
if (error != cudaSuccess) {
WARN(ALGO, "Allreduce kernel launch failed with error: ", cudaGetErrorString(error));
return CommResult::CommUnhandledCudaError;
}
return CommResult::CommSuccess;
}
AlgorithmCtxKey AllreduceRsAgPipeline::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) {
return AlgorithmCtxKey{nullptr, nullptr, 0, 0, 0};
}
std::shared_ptr<void> AllreduceRsAgPipeline::initAllreduceContext(std::shared_ptr<Communicator> comm, const void*,
void*, size_t, DataType) {
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
ctx->memorySemaphores = this->scratchSemaphores_;
ctx->registeredMemories = this->remoteScratchMemories_;
return ctx;
}
std::shared_ptr<Algorithm> AllreduceRsAgPipeline::build() {
auto self = std::make_shared<AllreduceRsAgPipeline>((uintptr_t)scratchBuffer_, scratchBufferSize_);
return std::make_shared<NativeAlgorithm>(
"default_allreduce_rsag_pipeline", "allreduce",
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) -> CommResult {
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
extras);
},
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize,
DataType dtype) { return self->initAllreduceContext(comm, input, output, inputSize, dtype); },
[self](const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype,
bool symmetricMemory) {
return self->generateAllreduceContextKey(input, output, inputSize, dtype, symmetricMemory);
});
}
} // namespace collective
} // namespace mscclpp

View File

@@ -0,0 +1,236 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include "allreduce/allreduce_rsag_zero_copy.hpp"
#include "allreduce/common.hpp"
#include "collective_utils.hpp"
#include "logger.hpp"
namespace mscclpp {
namespace collective {
__device__ mscclpp::DeviceSyncer globalSyncer;
// Zero-copy Reduce-Scatter + All-Gather (RSAG) allreduce.
//
// Unlike the standard RSAG which copies input into a scratch buffer first,
// this variant reads directly from peers' input buffers and writes reduced
// results directly to peers' output buffers — eliminating the need for a
// separate scratch buffer and reducing memory traffic.
//
// The algorithm runs in a single kernel with the following steps:
//
// 1. Barrier: Signal and wait on all peers to ensure input buffers are ready.
//
// 2. Reduce-Scatter: Each rank reads its assigned chunk from every peer's
// input buffer (via CudaIpc remote memory handles), reduces all values
// locally, then writes the reduced result to its own output buffer AND
// directly to every peer's output buffer at the same offset.
//
// 3. Global sync + Barrier: A device-wide sync ensures all writes complete,
// followed by a final signal/wait to guarantee all peers have finished
// writing, making the full output buffer valid on every rank.
//
// This approach requires registering both input and output buffers as remote
// memories (2 * nPeers handles), but avoids scratch buffer allocation and
// the extra copy steps of the standard RSAG. The NRanksPerNode template
// parameter enables compile-time unrolling of peer loops (supports 4 or 8).
template <int NRanksPerNode, ReduceOp OpType, typename T>
__global__ void __launch_bounds__(1024, 1)
allreduceRsAgZeroCopy(T* buff, T* scratch, T* resultBuff, DeviceHandle<BaseMemoryChannel>* memoryChannels,
DeviceHandle<SwitchChannel>* switchChannels, void* remoteMemories, int rank, int worldSize,
size_t nelems) {
int blockId = blockIdx.x;
assert((uintptr_t)buff % sizeof(int4) == 0);
assert((uintptr_t)resultBuff % sizeof(int4) == 0);
constexpr int NPeers = NRanksPerNode - 1;
constexpr uint32_t nelemsPerInt4 = sizeof(int4) / sizeof(T);
const uint32_t outputRemoteBufferOffset = NRanksPerNode - 1;
uint32_t alignedNelems = ((nelems + NRanksPerNode - 1) / NRanksPerNode + nelemsPerInt4 - 1) / nelemsPerInt4 *
nelemsPerInt4 * NRanksPerNode;
uint32_t nelemsPerRank = alignedNelems / NRanksPerNode;
uint32_t nInt4PerRank = nelemsPerRank / nelemsPerInt4;
uint32_t nInt4Total = (nelems + nelemsPerInt4 - 1) / nelemsPerInt4;
int4* resultBuff4 = reinterpret_cast<int4*>((char*)resultBuff);
int4* buff4 = reinterpret_cast<int4*>((char*)buff);
DeviceHandle<BaseMemoryChannel>* memoryChannelsLocal = memoryChannels + blockId * NPeers;
uint32_t nInt4PerBlock = nInt4PerRank / gridDim.x;
uint32_t remainderForBlock = nInt4PerRank % gridDim.x;
uint32_t offset4 = blockId * nInt4PerBlock;
if (blockId == (int)(gridDim.x - 1)) {
nInt4PerBlock += remainderForBlock;
}
if (nInt4PerBlock == 0) return;
if (threadIdx.x < NPeers) {
memoryChannelsLocal[threadIdx.x].relaxedSignal();
memoryChannelsLocal[threadIdx.x].relaxedWait();
}
__syncthreads();
int4 data[NPeers];
for (uint32_t idx = threadIdx.x; idx < nInt4PerBlock; idx += blockDim.x) {
uint32_t offset = idx + offset4 + rank * nInt4PerRank;
if (offset >= nInt4Total) continue;
int4 tmp = buff4[offset];
#pragma unroll
for (int i = 0; i < NPeers; i++) {
int rankIdx = (rank + i + 1) % NRanksPerNode;
int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1;
data[i] = mscclpp::read<int4>(((void**)remoteMemories)[peerIdx], offset);
}
for (int i = 0; i < NPeers; i++) {
tmp = cal_vector<T, OpType>(data[i], tmp);
}
#pragma unroll
for (int i = 0; i < NPeers; i++) {
int rankIdx = (rank + i + 1) % NRanksPerNode;
int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1;
mscclpp::write<int4>(((void**)remoteMemories)[outputRemoteBufferOffset + peerIdx], offset, tmp);
}
resultBuff4[offset] = tmp;
}
// Use device barrier gives better performance here.
globalSyncer.sync(gridDim.x);
if (blockIdx.x == 0 && threadIdx.x < NPeers) {
memoryChannelsLocal[threadIdx.x].signal();
memoryChannelsLocal[threadIdx.x].wait();
}
}
template <ReduceOp OpType, typename T>
struct AllreduceRsAgZeroCopyAdapter {
static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories,
DeviceHandle<SwitchChannel>* switchChannel, DeviceHandle<SwitchChannel>*, size_t, size_t,
size_t, int rank, int nRanksPerNode, int worldSize, size_t inputSize, cudaStream_t stream,
void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) {
using ChannelType = DeviceHandle<BaseMemoryChannel>;
size_t nelems = inputSize / sizeof(T);
if (nBlocks == 0 || nThreadsPerBlock == 0) {
nThreadsPerBlock = 1024;
nBlocks = 64;
if (inputSize >= (1 << 26)) {
nBlocks = 128;
}
}
if (nRanksPerNode == 4) {
allreduceRsAgZeroCopy<4, OpType, T>
<<<nBlocks, nThreadsPerBlock, 0, stream>>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels,
switchChannel, remoteMemories, rank, worldSize, nelems);
} else if (nRanksPerNode == 8) {
allreduceRsAgZeroCopy<8, OpType, T>
<<<nBlocks, nThreadsPerBlock, 0, stream>>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels,
switchChannel, remoteMemories, rank, worldSize, nelems);
} else {
THROW(ALGO, Error, ErrorCode::InvalidUsage, "Unsupported number of ranks per node: ", nRanksPerNode);
}
return cudaGetLastError();
}
};
void AllreduceRsAgZeroCopy::initialize(std::shared_ptr<Communicator> comm) {
this->conns_ = setupConnections(comm);
nChannelsPerConnection_ = 128;
comm_ = comm;
// setup semaphores
this->semaphores_ = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_);
this->baseChannels_ = setupBaseMemoryChannels(this->conns_, this->semaphores_, nChannelsPerConnection_);
this->baseMemoryChannelHandles_ = setupBaseMemoryChannelDeviceHandles(baseChannels_);
}
CommResult AllreduceRsAgZeroCopy::allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output,
size_t inputSize, DataType dtype, ReduceOp op,
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>&) {
auto algoCtx = std::static_pointer_cast<AlgorithmCtx>(ctx);
AllreduceFunc allreduce = dispatch<AllreduceRsAgZeroCopyAdapter>(op, dtype);
if (!allreduce) {
WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast<int>(op),
", dtype=", static_cast<int>(dtype));
return CommResult::CommInvalidArgument;
}
std::pair<int, int> numBlocksAndThreads = {nBlocks, nThreadsPerBlock};
cudaError_t error =
allreduce(input, nullptr, output, this->baseMemoryChannelHandles_.get(), algoCtx->remoteMemoryHandles.get(),
nullptr, nullptr, 0, 0, 0, algoCtx->rank, algoCtx->nRanksPerNode, algoCtx->workSize, inputSize, stream,
nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second);
if (error != cudaSuccess) {
WARN(ALGO, "Allreduce kernel launch failed with error: ", cudaGetErrorString(error));
return CommResult::CommUnhandledCudaError;
}
return CommResult::CommSuccess;
}
AlgorithmCtxKey AllreduceRsAgZeroCopy::generateAllreduceContextKey(const void* inputBuffer, void* outputBuffer,
size_t size, DataType, bool symmetricMemory) {
// For non-symmetric algorithms, we use both input and output buffer pointers in the key.
static int tag = 0;
if (symmetricMemory) {
size_t inputBytes, outputBytes;
CUdeviceptr inputBasePtr, outputBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&inputBasePtr, &inputBytes, (CUdeviceptr)inputBuffer));
MSCCLPP_CUTHROW(cuMemGetAddressRange(&outputBasePtr, &outputBytes, (CUdeviceptr)outputBuffer));
return AlgorithmCtxKey{(void*)inputBasePtr, (void*)outputBasePtr, inputBytes, outputBytes, 0};
}
return AlgorithmCtxKey{(void*)inputBuffer, outputBuffer, size, size, ++tag};
}
std::shared_ptr<void> AllreduceRsAgZeroCopy::initAllreduceContext(std::shared_ptr<Communicator> comm, const void* input,
void* output, size_t size, DataType) {
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
ctx->memorySemaphores = this->semaphores_;
// register input and output memories
RegisteredMemory inputMemory = comm->registerMemory((void*)input, size, Transport::CudaIpc);
RegisteredMemory outputMemory = comm->registerMemory(output, size, Transport::CudaIpc);
this->inputMemories_.push_back(inputMemory);
this->outputMemories_.push_back(outputMemory);
auto remoteInputMemories = setupRemoteMemories(comm, ctx->rank, inputMemory);
auto remoteOutputMemories = setupRemoteMemories(comm, ctx->rank, outputMemory);
ctx->registeredMemories.insert(ctx->registeredMemories.end(), remoteInputMemories.begin(), remoteInputMemories.end());
ctx->registeredMemories.insert(ctx->registeredMemories.end(), remoteOutputMemories.begin(),
remoteOutputMemories.end());
std::vector<void*> remoteMemoryHandles;
for (const auto& remoteMemory : ctx->registeredMemories) {
remoteMemoryHandles.push_back(remoteMemory.data());
}
ctx->remoteMemoryHandles = detail::gpuCallocShared<void*>(remoteMemoryHandles.size());
gpuMemcpy(ctx->remoteMemoryHandles.get(), remoteMemoryHandles.data(), remoteMemoryHandles.size(),
cudaMemcpyHostToDevice);
// store local registered memories to ctx for lifetime management
ctx->registeredMemories.push_back(inputMemory);
ctx->registeredMemories.push_back(outputMemory);
return ctx;
}
std::shared_ptr<Algorithm> AllreduceRsAgZeroCopy::build() {
auto self = std::make_shared<AllreduceRsAgZeroCopy>();
return std::make_shared<NativeAlgorithm>(
"default_allreduce_rsag_zero_copy", "allreduce",
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) -> CommResult {
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
extras);
},
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize,
DataType dtype) { return self->initAllreduceContext(comm, input, output, inputSize, dtype); },
[self](const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype,
bool symmetricMemory) {
return self->generateAllreduceContextKey(input, output, inputSize, dtype, symmetricMemory);
});
}
} // namespace collective
} // namespace mscclpp

View File

@@ -0,0 +1,43 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#ifndef MSCCLPP_EXT_ALLREDUCE_RSAG_HPP_
#define MSCCLPP_EXT_ALLREDUCE_RSAG_HPP_
#include <mscclpp/algorithm.hpp>
namespace mscclpp {
namespace collective {
class AllreduceRsAg : public mscclpp::AlgorithmBuilder {
public:
AllreduceRsAg(uintptr_t scratchBuffer, size_t scratchBufferSize)
: scratchBuffer_((void*)scratchBuffer), scratchBufferSize_(scratchBufferSize){};
std::shared_ptr<mscclpp::Algorithm> build() override;
private:
void initialize(std::shared_ptr<Communicator> comm);
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>& extras);
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
DataType);
AlgorithmCtxKey generateAllreduceContextKey(const void*, void*, size_t, DataType, bool);
void* scratchBuffer_;
size_t scratchBufferSize_;
std::shared_ptr<Communicator> comm_;
int nChannelsPerConnection_;
std::vector<Connection> conns_;
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> scratchSemaphores_;
std::vector<RegisteredMemory> remoteScratchMemories_;
RegisteredMemory localScratchMemory_;
std::vector<BaseMemoryChannel> baseChannels_;
std::shared_ptr<DeviceHandle<BaseMemoryChannel>> baseMemoryChannelHandles_;
std::shared_ptr<void*> remoteMemoryHandles_;
};
} // namespace collective
} // namespace mscclpp
#endif // MSCCLPP_EXT_ALLREDUCE_RSAG_HPP_

View File

@@ -0,0 +1,43 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#ifndef MSCCLPP_EXT_ALLREDUCE_RSAG_PIPELINE_HPP_
#define MSCCLPP_EXT_ALLREDUCE_RSAG_PIPELINE_HPP_
#include <mscclpp/algorithm.hpp>
namespace mscclpp {
namespace collective {
class AllreduceRsAgPipeline : public mscclpp::AlgorithmBuilder {
public:
AllreduceRsAgPipeline(uintptr_t scratchBuffer, size_t scratchBufferSize)
: scratchBuffer_((void*)scratchBuffer), scratchBufferSize_(scratchBufferSize){};
std::shared_ptr<mscclpp::Algorithm> build() override;
private:
void initialize(std::shared_ptr<Communicator> comm);
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>& extras);
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
DataType);
AlgorithmCtxKey generateAllreduceContextKey(const void*, void*, size_t, DataType, bool);
void* scratchBuffer_;
size_t scratchBufferSize_;
std::shared_ptr<Communicator> comm_;
int nChannelsPerConnection_;
std::vector<Connection> conns_;
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> scratchSemaphores_;
std::vector<RegisteredMemory> remoteScratchMemories_;
RegisteredMemory localScratchMemory_;
std::vector<BaseMemoryChannel> baseChannels_;
std::shared_ptr<DeviceHandle<BaseMemoryChannel>> baseMemoryChannelHandles_;
std::shared_ptr<void*> remoteMemoryHandles_;
};
} // namespace collective
} // namespace mscclpp
#endif // MSCCLPP_EXT_ALLREDUCE_RSAG_PIPELINE_HPP_

View File

@@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#ifndef MSCCLPP_EXT_ALLREDUCE_RSAG_ZERO_COPY_HPP_
#define MSCCLPP_EXT_ALLREDUCE_RSAG_ZERO_COPY_HPP_
#include <mscclpp/algorithm.hpp>
namespace mscclpp {
namespace collective {
class AllreduceRsAgZeroCopy : public mscclpp::AlgorithmBuilder {
public:
AllreduceRsAgZeroCopy() = default;
std::shared_ptr<mscclpp::Algorithm> build() override;
private:
void initialize(std::shared_ptr<Communicator> comm);
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>& extras);
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
DataType);
AlgorithmCtxKey generateAllreduceContextKey(const void*, void*, size_t, DataType, bool);
std::shared_ptr<Communicator> comm_;
int nChannelsPerConnection_;
std::vector<Connection> conns_;
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> semaphores_;
std::vector<RegisteredMemory> inputMemories_;
std::vector<RegisteredMemory> outputMemories_;
std::vector<BaseMemoryChannel> baseChannels_;
std::shared_ptr<DeviceHandle<BaseMemoryChannel>> baseMemoryChannelHandles_;
};
} // namespace collective
} // namespace mscclpp
#endif // MSCCLPP_EXT_ALLREDUCE_RSAG_ZERO_COPY_HPP_

View File

@@ -84,6 +84,7 @@ class AlgorithmCtx {
std::shared_ptr<DeviceHandle<PortChannel>> portChannelDeviceHandles;
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> memorySemaphores;
std::vector<std::shared_ptr<Host2DeviceSemaphore>> hostSemaphores;
std::shared_ptr<void*> remoteMemoryHandles;
std::unordered_map<std::string, std::shared_ptr<void>> extras;
};

View File

@@ -71,7 +71,15 @@ static std::shared_ptr<Algorithm> selectSingleNodeAllreduceBlackwell(
if (messageSize <= (1 << 21)) { // <= 2MB
return algoMap.at("default_allreduce_packet");
}
return nullptr;
if (config.inCaptureMode) {
// CUDA graph mode: setup new connections each time (zero-copy for graph)
return algoMap.at("default_allreduce_rsag_zero_copy");
}
// Non-graph mode: use non-zero-copy algorithms
if (messageSize <= (1 << 23)) { // <= 8MB
return algoMap.at("default_allreduce_rsag");
}
return algoMap.at("default_allreduce_rsag_pipeline");
}
// Symmetric memory path: can use cached memory handles
@@ -83,8 +91,7 @@ static std::shared_ptr<Algorithm> selectSingleNodeAllreduceBlackwell(
return algoMap.at("default_allreduce_nvls");
}
INFO(MSCCLPP_NCCL, "No suitable kernel for Blackwell architecture, fallback to nccl/rccl");
return nullptr;
return algoMap.at("default_allreduce_rsag_zero_copy");
}
std::shared_ptr<Algorithm> selectSingleNodeAllreduce(