mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-20 14:59:29 +00:00
Add new algos for GB200 (#747)
- Add new algos (allreduce_rsag, allreduce_rsag_pipeline and
allreduce_rsag_zero_copy) for GB200.
- Add IB stub for non-IB env
- Provides example for algorithm tunning with different nblocks/nthreads
Perf for allreduce_rsag
```
# out-of-place in-place
# size count type redop root time algbw busbw #wrong time algbw busbw #wrong
# (B) (elements) (us) (GB/s) (GB/s) (us) (GB/s) (GB/s)
1048576 262144 float sum -1 25.16 41.67 62.51 0 23.73 44.18 66.27 0
2097152 524288 float sum -1 26.06 80.47 120.71 0 25.31 82.86 124.29 0
4194304 1048576 float sum -1 31.09 134.93 202.39 0 30.75 136.39 204.58 0
8388608 2097152 float sum -1 45.52 184.29 276.43 0 45.13 185.87 278.80 0
16777216 4194304 float sum -1 75.73 221.53 332.30 0 75.51 222.18 333.27 0
33554432 8388608 float sum -1 137.25 244.48 366.72 0 137.22 244.54 366.81 0
67108864 16777216 float sum -1 271.34 247.32 370.99 0 270.86 247.76 371.65 0
134217728 33554432 float sum -1 534.25 251.22 376.84 0 534.43 251.14 376.71 0
# Out of bounds values : 0 OK
# Avg bus bandwidth : 264.454
#
# Collective test concluded: all_reduce_perf
```
perf for allreduce_rsag_pipeline
```
# out-of-place in-place
# size count type redop root time algbw busbw #wrong time algbw busbw #wrong
# (B) (elements) (us) (GB/s) (GB/s) (us) (GB/s) (GB/s)
1048576 262144 float sum -1 61.57 17.03 25.55 0 61.51 17.05 25.57 0
2097152 524288 float sum -1 61.31 34.20 51.31 0 61.23 34.25 51.38 0
4194304 1048576 float sum -1 61.62 68.06 102.10 0 61.84 67.83 101.74 0
8388608 2097152 float sum -1 61.97 135.37 203.06 0 61.89 135.53 203.30 0
16777216 4194304 float sum -1 63.15 265.65 398.48 0 62.89 266.76 400.15 0
33554432 8388608 float sum -1 100.63 333.46 500.19 0 99.76 336.34 504.51 0
67108864 16777216 float sum -1 180.04 372.75 559.13 0 179.75 373.34 560.01 0
134217728 33554432 float sum -1 339.60 395.23 592.84 0 338.16 396.91 595.36 0
# Out of bounds values : 0 OK
# Avg bus bandwidth : 304.665
#
# Collective test concluded: all_reduce_perf
```
perf for allreduce_rsag_zero_copy
```
# out-of-place in-place
# size count type redop root time algbw busbw #wrong time algbw busbw #wrong
# (B) (elements) (us) (GB/s) (GB/s) (us) (GB/s) (GB/s)
1048576 262144 float sum -1 14.99 69.93 104.90 0 14.44 72.61 108.92 0
2097152 524288 float sum -1 16.19 129.56 194.33 0 15.85 132.32 198.48 0
4194304 1048576 float sum -1 21.19 197.98 296.97 0 20.64 203.20 304.81 0
8388608 2097152 float sum -1 31.04 270.27 405.41 0 30.68 273.44 410.16 0
16777216 4194304 float sum -1 50.34 333.26 499.89 0 50.15 334.51 501.77 0
33554432 8388608 float sum -1 89.58 374.56 561.84 0 88.65 378.48 567.73 0
67108864 16777216 float sum -1 165.69 405.03 607.54 0 163.64 410.10 615.16 0
134217728 33554432 float sum -1 323.19 415.28 622.93 0 318.01 422.05 633.07 0
# Out of bounds values : 0 OK
# Avg bus bandwidth : 414.619
#
# Collective test concluded: all_reduce_perf
```
---------
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: chhwang <8018170+chhwang@users.noreply.github.com>
Co-authored-by: Qinghua Zhou <qinghuazhou@microsoft.com>
Co-authored-by: Caio Rocha <caiorocha@microsoft.com>
This commit is contained in:
12
.github/workflows/codeql-analysis.yml
vendored
12
.github/workflows/codeql-analysis.yml
vendored
@@ -51,7 +51,7 @@ jobs:
|
||||
df -h
|
||||
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v3
|
||||
uses: github/codeql-action/init@v4
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
|
||||
@@ -63,10 +63,10 @@ jobs:
|
||||
run: |
|
||||
rm -rf build && mkdir build && cd build
|
||||
cmake -DMSCCLPP_BYPASS_GPU_CHECK=ON -DMSCCLPP_USE_CUDA=ON ..
|
||||
make -j
|
||||
make -j4
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v3
|
||||
uses: github/codeql-action/analyze@v4
|
||||
with:
|
||||
category: "/language:${{matrix.language}}/version:${{matrix.version}}"
|
||||
|
||||
@@ -96,7 +96,7 @@ jobs:
|
||||
df -h
|
||||
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v3
|
||||
uses: github/codeql-action/init@v4
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
|
||||
@@ -108,9 +108,9 @@ jobs:
|
||||
run: |
|
||||
rm -rf build && mkdir build && cd build
|
||||
CXX=/opt/rocm/bin/hipcc cmake -DMSCCLPP_BYPASS_GPU_CHECK=ON -DMSCCLPP_USE_ROCM=ON ..
|
||||
make -j
|
||||
make -j4
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v3
|
||||
uses: github/codeql-action/analyze@v4
|
||||
with:
|
||||
category: "/language:${{matrix.language}}/version:${{matrix.version}}"
|
||||
|
||||
282
examples/torch-integration/customized_comm_with_tuning.py
Normal file
282
examples/torch-integration/customized_comm_with_tuning.py
Normal file
@@ -0,0 +1,282 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# MSCCLPP_MASTER_ADDR=<master_ip> MSCCLPP_MASTER_PORT=<port> torchrun --nnodes=1 --nproc_per_node=8 customized_comm_with_tuning.py
|
||||
|
||||
import os
|
||||
import torch
|
||||
import mscclpp.utils as mscclpp_utils
|
||||
import mscclpp
|
||||
import mscclpp.ext
|
||||
import netifaces as ni
|
||||
import ipaddress
|
||||
|
||||
|
||||
def load_algorithms(scratch_buffer: torch.tensor, rank: int) -> mscclpp.AlgorithmCollection:
|
||||
collection_builder = mscclpp.ext.AlgorithmCollectionBuilder()
|
||||
return collection_builder.build_default_algorithms(
|
||||
scratch_buffer=scratch_buffer.data_ptr(), scratch_buffer_size=scratch_buffer.nbytes, rank=rank
|
||||
)
|
||||
|
||||
|
||||
def interfaces_for_ip_netifaces(ip: str):
|
||||
target = ipaddress.ip_address(ip)
|
||||
for interface in ni.interfaces():
|
||||
addresses = ni.ifaddresses(interface)
|
||||
if ni.AF_INET in addresses:
|
||||
for link in addresses[ni.AF_INET]:
|
||||
if "addr" in link:
|
||||
addr = ipaddress.ip_address(link["addr"])
|
||||
if addr == target:
|
||||
return interface
|
||||
return None
|
||||
|
||||
|
||||
def to_mscclpp_reduce_op(op: torch.distributed.ReduceOp) -> mscclpp.ReduceOp:
|
||||
if op == torch.distributed.ReduceOp.SUM:
|
||||
return mscclpp.ReduceOp.SUM
|
||||
elif op == torch.distributed.ReduceOp.MIN:
|
||||
return mscclpp.ReduceOp.MIN
|
||||
else:
|
||||
raise ValueError(f"unsupported op: {op}")
|
||||
|
||||
|
||||
class CustomizedComm:
|
||||
def __init__(self, comm: mscclpp.CommGroup):
|
||||
self.comm = comm
|
||||
self.rank = comm.my_rank
|
||||
self.world_size = comm.nranks
|
||||
self.local_rank = comm.my_rank % comm.nranks_per_node
|
||||
self.n_ranks_per_node = comm.nranks_per_node
|
||||
dlpack = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(torch.float16))
|
||||
self.scratch_buffer = torch.utils.dlpack.from_dlpack(dlpack)
|
||||
algorithms = load_algorithms(scratch_buffer=self.scratch_buffer, rank=self.rank)
|
||||
self._algorithm_nvls_packet = [
|
||||
algo
|
||||
for algo in algorithms
|
||||
if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_packet"
|
||||
][0]
|
||||
self._algorithm_rsag_zero_copy = [
|
||||
algo
|
||||
for algo in algorithms
|
||||
if algo.collective == "allreduce" and algo.name == "default_allreduce_rsag_zero_copy"
|
||||
][0]
|
||||
self._algorithm_packet = [
|
||||
algo for algo in algorithms if algo.collective == "allreduce" and algo.name == "default_allreduce_packet"
|
||||
][0]
|
||||
self._tune(n_warmup=5, n_graph_launches=10, n_ops_per_graph=100)
|
||||
|
||||
def _tune(self, n_warmup, n_graph_launches, n_ops_per_graph):
|
||||
sizes = [1 << i for i in range(10, 28)]
|
||||
# Pre-fill with defaults for barrier
|
||||
self.best_configs = {1024: (self._algorithm_nvls_packet, 0, 0)}
|
||||
|
||||
tune_tensor = torch.rand(1 << 27, dtype=torch.float16, device="cuda")
|
||||
candidates_nblocks = [4, 8, 16, 24, 32, 48, 64, 128]
|
||||
candidates_nthreads = [512, 768, 1024]
|
||||
|
||||
for size in sizes:
|
||||
algos = []
|
||||
if size <= 4 * 1024 * 1024:
|
||||
algos.append(self._algorithm_nvls_packet)
|
||||
algos.append(self._algorithm_packet)
|
||||
if size >= 512 * 1024:
|
||||
algos.append(self._algorithm_rsag_zero_copy)
|
||||
|
||||
best_time = float("inf")
|
||||
best_config = None
|
||||
|
||||
for algo in algos:
|
||||
for nb in candidates_nblocks:
|
||||
if algo.name == "default_allreduce_nvls_packet" and nb > 16:
|
||||
continue
|
||||
if algo.name == "default_allreduce_packet" and nb > 56:
|
||||
continue
|
||||
for nt in candidates_nthreads:
|
||||
if self._run_algo(algo, tune_tensor, size, nb, nt) != 0:
|
||||
continue
|
||||
|
||||
for _ in range(n_warmup):
|
||||
self._run_algo(algo, tune_tensor, size, nb, nt)
|
||||
self.barrier()
|
||||
|
||||
capture_stream = torch.cuda.Stream()
|
||||
capture_stream.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
g = torch.cuda.CUDAGraph()
|
||||
# Warmup on capture stream
|
||||
with torch.cuda.stream(capture_stream):
|
||||
self._run_algo(algo, tune_tensor, size, nb, nt)
|
||||
capture_stream.synchronize()
|
||||
|
||||
with torch.cuda.graph(g, stream=capture_stream):
|
||||
for _ in range(n_ops_per_graph):
|
||||
self._run_algo(algo, tune_tensor, size, nb, nt)
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record(capture_stream)
|
||||
with torch.cuda.stream(capture_stream):
|
||||
for _ in range(n_graph_launches):
|
||||
g.replay()
|
||||
end_event.record(capture_stream)
|
||||
end_event.synchronize()
|
||||
|
||||
elapsed = start_event.elapsed_time(end_event)
|
||||
|
||||
# Synchronize timing results across all ranks to ensure consistent algorithm selection
|
||||
# replicate n times such due to algo limitations
|
||||
time_tensor = torch.full((self.world_size,), elapsed, dtype=torch.float64, device="cuda").to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
torch.cuda.current_stream().wait_stream(capture_stream)
|
||||
# TODO: use all_reduce may cause problem if the time elapsed between different algos are too close.
|
||||
# May change to broadcast in the future if that becomes an issue.
|
||||
self.all_reduce(time_tensor, op=torch.distributed.ReduceOp.SUM)
|
||||
avg_time = time_tensor[self.rank].item() / self.world_size
|
||||
|
||||
if avg_time < best_time:
|
||||
best_time = avg_time
|
||||
best_config = (algo, nb, nt)
|
||||
|
||||
if best_config:
|
||||
self.best_configs[size] = best_config
|
||||
if self.rank == 0:
|
||||
print(
|
||||
f"Size {size}: Best Algo {best_config[0].name} nblocks {best_config[1]} nthreads {best_config[2]} Time {(best_time/(n_graph_launches * n_ops_per_graph))*1000:.2f} us"
|
||||
)
|
||||
# reset the algorithms after tuning
|
||||
torch.cuda.synchronize()
|
||||
for algo in algos:
|
||||
algo.reset()
|
||||
|
||||
def _run_algo(self, algo, tensor, size, nblocks, nthreads):
|
||||
return algo.execute(
|
||||
comm=self.comm.communicator,
|
||||
input_buffer=tensor.data_ptr(),
|
||||
output_buffer=tensor.data_ptr(),
|
||||
input_size=size,
|
||||
output_size=size,
|
||||
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(tensor.dtype),
|
||||
op=mscclpp.ReduceOp.SUM,
|
||||
stream=torch.cuda.current_stream().cuda_stream,
|
||||
nblocks=nblocks,
|
||||
nthreads_per_block=nthreads,
|
||||
)
|
||||
|
||||
def get_tuned_config(self, size):
|
||||
if size < 1024:
|
||||
target_size = 1024
|
||||
elif size > 256 * 1024 * 1024:
|
||||
target_size = 256 * 1024 * 1024
|
||||
else:
|
||||
target_size = 1 << (size - 1).bit_length()
|
||||
return self.best_configs.get(target_size)
|
||||
|
||||
def all_reduce(self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM, stream: torch.cuda.Stream = None):
|
||||
assert op == torch.distributed.ReduceOp.SUM
|
||||
config = self.get_tuned_config(tensor.nbytes)
|
||||
algo, nblocks, nthreads = config if config else (self._algorithm_nvls_packet, 0, 0)
|
||||
ret = algo.execute(
|
||||
comm=self.comm.communicator,
|
||||
input_buffer=tensor.data_ptr(),
|
||||
output_buffer=tensor.data_ptr(),
|
||||
input_size=tensor.nbytes,
|
||||
output_size=tensor.nbytes,
|
||||
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(tensor.dtype),
|
||||
op=to_mscclpp_reduce_op(op),
|
||||
stream=stream.cuda_stream if stream is not None else torch.cuda.current_stream().cuda_stream,
|
||||
nblocks=nblocks,
|
||||
nthreads_per_block=nthreads,
|
||||
)
|
||||
if ret != 0:
|
||||
print(f"Rank {self.rank}: Algo {algo.name} failed with error {ret}")
|
||||
|
||||
def barrier(self):
|
||||
tensor = torch.empty(self.world_size, dtype=torch.float, device=torch.device("cuda"))
|
||||
self.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM, stream=torch.cuda.current_stream())
|
||||
|
||||
def benchmark(self, n_warmup=10, n_graph_launches=10, n_iter_per_graph=100):
|
||||
low = 5 * 1024
|
||||
high = 80 * 1024 * 1024
|
||||
sizes = []
|
||||
curr = low
|
||||
while curr <= high:
|
||||
sizes.append(curr)
|
||||
curr *= 2
|
||||
|
||||
if self.rank == 0:
|
||||
print(f"{'Size (Bytes)':<20} {'Time (us)':<20} {'AlgoBW (GB/s)':<20}")
|
||||
|
||||
dtype = torch.float16
|
||||
capture_stream = torch.cuda.Stream()
|
||||
|
||||
for size in sizes:
|
||||
tensor = torch.rand(size // 2, dtype=dtype, device="cuda")
|
||||
capture_stream.wait_stream(torch.cuda.current_stream())
|
||||
# Capture Graph
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g, stream=capture_stream):
|
||||
for _ in range(n_iter_per_graph):
|
||||
self.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
|
||||
|
||||
# warmup: Execute the graph once to prime the driver
|
||||
with torch.cuda.stream(capture_stream):
|
||||
for _ in range(n_warmup):
|
||||
g.replay()
|
||||
self.barrier()
|
||||
capture_stream.synchronize()
|
||||
|
||||
# Benchmark
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record(capture_stream)
|
||||
with torch.cuda.stream(capture_stream):
|
||||
for _ in range(n_graph_launches):
|
||||
g.replay()
|
||||
end_event.record(capture_stream)
|
||||
end_event.synchronize()
|
||||
|
||||
# Get elapsed time in milliseconds
|
||||
elapsed_ms = start_event.elapsed_time(end_event)
|
||||
avg_time_ms = elapsed_ms / (n_graph_launches * n_iter_per_graph)
|
||||
time_us = avg_time_ms * 1000
|
||||
|
||||
alg_bw = size / (avg_time_ms * 1e-3) if avg_time_ms > 0 else 0
|
||||
if self.rank == 0:
|
||||
print(f"{size:<20} {time_us:<20.2f} {alg_bw / 1e9:<20.2f}")
|
||||
|
||||
def destroy(self):
|
||||
self._algorithm_nvls_nonzero_copy = None
|
||||
self._algorithm_nvls_packet = None
|
||||
self.scratch_buffer = None
|
||||
self.comm = None
|
||||
|
||||
|
||||
def init_dist() -> CustomizedComm:
|
||||
rank = int(os.environ["RANK"])
|
||||
world = int(os.environ["WORLD_SIZE"])
|
||||
master_addr = os.environ["MSCCLPP_MASTER_ADDR"]
|
||||
master_port = os.environ["MSCCLPP_MASTER_PORT"]
|
||||
interface = interfaces_for_ip_netifaces(master_addr)
|
||||
if interface is None:
|
||||
raise ValueError(f"Cannot find network interface for IP address {master_addr}")
|
||||
interfaceIpPortTrio = f"{interface}:{master_addr}:{master_port}"
|
||||
mscclpp_group = mscclpp.CommGroup(interfaceIpPortTrio=interfaceIpPortTrio, rank=rank, size=world)
|
||||
return CustomizedComm(mscclpp_group)
|
||||
|
||||
|
||||
def main():
|
||||
local = int(os.environ["LOCAL_RANK"])
|
||||
torch.cuda.set_device(local)
|
||||
comm = init_dist()
|
||||
comm.benchmark(n_warmup=5, n_graph_launches=10, n_iter_per_graph=100)
|
||||
comm.barrier()
|
||||
torch.cuda.synchronize()
|
||||
comm.destroy()
|
||||
print(f"rank {local} All-reduce operation completed successfully.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -13,6 +13,9 @@
|
||||
#include "allreduce/allreduce_nvls_with_copy.hpp"
|
||||
#include "allreduce/allreduce_nvls_with_copy_2.hpp"
|
||||
#include "allreduce/allreduce_packet.hpp"
|
||||
#include "allreduce/allreduce_rsag.hpp"
|
||||
#include "allreduce/allreduce_rsag_pipeline.hpp"
|
||||
#include "allreduce/allreduce_rsag_zero_copy.hpp"
|
||||
#include "logger.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
@@ -82,6 +85,14 @@ AlgorithmCollection AlgorithmCollectionBuilder::buildDefaultNativeAlgorithms(uin
|
||||
collection.registerAlgorithm(allreduceNvls->collective(), allreduceNvls->name(), allreduceNvls);
|
||||
auto allreduceFullmesh = std::make_shared<AllreduceFullmesh>(scratchBuffer, scratchBufferSize)->build();
|
||||
collection.registerAlgorithm(allreduceFullmesh->collective(), allreduceFullmesh->name(), allreduceFullmesh);
|
||||
auto allreduceRsag = std::make_shared<AllreduceRsAg>(scratchBuffer, scratchBufferSize)->build();
|
||||
collection.registerAlgorithm(allreduceRsag->collective(), allreduceRsag->name(), allreduceRsag);
|
||||
auto allreduceRsagPipeline = std::make_shared<AllreduceRsAgPipeline>(scratchBuffer, scratchBufferSize)->build();
|
||||
collection.registerAlgorithm(allreduceRsagPipeline->collective(), allreduceRsagPipeline->name(),
|
||||
allreduceRsagPipeline);
|
||||
auto allreduceRsagZeroCopy = std::make_shared<AllreduceRsAgZeroCopy>()->build();
|
||||
collection.registerAlgorithm(allreduceRsagZeroCopy->collective(), allreduceRsagZeroCopy->name(),
|
||||
allreduceRsagZeroCopy);
|
||||
|
||||
auto allgatherFullmesh = std::make_shared<AllgatherFullmesh>(scratchBuffer, scratchBufferSize)->build();
|
||||
collection.registerAlgorithm(allgatherFullmesh->collective(), allgatherFullmesh->name(), allgatherFullmesh);
|
||||
|
||||
229
src/ext/collectives/allreduce/allreduce_rsag.cu
Normal file
229
src/ext/collectives/allreduce/allreduce_rsag.cu
Normal file
@@ -0,0 +1,229 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "allreduce/allreduce_rsag.hpp"
|
||||
#include "allreduce/common.hpp"
|
||||
#include "collective_utils.hpp"
|
||||
#include "logger.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
namespace collective {
|
||||
|
||||
// Allreduce using the Reduce-Scatter + All-Gather (RSAG) pattern.
|
||||
//
|
||||
// This algorithm performs allreduce in three phases over intra-node peers
|
||||
// connected via CudaIpc memory channels:
|
||||
//
|
||||
// 1. Scatter: Each rank copies its input data into a scratch buffer, then
|
||||
// signals peers and waits for all peers to do the same.
|
||||
//
|
||||
// 2. Reduce-Scatter: Each rank reduces its assigned chunk by reading the
|
||||
// corresponding chunks from all peers' scratch buffers (via remote memory
|
||||
// handles) and applying the reduction op. The reduced result is written
|
||||
// back to both the local result buffer and peers' scratch buffers.
|
||||
//
|
||||
// 3. All-Gather: After a second signal/wait barrier, each rank copies the
|
||||
// reduced chunks produced by other ranks from the scratch buffer into its
|
||||
// result buffer, completing the allreduce.
|
||||
//
|
||||
// Data is processed in int4-sized (16-byte) units for coalesced memory access,
|
||||
// with special handling for any remainder elements at the tail.
|
||||
template <ReduceOp OpType, typename T>
|
||||
__global__ void __launch_bounds__(1024, 1)
|
||||
allreduceRsAg(T* buff, T* scratch, T* resultBuff, DeviceHandle<BaseMemoryChannel>* memoryChannels,
|
||||
DeviceHandle<SwitchChannel>* switchChannels, void* remoteMemories, int rank, int nRanksPerNode,
|
||||
int worldSize, size_t nelems) {
|
||||
int blockId = blockIdx.x;
|
||||
uint32_t nPeers = nRanksPerNode - 1;
|
||||
|
||||
assert((uintptr_t)buff % sizeof(int4) == 0);
|
||||
assert((uintptr_t)resultBuff % sizeof(int4) == 0);
|
||||
|
||||
constexpr uint32_t nelemsPerInt4 = sizeof(int4) / sizeof(T);
|
||||
uint32_t alignedNelems = ((nelems + nRanksPerNode - 1) / nRanksPerNode + nelemsPerInt4 - 1) / nelemsPerInt4 *
|
||||
nelemsPerInt4 * nRanksPerNode;
|
||||
uint32_t nelemsPerRank = alignedNelems / nRanksPerNode;
|
||||
uint32_t nInt4PerRank = nelemsPerRank / nelemsPerInt4;
|
||||
uint32_t lastInt4Index = nelems / nelemsPerInt4;
|
||||
uint32_t remainder = nelems % nelemsPerInt4;
|
||||
|
||||
int4* scratch4 = reinterpret_cast<int4*>((char*)scratch);
|
||||
int4* resultBuff4 = reinterpret_cast<int4*>((char*)resultBuff);
|
||||
int4* buff4 = reinterpret_cast<int4*>((char*)buff);
|
||||
DeviceHandle<BaseMemoryChannel>* memoryChannelsLocal = memoryChannels + blockId * nPeers;
|
||||
|
||||
uint32_t nInt4PerBlock = nInt4PerRank / gridDim.x;
|
||||
uint32_t remainderForBlock = nInt4PerRank % gridDim.x;
|
||||
uint32_t offset4 = blockId * nInt4PerBlock;
|
||||
if (blockId == (int)(gridDim.x - 1)) {
|
||||
nInt4PerBlock += remainderForBlock;
|
||||
}
|
||||
if (nInt4PerBlock == 0) return;
|
||||
uint32_t nInt4ForCopy = nInt4PerBlock * nRanksPerNode;
|
||||
|
||||
for (uint32_t idx = threadIdx.x; idx < nInt4ForCopy; idx += blockDim.x) {
|
||||
int rankIdx = idx / nInt4PerBlock;
|
||||
uint32_t offsetIdx = rankIdx * nInt4PerRank + offset4 + (idx % nInt4PerBlock);
|
||||
if (offsetIdx > lastInt4Index) continue;
|
||||
if (offsetIdx == lastInt4Index && remainder != 0) {
|
||||
for (uint32_t i = 0; i < remainder; i++) {
|
||||
((T*)&scratch4[offsetIdx])[i] = ((T*)&buff4[offsetIdx])[i];
|
||||
}
|
||||
continue;
|
||||
}
|
||||
scratch4[offsetIdx] = buff4[offsetIdx];
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.x < nPeers) {
|
||||
memoryChannelsLocal[threadIdx.x].signal();
|
||||
memoryChannelsLocal[threadIdx.x].wait();
|
||||
}
|
||||
__syncthreads();
|
||||
for (uint32_t idx = threadIdx.x; idx < nInt4PerBlock; idx += blockDim.x) {
|
||||
uint32_t offset = idx + offset4 + rank * nInt4PerRank;
|
||||
if (offset > lastInt4Index) continue;
|
||||
int4 tmp = scratch4[offset];
|
||||
for (uint32_t i = 0; i < nPeers; i++) {
|
||||
int rankIdx = (rank + i + 1) % nRanksPerNode;
|
||||
int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1;
|
||||
int4 data = mscclpp::read<int4>(((void**)remoteMemories)[peerIdx], offset);
|
||||
tmp = cal_vector<T, OpType>(data, tmp);
|
||||
}
|
||||
for (uint32_t i = 0; i < nPeers; i++) {
|
||||
int rankIdx = (rank + i + 1) % nRanksPerNode;
|
||||
int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1;
|
||||
mscclpp::write<int4>(((void**)remoteMemories)[peerIdx], offset, tmp);
|
||||
}
|
||||
if (offset == lastInt4Index && remainder != 0) {
|
||||
for (uint32_t i = 0; i < remainder; i++) {
|
||||
((T*)&resultBuff4[offset])[i] = ((T*)&tmp)[i];
|
||||
}
|
||||
continue;
|
||||
}
|
||||
resultBuff4[offset] = tmp;
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.x < nPeers) {
|
||||
memoryChannelsLocal[threadIdx.x].signal();
|
||||
memoryChannelsLocal[threadIdx.x].wait();
|
||||
}
|
||||
__syncthreads();
|
||||
for (uint32_t idx = threadIdx.x; idx < nInt4ForCopy; idx += blockDim.x) {
|
||||
int rankIdx = idx / nInt4PerBlock;
|
||||
if (rankIdx == rank) continue;
|
||||
uint32_t offsetIdx = rankIdx * nInt4PerRank + offset4 + (idx % nInt4PerBlock);
|
||||
if (offsetIdx > lastInt4Index) continue;
|
||||
if (offsetIdx == lastInt4Index && remainder != 0) {
|
||||
for (uint32_t i = 0; i < remainder; i++) {
|
||||
((T*)&resultBuff4[offsetIdx])[i] = ((T*)&scratch4[offsetIdx])[i];
|
||||
}
|
||||
continue;
|
||||
}
|
||||
resultBuff4[offsetIdx] = scratch4[offsetIdx];
|
||||
}
|
||||
}
|
||||
|
||||
template <ReduceOp OpType, typename T>
|
||||
struct AllreduceRsAgAdapter {
|
||||
static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories,
|
||||
DeviceHandle<SwitchChannel>* switchChannel, DeviceHandle<SwitchChannel>*, size_t, size_t,
|
||||
size_t, int rank, int nRanksPerNode, int worldSize, size_t inputSize, cudaStream_t stream,
|
||||
void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) {
|
||||
using ChannelType = DeviceHandle<BaseMemoryChannel>;
|
||||
size_t nelems = inputSize / sizeof(T);
|
||||
if (nBlocks == 0 || nThreadsPerBlock == 0) {
|
||||
nThreadsPerBlock = 1024;
|
||||
nBlocks = 64;
|
||||
}
|
||||
allreduceRsAg<OpType, T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
|
||||
(T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, switchChannel, remoteMemories, rank,
|
||||
nRanksPerNode, worldSize, nelems);
|
||||
return cudaGetLastError();
|
||||
}
|
||||
};
|
||||
|
||||
void AllreduceRsAg::initialize(std::shared_ptr<Communicator> comm) {
|
||||
this->conns_ = setupConnections(comm);
|
||||
nChannelsPerConnection_ = 64;
|
||||
comm_ = comm;
|
||||
// setup semaphores
|
||||
this->scratchSemaphores_ = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_);
|
||||
RegisteredMemory localMemory = comm->registerMemory(scratchBuffer_, scratchBufferSize_, Transport::CudaIpc);
|
||||
this->remoteScratchMemories_ = setupRemoteMemories(comm, comm->bootstrap()->getRank(), localMemory);
|
||||
localScratchMemory_ = std::move(localMemory);
|
||||
|
||||
this->baseChannels_ = setupBaseMemoryChannels(this->conns_, this->scratchSemaphores_, nChannelsPerConnection_);
|
||||
this->baseMemoryChannelHandles_ = setupBaseMemoryChannelDeviceHandles(baseChannels_);
|
||||
std::vector<void*> remoteMemoryHandles;
|
||||
for (const auto& remoteMemory : this->remoteScratchMemories_) {
|
||||
remoteMemoryHandles.push_back(remoteMemory.data());
|
||||
}
|
||||
this->remoteMemoryHandles_ = detail::gpuCallocShared<void*>(remoteMemoryHandles.size());
|
||||
gpuMemcpy(this->remoteMemoryHandles_.get(), remoteMemoryHandles.data(), remoteMemoryHandles.size(),
|
||||
cudaMemcpyHostToDevice);
|
||||
}
|
||||
|
||||
CommResult AllreduceRsAg::allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output,
|
||||
size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream,
|
||||
int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>&) {
|
||||
auto algoCtx = std::static_pointer_cast<AlgorithmCtx>(ctx);
|
||||
AllreduceFunc allreduce = dispatch<AllreduceRsAgAdapter>(op, dtype);
|
||||
if (!allreduce) {
|
||||
WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast<int>(op),
|
||||
", dtype=", static_cast<int>(dtype));
|
||||
return CommResult::CommInvalidArgument;
|
||||
}
|
||||
if (inputSize > this->scratchBufferSize_) {
|
||||
WARN(ALGO, "Input size ", inputSize, " exceeds scratch buffer size ", this->scratchBufferSize_);
|
||||
return CommResult::CommInvalidArgument;
|
||||
}
|
||||
std::pair<int, int> numBlocksAndThreads = {nBlocks, nThreadsPerBlock};
|
||||
cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->baseMemoryChannelHandles_.get(),
|
||||
this->remoteMemoryHandles_.get(), nullptr, nullptr, 0, 0, 0, algoCtx->rank,
|
||||
algoCtx->nRanksPerNode, algoCtx->workSize, inputSize, stream, nullptr, 0, 0,
|
||||
numBlocksAndThreads.first, numBlocksAndThreads.second);
|
||||
if (error != cudaSuccess) {
|
||||
WARN(ALGO, "Allreduce kernel launch failed with error: ", cudaGetErrorString(error));
|
||||
return CommResult::CommUnhandledCudaError;
|
||||
}
|
||||
return CommResult::CommSuccess;
|
||||
}
|
||||
|
||||
AlgorithmCtxKey AllreduceRsAg::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) {
|
||||
return AlgorithmCtxKey{nullptr, nullptr, 0, 0, 0};
|
||||
}
|
||||
|
||||
std::shared_ptr<void> AllreduceRsAg::initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void*,
|
||||
size_t, DataType) {
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
|
||||
|
||||
ctx->memorySemaphores = this->scratchSemaphores_;
|
||||
ctx->registeredMemories = this->remoteScratchMemories_;
|
||||
return ctx;
|
||||
}
|
||||
|
||||
std::shared_ptr<Algorithm> AllreduceRsAg::build() {
|
||||
auto self = std::make_shared<AllreduceRsAg>((uintptr_t)scratchBuffer_, scratchBufferSize_);
|
||||
return std::make_shared<NativeAlgorithm>(
|
||||
"default_allreduce_rsag", "allreduce",
|
||||
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
|
||||
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) -> CommResult {
|
||||
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
|
||||
extras);
|
||||
},
|
||||
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize,
|
||||
DataType dtype) { return self->initAllreduceContext(comm, input, output, inputSize, dtype); },
|
||||
[self](const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype,
|
||||
bool symmetricMemory) {
|
||||
return self->generateAllreduceContextKey(input, output, inputSize, dtype, symmetricMemory);
|
||||
});
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace mscclpp
|
||||
336
src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu
Normal file
336
src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu
Normal file
@@ -0,0 +1,336 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "allreduce/allreduce_rsag_pipeline.hpp"
|
||||
#include "allreduce/common.hpp"
|
||||
#include "collective_utils.hpp"
|
||||
#include "logger.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
namespace collective {
|
||||
constexpr int MAX_NBLOCKS_FOR_PUT = 32;
|
||||
constexpr int MAX_NBLOCKS_FOR_RECV = 32;
|
||||
constexpr int MAX_NBLOCKS_FOR_REDUCE = 64;
|
||||
constexpr int REDUCE_COPY_RATIO = 2;
|
||||
__device__ DeviceSemaphore semaphoreForSend[MAX_NBLOCKS_FOR_REDUCE];
|
||||
__device__ DeviceSemaphore semaphoreForRecv[MAX_NBLOCKS_FOR_REDUCE];
|
||||
__device__ DeviceSemaphore semaphoreForReduce[MAX_NBLOCKS_FOR_REDUCE];
|
||||
|
||||
// TODO: move it to a common header file
|
||||
template <typename T>
|
||||
__device__ __forceinline__ int4 loadVec(const T* buff, size_t i, size_t nelems) {
|
||||
constexpr size_t ElemsPerInt4 = sizeof(int4) / sizeof(T);
|
||||
size_t offset = i * ElemsPerInt4;
|
||||
if (offset + ElemsPerInt4 <= nelems) {
|
||||
return reinterpret_cast<const int4*>(buff)[i];
|
||||
} else {
|
||||
union {
|
||||
int4 i;
|
||||
T t[ElemsPerInt4];
|
||||
} vec;
|
||||
vec.i = make_int4(0, 0, 0, 0);
|
||||
for (size_t j = 0; j < ElemsPerInt4 && offset + j < nelems; ++j) {
|
||||
vec.t[j] = buff[offset + j];
|
||||
}
|
||||
return vec.i;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void storeVec(T* buff, size_t i, int4 val, size_t nelems) {
|
||||
constexpr size_t ElemsPerInt4 = sizeof(int4) / sizeof(T);
|
||||
size_t offset = i * ElemsPerInt4;
|
||||
if (offset + ElemsPerInt4 <= nelems) {
|
||||
reinterpret_cast<int4*>(buff)[i] = val;
|
||||
} else {
|
||||
union {
|
||||
int4 i;
|
||||
T t[ElemsPerInt4];
|
||||
} vec;
|
||||
vec.i = val;
|
||||
for (size_t j = 0; j < ElemsPerInt4 && offset + j < nelems; ++j) {
|
||||
buff[offset + j] = vec.t[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pipelined Reduce-Scatter + All-Gather (RSAG) allreduce.
|
||||
//
|
||||
// This is a pipelined variant of the basic RSAG allreduce that overlaps
|
||||
// communication and computation by splitting the data into chunks processed
|
||||
// across multiple iterations. Three groups of thread blocks run concurrently
|
||||
// with different roles, synchronized via device semaphores:
|
||||
//
|
||||
// PUT blocks — Read local input chunks and write them into peers' scratch
|
||||
// buffers via remote memory handles (CudaIpc).
|
||||
//
|
||||
// REDUCE blocks — After a signal/wait barrier confirming PUT completion,
|
||||
// reduce the local chunk with data received from all peers
|
||||
// in the scratch buffer. Write the reduced result to both
|
||||
// the local output and peers' scratch (for the AG phase).
|
||||
//
|
||||
// RECV blocks — After a signal/wait barrier confirming REDUCE completion,
|
||||
// copy other ranks' reduced chunks from scratch into the
|
||||
// local result buffer, completing the all-gather.
|
||||
//
|
||||
// Pipelining is achieved by using a circular scratch buffer (pipelineDepth
|
||||
// stages). PUT blocks wait on a semaphore before reusing a scratch slot,
|
||||
// allowing the next iteration's PUT to overlap with the current iteration's
|
||||
// REDUCE and RECV. Each REDUCE block handles a subset of the PUT block's
|
||||
// data (controlled by REDUCE_COPY_RATIO), enabling finer-grained overlap.
|
||||
//
|
||||
// Data is processed in int4-sized (16-byte) units with vectorized load/store
|
||||
// helpers that handle tail elements.
|
||||
|
||||
template <ReduceOp OpType, typename T>
|
||||
__global__ void __launch_bounds__(1024, 1)
|
||||
allreduceRsAgPipeline(T* buff, T* scratch, T* resultBuff, DeviceHandle<BaseMemoryChannel>* memoryChannels,
|
||||
DeviceHandle<SwitchChannel>* switchChannels, void* remoteMemories, int rank,
|
||||
int nRanksPerNode, int worldSize, size_t nelems, size_t scratchSize, uint32_t nblocksForPut,
|
||||
uint32_t nblocksForReduce, uint32_t nblocksForRecv) {
|
||||
uint32_t bid = blockIdx.x;
|
||||
constexpr uint32_t nStepsPerIter = 4;
|
||||
uint32_t nInt4 = (nelems * sizeof(T) + sizeof(int4) - 1) / sizeof(int4);
|
||||
uint32_t nInt4PerIter = nblocksForReduce * blockDim.x * nStepsPerIter;
|
||||
const uint32_t chunkSize = nInt4PerIter * worldSize;
|
||||
uint32_t nIters = (nInt4 + chunkSize - 1) / chunkSize;
|
||||
uint32_t nPeers = nRanksPerNode - 1;
|
||||
int4* scratch4 = reinterpret_cast<int4*>((char*)scratch);
|
||||
const uint32_t scratchIterStride = 2 * chunkSize; // one for AS, one for AG
|
||||
const uint32_t pipelineDepth = scratchSize / sizeof(int4) / scratchIterStride;
|
||||
assert(pipelineDepth >= 1);
|
||||
|
||||
if (bid < nblocksForPut) {
|
||||
if (threadIdx.x == 0) {
|
||||
semaphoreForSend[bid].set(pipelineDepth);
|
||||
}
|
||||
for (uint32_t iter = 0; iter < nIters; iter++) {
|
||||
if (threadIdx.x == 0) {
|
||||
semaphoreForSend[bid].acquire();
|
||||
}
|
||||
__syncthreads();
|
||||
uint32_t threadIdInPut = bid * blockDim.x + threadIdx.x;
|
||||
for (uint32_t peer = 0; peer < nPeers; peer++) {
|
||||
int remoteRankId = (rank + peer + 1) % nRanksPerNode;
|
||||
int peerId = remoteRankId < rank ? remoteRankId : remoteRankId - 1;
|
||||
// Read chunk[remoteRankId] from local buff, write to peer's scratch[rank] (sender's slot)
|
||||
uint32_t srcOffset = iter * chunkSize + remoteRankId * nInt4PerIter;
|
||||
uint32_t dstOffset = (iter % pipelineDepth) * scratchIterStride + rank * nInt4PerIter;
|
||||
int4 tmp[nStepsPerIter * REDUCE_COPY_RATIO];
|
||||
#pragma unroll
|
||||
for (uint32_t step = 0; step < nStepsPerIter * REDUCE_COPY_RATIO; step++) {
|
||||
uint32_t offset = srcOffset + threadIdInPut + step * blockDim.x * nblocksForPut;
|
||||
tmp[step] = loadVec(buff, offset, nelems);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t step = 0; step < nStepsPerIter * REDUCE_COPY_RATIO; step++) {
|
||||
uint32_t offset = dstOffset + threadIdInPut + step * blockDim.x * nblocksForPut;
|
||||
mscclpp::write<int4>(((void**)remoteMemories)[peerId], offset, tmp[step]);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.x < REDUCE_COPY_RATIO) {
|
||||
semaphoreForReduce[bid * REDUCE_COPY_RATIO + threadIdx.x].release();
|
||||
}
|
||||
}
|
||||
} else if (bid < nblocksForPut + nblocksForReduce) {
|
||||
uint32_t bidInReduce = bid - nblocksForPut;
|
||||
DeviceHandle<BaseMemoryChannel>* localMemoryChannels = memoryChannels + bidInReduce * nPeers;
|
||||
// Map REDUCE blocks to PUT blocks: REDUCE blocks 0,1 handle PUT block 0's data
|
||||
uint32_t putBlockId = bidInReduce / REDUCE_COPY_RATIO;
|
||||
uint32_t subBlockId = bidInReduce % REDUCE_COPY_RATIO;
|
||||
for (uint32_t iter = 0; iter < nIters; iter++) {
|
||||
if (threadIdx.x == 0) {
|
||||
semaphoreForReduce[bidInReduce].acquire();
|
||||
}
|
||||
uint32_t baseOffset = (iter % pipelineDepth) * scratchIterStride;
|
||||
uint32_t baseSrcOffset = iter * chunkSize;
|
||||
|
||||
// Use same thread mapping as PUT: putBlockId * blockDim.x + threadIdx.x
|
||||
uint32_t threadIdInPut = putBlockId * blockDim.x + threadIdx.x;
|
||||
__syncthreads();
|
||||
if (threadIdx.x < nPeers) {
|
||||
localMemoryChannels[threadIdx.x].signal();
|
||||
localMemoryChannels[threadIdx.x].wait();
|
||||
}
|
||||
__syncthreads();
|
||||
#pragma unroll nStepsPerIter
|
||||
for (uint32_t step = 0; step < nStepsPerIter; step++) {
|
||||
// Map to PUT's step pattern: each REDUCE block handles nStepsPerIter steps
|
||||
// subBlockId determines which subset of the REDUCE_COPY_RATIO * nStepsPerIter steps
|
||||
uint32_t putStep = subBlockId * nStepsPerIter + step;
|
||||
uint32_t myChunkOffset =
|
||||
baseSrcOffset + rank * nInt4PerIter + threadIdInPut + putStep * blockDim.x * nblocksForPut;
|
||||
int4 tmp = loadVec(buff, myChunkOffset, nelems);
|
||||
// Add data from each peer's slot in scratch (peer sent their chunk[rank] to our scratch[peer])
|
||||
for (uint32_t peer = 0; peer < nPeers; peer++) {
|
||||
int remoteRankId = (rank + peer + 1) % nRanksPerNode;
|
||||
uint32_t peerSlotOffset =
|
||||
baseOffset + remoteRankId * nInt4PerIter + threadIdInPut + putStep * blockDim.x * nblocksForPut;
|
||||
int4 data = scratch4[peerSlotOffset];
|
||||
tmp = cal_vector<T, OpType>(data, tmp);
|
||||
}
|
||||
storeVec(resultBuff, myChunkOffset, tmp, nelems);
|
||||
// Broadcast reduced result to all peers' scratch at SCATTER_AG_OFFSET + rank * nInt4PerIter
|
||||
uint32_t dstOffset =
|
||||
baseOffset + chunkSize + rank * nInt4PerIter + threadIdInPut + putStep * blockDim.x * nblocksForPut;
|
||||
for (uint32_t i = 0; i < nPeers; i++) {
|
||||
int peerIdx = (rank + i + 1) % nRanksPerNode;
|
||||
int index = peerIdx < rank ? peerIdx : peerIdx - 1;
|
||||
mscclpp::write<int4>(((void**)remoteMemories)[index], dstOffset, tmp);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
semaphoreForRecv[bidInReduce].release();
|
||||
}
|
||||
}
|
||||
} else if (bid < nblocksForPut + nblocksForReduce + nblocksForRecv) {
|
||||
uint32_t bidInRecv = bid - nblocksForPut - nblocksForReduce;
|
||||
DeviceHandle<BaseMemoryChannel>* localMemoryChannels = memoryChannels + (nblocksForReduce + bidInRecv) * nPeers;
|
||||
for (uint32_t iter = 0; iter < nIters; iter++) {
|
||||
if (threadIdx.x < REDUCE_COPY_RATIO) {
|
||||
semaphoreForRecv[bidInRecv * REDUCE_COPY_RATIO + threadIdx.x].acquire();
|
||||
}
|
||||
uint32_t baseOffset = scratchIterStride * (iter % pipelineDepth);
|
||||
uint32_t baseDstOffset = chunkSize * iter;
|
||||
int threadIdInRecv = bidInRecv * blockDim.x + threadIdx.x;
|
||||
__syncthreads();
|
||||
if (threadIdx.x < nPeers) {
|
||||
localMemoryChannels[threadIdx.x].signal();
|
||||
localMemoryChannels[threadIdx.x].wait();
|
||||
}
|
||||
__syncthreads();
|
||||
// Copy other ranks' reduced chunks from scratch to result
|
||||
for (uint32_t peer = 0; peer < nPeers; peer++) {
|
||||
int remoteRankId = (rank + peer + 1) % nRanksPerNode;
|
||||
for (uint32_t step = 0; step < nStepsPerIter * REDUCE_COPY_RATIO; step++) {
|
||||
uint32_t offset = baseOffset + chunkSize + remoteRankId * nInt4PerIter + threadIdInRecv +
|
||||
step * blockDim.x * nblocksForRecv;
|
||||
uint32_t dstOffset =
|
||||
baseDstOffset + remoteRankId * nInt4PerIter + threadIdInRecv + step * blockDim.x * nblocksForRecv;
|
||||
storeVec(resultBuff, dstOffset, scratch4[offset], nelems);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
semaphoreForSend[bidInRecv].release();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <ReduceOp OpType, typename T>
|
||||
struct AllreduceRsAgPipelineAdapter {
|
||||
static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories,
|
||||
DeviceHandle<SwitchChannel>* switchChannel, DeviceHandle<SwitchChannel>*, size_t, size_t,
|
||||
size_t scratchSize, int rank, int nRanksPerNode, int worldSize, size_t inputSize,
|
||||
cudaStream_t stream, void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) {
|
||||
using ChannelType = DeviceHandle<BaseMemoryChannel>;
|
||||
size_t nelems = inputSize / sizeof(T);
|
||||
uint32_t nblocksForPut = MAX_NBLOCKS_FOR_PUT;
|
||||
uint32_t nblocksForReduce = MAX_NBLOCKS_FOR_REDUCE;
|
||||
uint32_t nblocksForRecv = MAX_NBLOCKS_FOR_RECV;
|
||||
int maxNblocks = nblocksForPut + nblocksForReduce + nblocksForRecv;
|
||||
if (nBlocks == 0 || nThreadsPerBlock == 0) {
|
||||
nThreadsPerBlock = 1024;
|
||||
nBlocks = maxNblocks;
|
||||
} else {
|
||||
nBlocks = nBlocks / (REDUCE_COPY_RATIO + 2) * (REDUCE_COPY_RATIO + 2);
|
||||
if (nBlocks > maxNblocks) {
|
||||
WARN(ALGO, "The number of blocks is too large for the allreduce pipeline algorithm, reducing it to ",
|
||||
maxNblocks);
|
||||
nBlocks = maxNblocks;
|
||||
}
|
||||
nblocksForPut = nBlocks / (REDUCE_COPY_RATIO + 2);
|
||||
nblocksForReduce = nblocksForPut * REDUCE_COPY_RATIO;
|
||||
nblocksForRecv = nblocksForPut;
|
||||
}
|
||||
allreduceRsAgPipeline<OpType, T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
|
||||
(T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, switchChannel, remoteMemories, rank,
|
||||
nRanksPerNode, worldSize, nelems, scratchSize, nblocksForPut, nblocksForReduce, nblocksForRecv);
|
||||
return cudaGetLastError();
|
||||
}
|
||||
};
|
||||
|
||||
void AllreduceRsAgPipeline::initialize(std::shared_ptr<Communicator> comm) {
|
||||
this->conns_ = setupConnections(comm);
|
||||
nChannelsPerConnection_ = MAX_NBLOCKS_FOR_REDUCE + MAX_NBLOCKS_FOR_RECV;
|
||||
comm_ = comm;
|
||||
// setup semaphores
|
||||
this->scratchSemaphores_ = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_);
|
||||
RegisteredMemory localMemory = comm->registerMemory(scratchBuffer_, scratchBufferSize_, Transport::CudaIpc);
|
||||
this->remoteScratchMemories_ = setupRemoteMemories(comm, comm->bootstrap()->getRank(), localMemory);
|
||||
localScratchMemory_ = std::move(localMemory);
|
||||
|
||||
this->baseChannels_ = setupBaseMemoryChannels(this->conns_, this->scratchSemaphores_, nChannelsPerConnection_);
|
||||
this->baseMemoryChannelHandles_ = setupBaseMemoryChannelDeviceHandles(baseChannels_);
|
||||
std::vector<void*> remoteMemoryHandles;
|
||||
for (const auto& remoteMemory : this->remoteScratchMemories_) {
|
||||
remoteMemoryHandles.push_back(remoteMemory.data());
|
||||
}
|
||||
this->remoteMemoryHandles_ = detail::gpuCallocShared<void*>(remoteMemoryHandles.size());
|
||||
gpuMemcpy(this->remoteMemoryHandles_.get(), remoteMemoryHandles.data(), remoteMemoryHandles.size(),
|
||||
cudaMemcpyHostToDevice);
|
||||
}
|
||||
|
||||
CommResult AllreduceRsAgPipeline::allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output,
|
||||
size_t inputSize, DataType dtype, ReduceOp op,
|
||||
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>&) {
|
||||
auto algoCtx = std::static_pointer_cast<AlgorithmCtx>(ctx);
|
||||
AllreduceFunc allreduce = dispatch<AllreduceRsAgPipelineAdapter>(op, dtype);
|
||||
if (!allreduce) {
|
||||
WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast<int>(op),
|
||||
", dtype=", static_cast<int>(dtype));
|
||||
return CommResult::CommInvalidArgument;
|
||||
}
|
||||
std::pair<int, int> numBlocksAndThreads = {nBlocks, nThreadsPerBlock};
|
||||
cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->baseMemoryChannelHandles_.get(),
|
||||
this->remoteMemoryHandles_.get(), nullptr, nullptr, 0, 0, this->scratchBufferSize_,
|
||||
algoCtx->rank, algoCtx->nRanksPerNode, algoCtx->workSize, inputSize, stream, nullptr, 0,
|
||||
0, numBlocksAndThreads.first, numBlocksAndThreads.second);
|
||||
if (error != cudaSuccess) {
|
||||
WARN(ALGO, "Allreduce kernel launch failed with error: ", cudaGetErrorString(error));
|
||||
return CommResult::CommUnhandledCudaError;
|
||||
}
|
||||
return CommResult::CommSuccess;
|
||||
}
|
||||
|
||||
AlgorithmCtxKey AllreduceRsAgPipeline::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) {
|
||||
return AlgorithmCtxKey{nullptr, nullptr, 0, 0, 0};
|
||||
}
|
||||
|
||||
std::shared_ptr<void> AllreduceRsAgPipeline::initAllreduceContext(std::shared_ptr<Communicator> comm, const void*,
|
||||
void*, size_t, DataType) {
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
|
||||
|
||||
ctx->memorySemaphores = this->scratchSemaphores_;
|
||||
ctx->registeredMemories = this->remoteScratchMemories_;
|
||||
return ctx;
|
||||
}
|
||||
|
||||
std::shared_ptr<Algorithm> AllreduceRsAgPipeline::build() {
|
||||
auto self = std::make_shared<AllreduceRsAgPipeline>((uintptr_t)scratchBuffer_, scratchBufferSize_);
|
||||
return std::make_shared<NativeAlgorithm>(
|
||||
"default_allreduce_rsag_pipeline", "allreduce",
|
||||
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
|
||||
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) -> CommResult {
|
||||
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
|
||||
extras);
|
||||
},
|
||||
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize,
|
||||
DataType dtype) { return self->initAllreduceContext(comm, input, output, inputSize, dtype); },
|
||||
[self](const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype,
|
||||
bool symmetricMemory) {
|
||||
return self->generateAllreduceContextKey(input, output, inputSize, dtype, symmetricMemory);
|
||||
});
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace mscclpp
|
||||
236
src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu
Normal file
236
src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu
Normal file
@@ -0,0 +1,236 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "allreduce/allreduce_rsag_zero_copy.hpp"
|
||||
#include "allreduce/common.hpp"
|
||||
#include "collective_utils.hpp"
|
||||
#include "logger.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
namespace collective {
|
||||
|
||||
__device__ mscclpp::DeviceSyncer globalSyncer;
|
||||
|
||||
// Zero-copy Reduce-Scatter + All-Gather (RSAG) allreduce.
|
||||
//
|
||||
// Unlike the standard RSAG which copies input into a scratch buffer first,
|
||||
// this variant reads directly from peers' input buffers and writes reduced
|
||||
// results directly to peers' output buffers — eliminating the need for a
|
||||
// separate scratch buffer and reducing memory traffic.
|
||||
//
|
||||
// The algorithm runs in a single kernel with the following steps:
|
||||
//
|
||||
// 1. Barrier: Signal and wait on all peers to ensure input buffers are ready.
|
||||
//
|
||||
// 2. Reduce-Scatter: Each rank reads its assigned chunk from every peer's
|
||||
// input buffer (via CudaIpc remote memory handles), reduces all values
|
||||
// locally, then writes the reduced result to its own output buffer AND
|
||||
// directly to every peer's output buffer at the same offset.
|
||||
//
|
||||
// 3. Global sync + Barrier: A device-wide sync ensures all writes complete,
|
||||
// followed by a final signal/wait to guarantee all peers have finished
|
||||
// writing, making the full output buffer valid on every rank.
|
||||
//
|
||||
// This approach requires registering both input and output buffers as remote
|
||||
// memories (2 * nPeers handles), but avoids scratch buffer allocation and
|
||||
// the extra copy steps of the standard RSAG. The NRanksPerNode template
|
||||
// parameter enables compile-time unrolling of peer loops (supports 4 or 8).
|
||||
|
||||
template <int NRanksPerNode, ReduceOp OpType, typename T>
|
||||
__global__ void __launch_bounds__(1024, 1)
|
||||
allreduceRsAgZeroCopy(T* buff, T* scratch, T* resultBuff, DeviceHandle<BaseMemoryChannel>* memoryChannels,
|
||||
DeviceHandle<SwitchChannel>* switchChannels, void* remoteMemories, int rank, int worldSize,
|
||||
size_t nelems) {
|
||||
int blockId = blockIdx.x;
|
||||
|
||||
assert((uintptr_t)buff % sizeof(int4) == 0);
|
||||
assert((uintptr_t)resultBuff % sizeof(int4) == 0);
|
||||
|
||||
constexpr int NPeers = NRanksPerNode - 1;
|
||||
constexpr uint32_t nelemsPerInt4 = sizeof(int4) / sizeof(T);
|
||||
const uint32_t outputRemoteBufferOffset = NRanksPerNode - 1;
|
||||
uint32_t alignedNelems = ((nelems + NRanksPerNode - 1) / NRanksPerNode + nelemsPerInt4 - 1) / nelemsPerInt4 *
|
||||
nelemsPerInt4 * NRanksPerNode;
|
||||
uint32_t nelemsPerRank = alignedNelems / NRanksPerNode;
|
||||
uint32_t nInt4PerRank = nelemsPerRank / nelemsPerInt4;
|
||||
uint32_t nInt4Total = (nelems + nelemsPerInt4 - 1) / nelemsPerInt4;
|
||||
|
||||
int4* resultBuff4 = reinterpret_cast<int4*>((char*)resultBuff);
|
||||
int4* buff4 = reinterpret_cast<int4*>((char*)buff);
|
||||
DeviceHandle<BaseMemoryChannel>* memoryChannelsLocal = memoryChannels + blockId * NPeers;
|
||||
|
||||
uint32_t nInt4PerBlock = nInt4PerRank / gridDim.x;
|
||||
uint32_t remainderForBlock = nInt4PerRank % gridDim.x;
|
||||
uint32_t offset4 = blockId * nInt4PerBlock;
|
||||
if (blockId == (int)(gridDim.x - 1)) {
|
||||
nInt4PerBlock += remainderForBlock;
|
||||
}
|
||||
if (nInt4PerBlock == 0) return;
|
||||
|
||||
if (threadIdx.x < NPeers) {
|
||||
memoryChannelsLocal[threadIdx.x].relaxedSignal();
|
||||
memoryChannelsLocal[threadIdx.x].relaxedWait();
|
||||
}
|
||||
__syncthreads();
|
||||
int4 data[NPeers];
|
||||
for (uint32_t idx = threadIdx.x; idx < nInt4PerBlock; idx += blockDim.x) {
|
||||
uint32_t offset = idx + offset4 + rank * nInt4PerRank;
|
||||
if (offset >= nInt4Total) continue;
|
||||
int4 tmp = buff4[offset];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NPeers; i++) {
|
||||
int rankIdx = (rank + i + 1) % NRanksPerNode;
|
||||
int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1;
|
||||
data[i] = mscclpp::read<int4>(((void**)remoteMemories)[peerIdx], offset);
|
||||
}
|
||||
for (int i = 0; i < NPeers; i++) {
|
||||
tmp = cal_vector<T, OpType>(data[i], tmp);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NPeers; i++) {
|
||||
int rankIdx = (rank + i + 1) % NRanksPerNode;
|
||||
int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1;
|
||||
mscclpp::write<int4>(((void**)remoteMemories)[outputRemoteBufferOffset + peerIdx], offset, tmp);
|
||||
}
|
||||
resultBuff4[offset] = tmp;
|
||||
}
|
||||
// Use device barrier gives better performance here.
|
||||
globalSyncer.sync(gridDim.x);
|
||||
if (blockIdx.x == 0 && threadIdx.x < NPeers) {
|
||||
memoryChannelsLocal[threadIdx.x].signal();
|
||||
memoryChannelsLocal[threadIdx.x].wait();
|
||||
}
|
||||
}
|
||||
|
||||
template <ReduceOp OpType, typename T>
|
||||
struct AllreduceRsAgZeroCopyAdapter {
|
||||
static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories,
|
||||
DeviceHandle<SwitchChannel>* switchChannel, DeviceHandle<SwitchChannel>*, size_t, size_t,
|
||||
size_t, int rank, int nRanksPerNode, int worldSize, size_t inputSize, cudaStream_t stream,
|
||||
void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) {
|
||||
using ChannelType = DeviceHandle<BaseMemoryChannel>;
|
||||
size_t nelems = inputSize / sizeof(T);
|
||||
if (nBlocks == 0 || nThreadsPerBlock == 0) {
|
||||
nThreadsPerBlock = 1024;
|
||||
nBlocks = 64;
|
||||
if (inputSize >= (1 << 26)) {
|
||||
nBlocks = 128;
|
||||
}
|
||||
}
|
||||
if (nRanksPerNode == 4) {
|
||||
allreduceRsAgZeroCopy<4, OpType, T>
|
||||
<<<nBlocks, nThreadsPerBlock, 0, stream>>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels,
|
||||
switchChannel, remoteMemories, rank, worldSize, nelems);
|
||||
} else if (nRanksPerNode == 8) {
|
||||
allreduceRsAgZeroCopy<8, OpType, T>
|
||||
<<<nBlocks, nThreadsPerBlock, 0, stream>>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels,
|
||||
switchChannel, remoteMemories, rank, worldSize, nelems);
|
||||
} else {
|
||||
THROW(ALGO, Error, ErrorCode::InvalidUsage, "Unsupported number of ranks per node: ", nRanksPerNode);
|
||||
}
|
||||
return cudaGetLastError();
|
||||
}
|
||||
};
|
||||
|
||||
void AllreduceRsAgZeroCopy::initialize(std::shared_ptr<Communicator> comm) {
|
||||
this->conns_ = setupConnections(comm);
|
||||
nChannelsPerConnection_ = 128;
|
||||
comm_ = comm;
|
||||
// setup semaphores
|
||||
this->semaphores_ = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_);
|
||||
this->baseChannels_ = setupBaseMemoryChannels(this->conns_, this->semaphores_, nChannelsPerConnection_);
|
||||
this->baseMemoryChannelHandles_ = setupBaseMemoryChannelDeviceHandles(baseChannels_);
|
||||
}
|
||||
|
||||
CommResult AllreduceRsAgZeroCopy::allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output,
|
||||
size_t inputSize, DataType dtype, ReduceOp op,
|
||||
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>&) {
|
||||
auto algoCtx = std::static_pointer_cast<AlgorithmCtx>(ctx);
|
||||
AllreduceFunc allreduce = dispatch<AllreduceRsAgZeroCopyAdapter>(op, dtype);
|
||||
if (!allreduce) {
|
||||
WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast<int>(op),
|
||||
", dtype=", static_cast<int>(dtype));
|
||||
return CommResult::CommInvalidArgument;
|
||||
}
|
||||
std::pair<int, int> numBlocksAndThreads = {nBlocks, nThreadsPerBlock};
|
||||
cudaError_t error =
|
||||
allreduce(input, nullptr, output, this->baseMemoryChannelHandles_.get(), algoCtx->remoteMemoryHandles.get(),
|
||||
nullptr, nullptr, 0, 0, 0, algoCtx->rank, algoCtx->nRanksPerNode, algoCtx->workSize, inputSize, stream,
|
||||
nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second);
|
||||
if (error != cudaSuccess) {
|
||||
WARN(ALGO, "Allreduce kernel launch failed with error: ", cudaGetErrorString(error));
|
||||
return CommResult::CommUnhandledCudaError;
|
||||
}
|
||||
return CommResult::CommSuccess;
|
||||
}
|
||||
|
||||
AlgorithmCtxKey AllreduceRsAgZeroCopy::generateAllreduceContextKey(const void* inputBuffer, void* outputBuffer,
|
||||
size_t size, DataType, bool symmetricMemory) {
|
||||
// For non-symmetric algorithms, we use both input and output buffer pointers in the key.
|
||||
static int tag = 0;
|
||||
if (symmetricMemory) {
|
||||
size_t inputBytes, outputBytes;
|
||||
CUdeviceptr inputBasePtr, outputBasePtr;
|
||||
MSCCLPP_CUTHROW(cuMemGetAddressRange(&inputBasePtr, &inputBytes, (CUdeviceptr)inputBuffer));
|
||||
MSCCLPP_CUTHROW(cuMemGetAddressRange(&outputBasePtr, &outputBytes, (CUdeviceptr)outputBuffer));
|
||||
return AlgorithmCtxKey{(void*)inputBasePtr, (void*)outputBasePtr, inputBytes, outputBytes, 0};
|
||||
}
|
||||
return AlgorithmCtxKey{(void*)inputBuffer, outputBuffer, size, size, ++tag};
|
||||
}
|
||||
|
||||
std::shared_ptr<void> AllreduceRsAgZeroCopy::initAllreduceContext(std::shared_ptr<Communicator> comm, const void* input,
|
||||
void* output, size_t size, DataType) {
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
|
||||
|
||||
ctx->memorySemaphores = this->semaphores_;
|
||||
|
||||
// register input and output memories
|
||||
RegisteredMemory inputMemory = comm->registerMemory((void*)input, size, Transport::CudaIpc);
|
||||
RegisteredMemory outputMemory = comm->registerMemory(output, size, Transport::CudaIpc);
|
||||
this->inputMemories_.push_back(inputMemory);
|
||||
this->outputMemories_.push_back(outputMemory);
|
||||
|
||||
auto remoteInputMemories = setupRemoteMemories(comm, ctx->rank, inputMemory);
|
||||
auto remoteOutputMemories = setupRemoteMemories(comm, ctx->rank, outputMemory);
|
||||
ctx->registeredMemories.insert(ctx->registeredMemories.end(), remoteInputMemories.begin(), remoteInputMemories.end());
|
||||
ctx->registeredMemories.insert(ctx->registeredMemories.end(), remoteOutputMemories.begin(),
|
||||
remoteOutputMemories.end());
|
||||
std::vector<void*> remoteMemoryHandles;
|
||||
for (const auto& remoteMemory : ctx->registeredMemories) {
|
||||
remoteMemoryHandles.push_back(remoteMemory.data());
|
||||
}
|
||||
ctx->remoteMemoryHandles = detail::gpuCallocShared<void*>(remoteMemoryHandles.size());
|
||||
gpuMemcpy(ctx->remoteMemoryHandles.get(), remoteMemoryHandles.data(), remoteMemoryHandles.size(),
|
||||
cudaMemcpyHostToDevice);
|
||||
|
||||
// store local registered memories to ctx for lifetime management
|
||||
ctx->registeredMemories.push_back(inputMemory);
|
||||
ctx->registeredMemories.push_back(outputMemory);
|
||||
return ctx;
|
||||
}
|
||||
|
||||
std::shared_ptr<Algorithm> AllreduceRsAgZeroCopy::build() {
|
||||
auto self = std::make_shared<AllreduceRsAgZeroCopy>();
|
||||
return std::make_shared<NativeAlgorithm>(
|
||||
"default_allreduce_rsag_zero_copy", "allreduce",
|
||||
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
|
||||
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) -> CommResult {
|
||||
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
|
||||
extras);
|
||||
},
|
||||
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize,
|
||||
DataType dtype) { return self->initAllreduceContext(comm, input, output, inputSize, dtype); },
|
||||
[self](const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype,
|
||||
bool symmetricMemory) {
|
||||
return self->generateAllreduceContextKey(input, output, inputSize, dtype, symmetricMemory);
|
||||
});
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace mscclpp
|
||||
43
src/ext/collectives/include/allreduce/allreduce_rsag.hpp
Normal file
43
src/ext/collectives/include/allreduce/allreduce_rsag.hpp
Normal file
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef MSCCLPP_EXT_ALLREDUCE_RSAG_HPP_
|
||||
#define MSCCLPP_EXT_ALLREDUCE_RSAG_HPP_
|
||||
|
||||
#include <mscclpp/algorithm.hpp>
|
||||
|
||||
namespace mscclpp {
|
||||
namespace collective {
|
||||
|
||||
class AllreduceRsAg : public mscclpp::AlgorithmBuilder {
|
||||
public:
|
||||
AllreduceRsAg(uintptr_t scratchBuffer, size_t scratchBufferSize)
|
||||
: scratchBuffer_((void*)scratchBuffer), scratchBufferSize_(scratchBufferSize){};
|
||||
std::shared_ptr<mscclpp::Algorithm> build() override;
|
||||
|
||||
private:
|
||||
void initialize(std::shared_ptr<Communicator> comm);
|
||||
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras);
|
||||
|
||||
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
|
||||
DataType);
|
||||
AlgorithmCtxKey generateAllreduceContextKey(const void*, void*, size_t, DataType, bool);
|
||||
void* scratchBuffer_;
|
||||
size_t scratchBufferSize_;
|
||||
std::shared_ptr<Communicator> comm_;
|
||||
int nChannelsPerConnection_;
|
||||
std::vector<Connection> conns_;
|
||||
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> scratchSemaphores_;
|
||||
std::vector<RegisteredMemory> remoteScratchMemories_;
|
||||
RegisteredMemory localScratchMemory_;
|
||||
|
||||
std::vector<BaseMemoryChannel> baseChannels_;
|
||||
std::shared_ptr<DeviceHandle<BaseMemoryChannel>> baseMemoryChannelHandles_;
|
||||
std::shared_ptr<void*> remoteMemoryHandles_;
|
||||
};
|
||||
} // namespace collective
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_EXT_ALLREDUCE_RSAG_HPP_
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef MSCCLPP_EXT_ALLREDUCE_RSAG_PIPELINE_HPP_
|
||||
#define MSCCLPP_EXT_ALLREDUCE_RSAG_PIPELINE_HPP_
|
||||
|
||||
#include <mscclpp/algorithm.hpp>
|
||||
|
||||
namespace mscclpp {
|
||||
namespace collective {
|
||||
|
||||
class AllreduceRsAgPipeline : public mscclpp::AlgorithmBuilder {
|
||||
public:
|
||||
AllreduceRsAgPipeline(uintptr_t scratchBuffer, size_t scratchBufferSize)
|
||||
: scratchBuffer_((void*)scratchBuffer), scratchBufferSize_(scratchBufferSize){};
|
||||
std::shared_ptr<mscclpp::Algorithm> build() override;
|
||||
|
||||
private:
|
||||
void initialize(std::shared_ptr<Communicator> comm);
|
||||
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras);
|
||||
|
||||
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
|
||||
DataType);
|
||||
AlgorithmCtxKey generateAllreduceContextKey(const void*, void*, size_t, DataType, bool);
|
||||
void* scratchBuffer_;
|
||||
size_t scratchBufferSize_;
|
||||
std::shared_ptr<Communicator> comm_;
|
||||
int nChannelsPerConnection_;
|
||||
std::vector<Connection> conns_;
|
||||
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> scratchSemaphores_;
|
||||
std::vector<RegisteredMemory> remoteScratchMemories_;
|
||||
RegisteredMemory localScratchMemory_;
|
||||
|
||||
std::vector<BaseMemoryChannel> baseChannels_;
|
||||
std::shared_ptr<DeviceHandle<BaseMemoryChannel>> baseMemoryChannelHandles_;
|
||||
std::shared_ptr<void*> remoteMemoryHandles_;
|
||||
};
|
||||
} // namespace collective
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_EXT_ALLREDUCE_RSAG_PIPELINE_HPP_
|
||||
@@ -0,0 +1,39 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef MSCCLPP_EXT_ALLREDUCE_RSAG_ZERO_COPY_HPP_
|
||||
#define MSCCLPP_EXT_ALLREDUCE_RSAG_ZERO_COPY_HPP_
|
||||
|
||||
#include <mscclpp/algorithm.hpp>
|
||||
|
||||
namespace mscclpp {
|
||||
namespace collective {
|
||||
|
||||
class AllreduceRsAgZeroCopy : public mscclpp::AlgorithmBuilder {
|
||||
public:
|
||||
AllreduceRsAgZeroCopy() = default;
|
||||
std::shared_ptr<mscclpp::Algorithm> build() override;
|
||||
|
||||
private:
|
||||
void initialize(std::shared_ptr<Communicator> comm);
|
||||
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras);
|
||||
|
||||
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
|
||||
DataType);
|
||||
AlgorithmCtxKey generateAllreduceContextKey(const void*, void*, size_t, DataType, bool);
|
||||
std::shared_ptr<Communicator> comm_;
|
||||
int nChannelsPerConnection_;
|
||||
std::vector<Connection> conns_;
|
||||
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> semaphores_;
|
||||
std::vector<RegisteredMemory> inputMemories_;
|
||||
std::vector<RegisteredMemory> outputMemories_;
|
||||
|
||||
std::vector<BaseMemoryChannel> baseChannels_;
|
||||
std::shared_ptr<DeviceHandle<BaseMemoryChannel>> baseMemoryChannelHandles_;
|
||||
};
|
||||
} // namespace collective
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_EXT_ALLREDUCE_RSAG_ZERO_COPY_HPP_
|
||||
@@ -84,6 +84,7 @@ class AlgorithmCtx {
|
||||
std::shared_ptr<DeviceHandle<PortChannel>> portChannelDeviceHandles;
|
||||
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> memorySemaphores;
|
||||
std::vector<std::shared_ptr<Host2DeviceSemaphore>> hostSemaphores;
|
||||
std::shared_ptr<void*> remoteMemoryHandles;
|
||||
std::unordered_map<std::string, std::shared_ptr<void>> extras;
|
||||
};
|
||||
|
||||
|
||||
@@ -71,7 +71,15 @@ static std::shared_ptr<Algorithm> selectSingleNodeAllreduceBlackwell(
|
||||
if (messageSize <= (1 << 21)) { // <= 2MB
|
||||
return algoMap.at("default_allreduce_packet");
|
||||
}
|
||||
return nullptr;
|
||||
if (config.inCaptureMode) {
|
||||
// CUDA graph mode: setup new connections each time (zero-copy for graph)
|
||||
return algoMap.at("default_allreduce_rsag_zero_copy");
|
||||
}
|
||||
// Non-graph mode: use non-zero-copy algorithms
|
||||
if (messageSize <= (1 << 23)) { // <= 8MB
|
||||
return algoMap.at("default_allreduce_rsag");
|
||||
}
|
||||
return algoMap.at("default_allreduce_rsag_pipeline");
|
||||
}
|
||||
|
||||
// Symmetric memory path: can use cached memory handles
|
||||
@@ -83,8 +91,7 @@ static std::shared_ptr<Algorithm> selectSingleNodeAllreduceBlackwell(
|
||||
return algoMap.at("default_allreduce_nvls");
|
||||
}
|
||||
|
||||
INFO(MSCCLPP_NCCL, "No suitable kernel for Blackwell architecture, fallback to nccl/rccl");
|
||||
return nullptr;
|
||||
return algoMap.at("default_allreduce_rsag_zero_copy");
|
||||
}
|
||||
|
||||
std::shared_ptr<Algorithm> selectSingleNodeAllreduce(
|
||||
|
||||
Reference in New Issue
Block a user