Fix multicast handle leak, cuMemMap offset handling, and rename NVLS allreduce algorithms (#759)

## Summary

This PR addresses a multicast resource leak, fixes `cuMemMap` offset
handling for multicast handles, renames NVLS allreduce algorithm classes
for clarity, and adds a new unit test for `SwitchChannel`.

### Bug Fixes

#### 1. Fix multicast allocation handle leak in `createMulticast()`
(`gpu_ipc_mem.cc`)

`GpuIpcMemHandle::createMulticast()` called
`cuMulticastCreate(&allocHandle, ...)` but never released the local
`allocHandle` after exporting it to shareable handles (POSIX FD /
Fabric). This caused a reference count leak — the multicast object was
never freed even after all mappings and imported handles were released.

Per the [CUDA Driver API docs for
`cuMemRelease`](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html):
> *"The memory allocation will be freed when all outstanding mappings to
the memory are unmapped and when all outstanding references to the
handle (including its shareable counterparts) are also released."*

The fix adds `cuMemRelease(allocHandle)` after export, matching the
existing pattern used for regular allocations in
`GpuIpcMemHandle::create()`.

**Impact:** Without this fix, repeated creation/destruction of NVLS
connections causes OOM after ~120 iterations when allocating 1GB
multicast buffers on H100.

#### 2. Fix `cuMemMap` offset for multicast handles (`gpu_ipc_mem.cc`)

`cuMemMap` requires `offset=0` for multicast handles. Previously, the
code attempted to map at a non-zero offset within the multicast object,
leading to errors when binding multiple buffers to the same
`NvlsConnection`. The fix maps the entire range `[0, mcOffset +
bufferSize)` and returns the pointer offset by `mcOffset`. This only
consumes extra virtual address space; no additional physical memory is
used.

### Refactoring

#### 3. Rename NVLS allreduce algorithm classes

Renamed for clarity:
- `AllreduceNvls` → `AllreduceNvlsZeroCopy`
- `AllreduceNvlsWithCopy` → `AllreduceNvlsWarpPipeline`
- `AllreduceNvlsWithCopy2` → `AllreduceNvlsBlockPipeline`

Updated all references in builder, selector, docs, and examples.

#### 4. Move `nvlsConnections` setup to `initialize()`

Moved `nvlsConnections_` from `AlgorithmCtx` (which no longer has this
member) to individual algorithm class members, initialized in their
`initialize()` methods.

### Tests

#### 5. Add `TwoChannelsSameConnection` test

New unit test that creates two `SwitchChannel` instances from the same
`NvlsConnection`, performs reduce operations on both, and verifies
correctness. This exercises the multi-bind path that triggered the
`cuMemMap` offset fix.

### Files Changed

- `src/core/gpu_ipc_mem.cc` — multicast handle leak fix + cuMemMap
offset fix
- `src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu` (renamed)
- `src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu`
(renamed)
- `src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu`
(renamed)
- `src/ext/collectives/allreduce/allreduce_nvls_packet.cu` —
nvlsConnections fix
- `src/ext/collectives/include/allreduce/*.hpp` — renamed headers
- `src/ext/collectives/algorithm_collection_builder.cc` — updated
references
- `src/ext/nccl/algorithm_selector.cc` — updated algorithm names
- `test/mp_unit/switch_channel_tests.cu` — new test
- `docs/guide/mscclpp-torch-integration.md` — updated names
- `examples/torch-integration/customized_comm_with_default_algo.py` —
updated names
This commit is contained in:
Binyang Li
2026-03-09 10:22:45 -07:00
committed by GitHub
parent 3751f0299b
commit bf946ea51e
16 changed files with 313 additions and 120 deletions

View File

@@ -129,7 +129,7 @@ class CustomizedComm:
self._algo_large = [
algo for algo in algorithms
if algo.collective == "allreduce"
and algo.name == "default_allreduce_nvls_with_copy"
and algo.name == "default_allreduce_nvls_warp_pipeline"
][0]
def all_reduce(self, tensor: torch.Tensor, stream=None):
@@ -479,9 +479,9 @@ The default algorithms use a fixed heuristic to select algorithms based on messa
### How It Works
1. **Candidate selection** — For each power-of-two message size from 1 KB to 128 MB, the tuner picks the applicable algorithms:
- All sizes (when NVLS is supported): `default_allreduce_nvls_zero_copy`
- Small messages (≤ 4 MB): `default_allreduce_nvls_packet`, `default_allreduce_packet`
- Large messages (≥ 512 KB): `default_allreduce_rsag_zero_copy`
- Overlapping sizes get all three candidates.
2. **Grid search** — Each candidate is run with every combination of block counts (`4, 8, 16, … 128`) and thread counts (`512, 768, 1024`). Results are captured in a CUDA graph and timed.
@@ -489,6 +489,36 @@ The default algorithms use a fixed heuristic to select algorithms based on messa
4. **Runtime dispatch** — `get_tuned_config()` rounds the actual message size up to the next power of two and returns the winning `(algorithm, nblocks, nthreads)` triple.
### Symmetric Memory Allocation
Algorithms like `default_allreduce_nvls_zero_copy` require **symmetric memory** — memory where the buffer offset is the same for each rank, allocated via `mscclpp.RawGpuBuffer` (`cuMemAlloc`). Regular `torch.rand()` or `torch.empty()` allocations cannot be used with these algorithms because they do not guarantee the same offset across ranks. Instead, allocate a single large buffer and reuse it for all message sizes:
```python
# Allocate symmetric memory via RawGpuBuffer and wrap as a PyTorch tensor
tune_tensor = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(torch.float16))
tune_tensor = torch.utils.dlpack.from_dlpack(tune_tensor)
tune_tensor.normal_()
```
When executing an algorithm with symmetric memory, pass `symmetric_memory=True`:
```python
def _run_algo(self, algo, tensor, size, nblocks, nthreads):
return algo.execute(
comm=self.comm.communicator,
input_buffer=tensor.data_ptr(),
output_buffer=tensor.data_ptr(),
input_size=size,
output_size=size,
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(tensor.dtype),
op=mscclpp.ReduceOp.SUM,
stream=torch.cuda.current_stream().cuda_stream,
nblocks=nblocks,
nthreads_per_block=nthreads,
symmetric_memory=True,
)
```
### Loading Candidate Algorithms
The same `load_algorithms` helper from Approach 1 is reused. The tuner extracts multiple algorithm objects:
@@ -510,23 +540,35 @@ self._algorithm_packet = [
algo for algo in algorithms
if algo.collective == "allreduce" and algo.name == "default_allreduce_packet"
][0]
# NVLS zero-copy is only available on supported hardware
if mscclpp.is_nvls_supported():
self._algorithm_nvls_zero_copy = [
algo for algo in algorithms
if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_zero_copy"
][0]
```
### The Tuning Loop
The tuning loop iterates over message sizes, candidate algorithms, and kernel launch parameters. CUDA graphs are used for accurate timing:
The tuning loop iterates over message sizes, candidate algorithms, and kernel launch parameters. CUDA graphs are used for accurate timing. Note the use of `RawGpuBuffer` for symmetric memory:
```python
def _tune(self, n_warmup, n_graph_launches, n_ops_per_graph):
sizes = [1 << i for i in range(10, 28)]
self.best_configs = {1024: (self._algorithm_nvls_packet, 0, 0)}
tune_tensor = torch.rand(1 << 27, dtype=torch.float16, device="cuda")
# Use RawGpuBuffer for symmetric memory allocation
tune_tensor = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(torch.float16))
tune_tensor = torch.utils.dlpack.from_dlpack(tune_tensor)
tune_tensor.normal_()
candidates_nblocks = [4, 8, 16, 24, 32, 48, 64, 128]
candidates_nthreads = [512, 768, 1024]
for size in sizes:
algos = []
if mscclpp.is_nvls_supported():
algos.append(self._algorithm_nvls_zero_copy)
if size <= 4 * 1024 * 1024:
algos.append(self._algorithm_nvls_packet)
algos.append(self._algorithm_packet)
@@ -562,7 +604,7 @@ def _tune(self, n_warmup, n_graph_launches, n_ops_per_graph):
### Dispatching with Tuned Configuration
At runtime, round the message size to the next power of two and look up the best configuration:
At runtime, round the message size to the next power of two and look up the best configuration. When the tensor is allocated from `RawGpuBuffer` (`cuMemAlloc`) and the buffer offset is the same for each rank, pass `symmetric_memory=True` to the `execute()` call (see the [Symmetric Memory Allocation](#symmetric-memory-allocation) section above):
```python
def get_tuned_config(self, size):
@@ -591,6 +633,28 @@ def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, stream=None):
)
```
### Benchmarking with Symmetric Memory
When benchmarking tuned configurations, use the same `RawGpuBuffer` allocation pattern. Create one large buffer and slice it for each message size:
```python
def benchmark(self, n_warmup=10, n_graph_launches=10, n_iter_per_graph=100):
# Allocate a single large RawGpuBuffer (symmetric memory) and reuse for all sizes
dtype = torch.float16
bench_buf = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(dtype))
bench_buf = torch.utils.dlpack.from_dlpack(bench_buf)
bench_buf.normal_()
for size in sizes:
n_elements = size // bench_buf.element_size()
tensor = bench_buf[:n_elements]
# Capture CUDA graph, warmup, and time...
with torch.cuda.graph(g, stream=capture_stream):
for _ in range(n_iter_per_graph):
self.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
```
### Running the Tuning Example
```bash

View File

@@ -61,7 +61,7 @@ class CustomizedComm:
self._algorithm_nvls_nonzero_copy = [
algo
for algo in algorithms
if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_with_copy"
if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_warp_pipeline"
][0]
def all_reduce(self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM, stream: torch.cuda.Stream = None):

View File

@@ -64,6 +64,12 @@ class CustomizedComm:
self._algorithm_packet = [
algo for algo in algorithms if algo.collective == "allreduce" and algo.name == "default_allreduce_packet"
][0]
if mscclpp.is_nvls_supported():
self._algorithm_nvls_zero_copy = [
algo
for algo in algorithms
if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_zero_copy"
][0]
self._tune(n_warmup=5, n_graph_launches=10, n_ops_per_graph=100)
def _tune(self, n_warmup, n_graph_launches, n_ops_per_graph):
@@ -71,12 +77,16 @@ class CustomizedComm:
# Pre-fill with defaults for barrier
self.best_configs = {1024: (self._algorithm_nvls_packet, 0, 0)}
tune_tensor = torch.rand(1 << 27, dtype=torch.float16, device="cuda")
tune_tensor = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(torch.float16))
tune_tensor = torch.utils.dlpack.from_dlpack(tune_tensor)
tune_tensor.normal_()
candidates_nblocks = [4, 8, 16, 24, 32, 48, 64, 128]
candidates_nthreads = [512, 768, 1024]
for size in sizes:
algos = []
if mscclpp.is_nvls_supported():
algos.append(self._algorithm_nvls_zero_copy)
if size <= 4 * 1024 * 1024:
algos.append(self._algorithm_nvls_packet)
algos.append(self._algorithm_packet)
@@ -150,7 +160,7 @@ class CustomizedComm:
for algo in algos:
algo.reset()
def _run_algo(self, algo, tensor, size, nblocks, nthreads):
def _run_algo(self, algo: mscclpp.Algorithm, tensor, size, nblocks, nthreads):
return algo.execute(
comm=self.comm.communicator,
input_buffer=tensor.data_ptr(),
@@ -162,6 +172,7 @@ class CustomizedComm:
stream=torch.cuda.current_stream().cuda_stream,
nblocks=nblocks,
nthreads_per_block=nthreads,
symmetric_memory=True,
)
def get_tuned_config(self, size):
@@ -188,6 +199,7 @@ class CustomizedComm:
stream=stream.cuda_stream if stream is not None else torch.cuda.current_stream().cuda_stream,
nblocks=nblocks,
nthreads_per_block=nthreads,
symmetric_memory=True,
)
if ret != 0:
print(f"Rank {self.rank}: Algo {algo.name} failed with error {ret}")
@@ -211,8 +223,16 @@ class CustomizedComm:
dtype = torch.float16
capture_stream = torch.cuda.Stream()
# Allocate a single large RawGpuBuffer (symmetric memory) and reuse it for all sizes.
# Cannot allocate per-size tensors with symmetric memory.
bench_buf = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(dtype))
bench_buf = torch.utils.dlpack.from_dlpack(bench_buf)
bench_buf.normal_()
for size in sizes:
tensor = torch.rand(size // 2, dtype=dtype, device="cuda")
n_elements = size // bench_buf.element_size()
tensor = bench_buf[:n_elements]
capture_stream.wait_stream(torch.cuda.current_stream())
# Capture Graph
g = torch.cuda.CUDAGraph()

View File

@@ -249,8 +249,13 @@ UniqueGpuIpcMemHandle GpuIpcMemHandle::createMulticast([[maybe_unused]] size_t b
}
if (handle->typeFlags == GpuIpcMemHandle::Type::None) {
cuMemRelease(allocHandle);
THROW(GPU, Error, ErrorCode::SystemError, "createMulticast failed: neither POSIX FD nor FABRIC handle was created");
}
// Release the local allocation handle. The exported POSIX FD / Fabric handle keeps the
// multicast object alive. Each importer will get its own handle via cuMemImportFromShareableHandle.
MSCCLPP_CUTHROW(cuMemRelease(allocHandle));
return handle;
#else // !(CUDA_NVLS_API_AVAILABLE)
THROW(GPU, Error, ErrorCode::InvalidUsage,
@@ -418,41 +423,45 @@ std::shared_ptr<void> GpuIpcMem::mapMulticast([[maybe_unused]] int numDevices, [
// This will block until all devices call cuMulticastAddDevice()
MSCCLPP_CUTHROW(cuMulticastBindAddr(allocHandle_, mcOffset, bufferAddr, bufferSize, 0));
// cuMemMap requires offset to be 0 for multicast handles, so we map the entire range
// [0, mcOffset + bufferSize) and return a pointer at mcPtr + mcOffset. This only consumes
// extra virtual address space for the mcOffset region; no additional physical memory is used.
size_t mapSize = mcOffset + bufferSize;
CUdeviceptr mcPtr;
MSCCLPP_CUTHROW(cuMemAddressReserve(&mcPtr, bufferSize, minMcGran, 0U, 0));
MSCCLPP_CUTHROW(cuMemMap(mcPtr, bufferSize, 0, allocHandle_, 0));
MSCCLPP_CUTHROW(cuMemAddressReserve(&mcPtr, mapSize, minMcGran, 0U, 0));
MSCCLPP_CUTHROW(cuMemMap(mcPtr, mapSize, 0, allocHandle_, 0));
CUmemAccessDesc accessDesc = {};
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
accessDesc.location.id = deviceId;
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
MSCCLPP_CUTHROW(cuMemSetAccess(mcPtr, bufferSize, &accessDesc, 1));
MSCCLPP_CUTHROW(cuMemSetAccess(mcPtr, mapSize, &accessDesc, 1));
// Return shared_ptr with custom deleter that unmaps and unbinds
CUmemGenericAllocationHandle allocHandle = allocHandle_;
return std::shared_ptr<void>(
reinterpret_cast<void*>(mcPtr), [self = shared_from_this(), mcOffset, bufferSize, allocHandle](void* ptr) {
CUresult res;
const char* errStr;
return std::shared_ptr<void>(reinterpret_cast<void*>(mcPtr + mcOffset), [self = shared_from_this(), mcPtr, mapSize,
mcOffset, bufferSize, allocHandle](void*) {
CUresult res;
const char* errStr;
res = cuMemUnmap((CUdeviceptr)ptr, bufferSize);
if (res != CUDA_SUCCESS) {
(void)cuGetErrorString(res, &errStr);
WARN(GPU, "Failed to unmap CUDA memory at pointer ", (void*)ptr, ": ", errStr);
}
res = cuMemUnmap(mcPtr, mapSize);
if (res != CUDA_SUCCESS) {
(void)cuGetErrorString(res, &errStr);
WARN(GPU, "Failed to unmap CUDA memory at pointer ", (void*)mcPtr, ": ", errStr);
}
res = cuMemAddressFree((CUdeviceptr)ptr, bufferSize);
if (res != CUDA_SUCCESS) {
(void)cuGetErrorString(res, &errStr);
WARN(GPU, "Failed to free CUDA memory at pointer ", (void*)ptr, ": ", errStr);
}
res = cuMemAddressFree(mcPtr, mapSize);
if (res != CUDA_SUCCESS) {
(void)cuGetErrorString(res, &errStr);
WARN(GPU, "Failed to free CUDA memory at pointer ", (void*)mcPtr, ": ", errStr);
}
int deviceId;
CUdevice device;
if (cudaGetDevice(&deviceId) == cudaSuccess && cuDeviceGet(&device, deviceId) == CUDA_SUCCESS) {
(void)cuMulticastUnbind(allocHandle, device, mcOffset, bufferSize);
}
});
int deviceId;
CUdevice device;
if (cudaGetDevice(&deviceId) == cudaSuccess && cuDeviceGet(&device, deviceId) == CUDA_SUCCESS) {
(void)cuMulticastUnbind(allocHandle, device, mcOffset, bufferSize);
}
});
#else // !(CUDA_NVLS_API_AVAILABLE)
THROW(GPU, Error, ErrorCode::InvalidUsage,
"NVLS is not supported on this device (requires CUDA version >= 12.3 and Linux kernel version >= 5.6.0)");

View File

@@ -8,10 +8,10 @@
#include "allgather/allgather_fullmesh_2.hpp"
#include "allreduce/allreduce_allpair_packet.hpp"
#include "allreduce/allreduce_fullmesh.hpp"
#include "allreduce/allreduce_nvls.hpp"
#include "allreduce/allreduce_nvls_block_pipeline.hpp"
#include "allreduce/allreduce_nvls_packet.hpp"
#include "allreduce/allreduce_nvls_with_copy.hpp"
#include "allreduce/allreduce_nvls_with_copy_2.hpp"
#include "allreduce/allreduce_nvls_warp_pipeline.hpp"
#include "allreduce/allreduce_nvls_zero_copy.hpp"
#include "allreduce/allreduce_packet.hpp"
#include "allreduce/allreduce_rsag.hpp"
#include "allreduce/allreduce_rsag_pipeline.hpp"
@@ -72,12 +72,14 @@ AlgorithmCollection AlgorithmCollectionBuilder::buildDefaultNativeAlgorithms(uin
auto allreduceNvlsPacket =
std::make_shared<AllreduceNvlsPacket>(scratchBuffer, scratchBufferSize, flagBuffer, flagBufferSize)->build();
collection.registerAlgorithm(allreduceNvlsPacket->collective(), allreduceNvlsPacket->name(), allreduceNvlsPacket);
auto allreduceNvlsWithCopy = std::make_shared<AllreduceNvlsWithCopy>(scratchBuffer, scratchBufferSize)->build();
collection.registerAlgorithm(allreduceNvlsWithCopy->collective(), allreduceNvlsWithCopy->name(),
allreduceNvlsWithCopy);
auto allreduceNvlsWithCopy2 = std::make_shared<AllreduceNvlsWithCopy2>(scratchBuffer, scratchBufferSize)->build();
collection.registerAlgorithm(allreduceNvlsWithCopy2->collective(), allreduceNvlsWithCopy2->name(),
allreduceNvlsWithCopy2);
auto allreduceNvlsWarpPipeline =
std::make_shared<AllreduceNvlsWarpPipeline>(scratchBuffer, scratchBufferSize)->build();
collection.registerAlgorithm(allreduceNvlsWarpPipeline->collective(), allreduceNvlsWarpPipeline->name(),
allreduceNvlsWarpPipeline);
auto allreduceNvlsBlockPipeline =
std::make_shared<AllreduceNvlsBlockPipeline>(scratchBuffer, scratchBufferSize)->build();
collection.registerAlgorithm(allreduceNvlsBlockPipeline->collective(), allreduceNvlsBlockPipeline->name(),
allreduceNvlsBlockPipeline);
auto allreducePkt =
std::make_shared<AllreducePacket>(scratchBuffer, scratchBufferSize, flagBuffer, flagBufferSize)->build();
collection.registerAlgorithm(allreducePkt->collective(), allreducePkt->name(), allreducePkt);

View File

@@ -3,7 +3,7 @@
#include <mscclpp/algorithm.hpp>
#include "allreduce/allreduce_nvls_with_copy_2.hpp"
#include "allreduce/allreduce_nvls_block_pipeline.hpp"
#include "allreduce/common.hpp"
#include "collective_utils.hpp"
#include "debug.h"
@@ -15,11 +15,12 @@ __device__ DeviceSemaphore deviceSemaphore[NUM_SEMAPHORES];
template <typename T>
__global__ void __launch_bounds__(1024, 1)
allreduceNvlsWithCopy2([[maybe_unused]] const void* src, [[maybe_unused]] void* scratch, [[maybe_unused]] void* dst,
[[maybe_unused]] DeviceHandle<BaseMemoryChannel>* memoryChannels,
[[maybe_unused]] DeviceHandle<SwitchChannel>* switchChannels, [[maybe_unused]] size_t size,
[[maybe_unused]] size_t scratchBufferSize, [[maybe_unused]] int rank,
[[maybe_unused]] int nRanksPerNode) {
allreduceNvlsBlockPipeline([[maybe_unused]] const void* src, [[maybe_unused]] void* scratch,
[[maybe_unused]] void* dst,
[[maybe_unused]] DeviceHandle<BaseMemoryChannel>* memoryChannels,
[[maybe_unused]] DeviceHandle<SwitchChannel>* switchChannels,
[[maybe_unused]] size_t size, [[maybe_unused]] size_t scratchBufferSize,
[[maybe_unused]] int rank, [[maybe_unused]] int nRanksPerNode) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
constexpr int alignment = 16;
int nPeers = nRanksPerNode - 1;
@@ -146,7 +147,7 @@ __global__ void __launch_bounds__(1024, 1)
}
template <ReduceOp OpType, typename T>
struct NvlsWithCopy2Adapter {
struct NvlsBlockPipelineAdapter {
static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void*,
DeviceHandle<SwitchChannel>* nvlsChannels, DeviceHandle<SwitchChannel>*, size_t, size_t,
size_t scratchBufferSize, int rank, int nRanksPerNode, int, size_t inputSize,
@@ -162,7 +163,7 @@ struct NvlsWithCopy2Adapter {
#endif
{
using ChannelType = DeviceHandle<BaseMemoryChannel>;
allreduceNvlsWithCopy2<T>
allreduceNvlsBlockPipeline<T>
<<<nBlocks, nThreadsPerBlock, 0, stream>>>(input, scratch, output, (ChannelType*)memoryChannels,
nvlsChannels, inputSize, scratchBufferSize, rank, nRanksPerNode);
return cudaGetLastError();
@@ -170,7 +171,7 @@ struct NvlsWithCopy2Adapter {
}
};
void AllreduceNvlsWithCopy2::initialize(std::shared_ptr<Communicator> comm) {
void AllreduceNvlsBlockPipeline::initialize(std::shared_ptr<Communicator> comm) {
nSwitchChannels_ = 8;
int nBaseChannels = 64;
this->conns_ = setupConnections(comm);
@@ -180,14 +181,15 @@ void AllreduceNvlsWithCopy2::initialize(std::shared_ptr<Communicator> comm) {
// setup base memory channels
this->baseChannels_ = setupBaseMemoryChannels(this->conns_, memorySemaphores, nBaseChannels);
this->memoryChannelsDeviceHandle_ = setupBaseMemoryChannelDeviceHandles(this->baseChannels_);
this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels_);
}
CommResult AllreduceNvlsWithCopy2::allreduceKernelFunc(const std::shared_ptr<void> ctx_void, const void* input,
void* output, size_t inputSize, DataType dtype, ReduceOp op,
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>&) {
CommResult AllreduceNvlsBlockPipeline::allreduceKernelFunc(const std::shared_ptr<void> ctx_void, const void* input,
void* output, size_t inputSize, DataType dtype, ReduceOp op,
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>&) {
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
AllreduceFunc allreduce = dispatch<NvlsWithCopy2Adapter>(op, dtype);
AllreduceFunc allreduce = dispatch<NvlsBlockPipelineAdapter>(op, dtype);
if (!allreduce) {
WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast<int>(dtype));
return CommResult::CommInvalidArgument;
@@ -201,35 +203,35 @@ CommResult AllreduceNvlsWithCopy2::allreduceKernelFunc(const std::shared_ptr<voi
ctx->rank, ctx->nRanksPerNode, ctx->workSize, inputSize, stream, nullptr, 0, 0,
blockAndThreadNum.first, blockAndThreadNum.second);
if (error != cudaSuccess) {
WARN("AllreduceNvlsWithCopy failed with error: %s", cudaGetErrorString(error));
WARN("AllreduceNvlsBlockPipeline failed with error: %s", cudaGetErrorString(error));
return CommResult::CommUnhandledCudaError;
}
return CommResult::CommSuccess;
}
AlgorithmCtxKey AllreduceNvlsWithCopy2::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) {
AlgorithmCtxKey AllreduceNvlsBlockPipeline::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) {
return AlgorithmCtxKey{nullptr, nullptr, 0, 0, 0};
}
std::shared_ptr<void> AllreduceNvlsWithCopy2::initAllreduceContext(std::shared_ptr<Communicator> comm, const void*,
void*, size_t, DataType) {
std::shared_ptr<void> AllreduceNvlsBlockPipeline::initAllreduceContext(std::shared_ptr<Communicator> comm, const void*,
void*, size_t, DataType) {
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
// setup channels
ctx->nvlsConnections = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels_);
ctx->switchChannels =
setupNvlsChannels(ctx->nvlsConnections, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_);
setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_);
ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels);
return ctx;
}
std::shared_ptr<Algorithm> AllreduceNvlsWithCopy2::build() {
auto self = std::make_shared<AllreduceNvlsWithCopy2>(reinterpret_cast<uintptr_t>(scratchBuffer_), scratchBufferSize_);
std::shared_ptr<Algorithm> AllreduceNvlsBlockPipeline::build() {
auto self =
std::make_shared<AllreduceNvlsBlockPipeline>(reinterpret_cast<uintptr_t>(scratchBuffer_), scratchBufferSize_);
return std::make_shared<NativeAlgorithm>(
"default_allreduce_nvls_with_copy2", "allreduce",
"default_allreduce_nvls_block_pipeline", "allreduce",
[self](std::shared_ptr<Communicator> comm) { self->initialize(comm); },
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
@@ -247,4 +249,4 @@ std::shared_ptr<Algorithm> AllreduceNvlsWithCopy2::build() {
}
} // namespace collective
} // namespace mscclpp
} // namespace mscclpp

View File

@@ -75,7 +75,10 @@ struct AllreduceNvlsPacketAdapter {
}
};
void AllreduceNvlsPacket::initialize(std::shared_ptr<Communicator>) {}
void AllreduceNvlsPacket::initialize(std::shared_ptr<Communicator> comm) {
int nSwitchChannels = 1;
this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels);
}
AlgorithmCtxKey AllreduceNvlsPacket::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) {
return AlgorithmCtxKey{nullptr, nullptr, 0, 0, 0};
@@ -90,9 +93,8 @@ std::shared_ptr<void> AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr<
// setup channels
int nSwitchChannels = 1;
ctx->nvlsConnections = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels);
ctx->switchChannels =
setupNvlsChannels(ctx->nvlsConnections, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels);
setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels);
ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels);
return ctx;
}

View File

@@ -3,7 +3,7 @@
#include <mscclpp/algorithm.hpp>
#include "allreduce/allreduce_nvls_with_copy.hpp"
#include "allreduce/allreduce_nvls_warp_pipeline.hpp"
#include "allreduce/common.hpp"
#include "collective_utils.hpp"
#include "debug.h"
@@ -13,11 +13,12 @@ namespace collective {
template <typename T>
__global__ void __launch_bounds__(1024, 1)
allreduce10([[maybe_unused]] const void* src, [[maybe_unused]] void* scratch, [[maybe_unused]] void* dst,
[[maybe_unused]] DeviceHandle<BaseMemoryChannel>* memoryChannels,
[[maybe_unused]] DeviceHandle<SwitchChannel>* multicast, [[maybe_unused]] size_t size,
[[maybe_unused]] size_t scratchBufferSize, [[maybe_unused]] int rank,
[[maybe_unused]] int nRanksPerNode) {
allreduceNvlsWarpPipeline([[maybe_unused]] const void* src, [[maybe_unused]] void* scratch,
[[maybe_unused]] void* dst,
[[maybe_unused]] DeviceHandle<BaseMemoryChannel>* memoryChannels,
[[maybe_unused]] DeviceHandle<SwitchChannel>* multicast, [[maybe_unused]] size_t size,
[[maybe_unused]] size_t scratchBufferSize, [[maybe_unused]] int rank,
[[maybe_unused]] int nRanksPerNode) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
constexpr int alignment = 16;
int nPeers = nRanksPerNode - 1;
@@ -109,7 +110,7 @@ __global__ void __launch_bounds__(1024, 1)
}
template <ReduceOp OpType, typename T>
struct NvlsWithCopyAdapter {
struct NvlsWarpPipelineAdapter {
static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void*,
DeviceHandle<SwitchChannel>* nvlsChannels, DeviceHandle<SwitchChannel>*, size_t, size_t,
size_t scratchBufferSize, int rank, int nRanksPerNode, int, size_t inputSize,
@@ -125,15 +126,15 @@ struct NvlsWithCopyAdapter {
#endif
{
using ChannelType = DeviceHandle<BaseMemoryChannel>;
allreduce10<T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(input, scratch, output, (ChannelType*)memoryChannels,
nvlsChannels, inputSize, scratchBufferSize, rank,
nRanksPerNode);
allreduceNvlsWarpPipeline<T>
<<<nBlocks, nThreadsPerBlock, 0, stream>>>(input, scratch, output, (ChannelType*)memoryChannels,
nvlsChannels, inputSize, scratchBufferSize, rank, nRanksPerNode);
return cudaGetLastError();
}
}
};
void AllreduceNvlsWithCopy::initialize(std::shared_ptr<Communicator> comm) {
void AllreduceNvlsWarpPipeline::initialize(std::shared_ptr<Communicator> comm) {
nSwitchChannels_ = 8;
int nBaseChannels = 64;
this->conns_ = setupConnections(comm);
@@ -143,14 +144,15 @@ void AllreduceNvlsWithCopy::initialize(std::shared_ptr<Communicator> comm) {
// setup base memory channels
this->baseChannels_ = setupBaseMemoryChannels(this->conns_, memorySemaphores, nBaseChannels);
this->memoryChannelsDeviceHandle_ = setupBaseMemoryChannelDeviceHandles(this->baseChannels_);
this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels_);
}
CommResult AllreduceNvlsWithCopy::allreduceKernelFunc(const std::shared_ptr<void> ctx_void, const void* input,
void* output, size_t inputSize, DataType dtype, ReduceOp op,
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>&) {
CommResult AllreduceNvlsWarpPipeline::allreduceKernelFunc(const std::shared_ptr<void> ctx_void, const void* input,
void* output, size_t inputSize, DataType dtype, ReduceOp op,
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>&) {
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
AllreduceFunc allreduce = dispatch<NvlsWithCopyAdapter>(op, dtype);
AllreduceFunc allreduce = dispatch<NvlsWarpPipelineAdapter>(op, dtype);
if (!allreduce) {
WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast<int>(dtype));
return CommResult::CommInvalidArgument;
@@ -164,35 +166,35 @@ CommResult AllreduceNvlsWithCopy::allreduceKernelFunc(const std::shared_ptr<void
ctx->rank, ctx->nRanksPerNode, ctx->workSize, inputSize, stream, nullptr, 0, 0,
blockAndThreadNum.first, blockAndThreadNum.second);
if (error != cudaSuccess) {
WARN("AllreduceNvlsWithCopy failed with error: %s", cudaGetErrorString(error));
WARN("AllreduceNvlsWarpPipeline failed with error: %s", cudaGetErrorString(error));
return CommResult::CommUnhandledCudaError;
}
return CommResult::CommSuccess;
}
AlgorithmCtxKey AllreduceNvlsWithCopy::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) {
AlgorithmCtxKey AllreduceNvlsWarpPipeline::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) {
return AlgorithmCtxKey{nullptr, nullptr, 0, 0, 0};
}
std::shared_ptr<void> AllreduceNvlsWithCopy::initAllreduceContext(std::shared_ptr<Communicator> comm, const void*,
void*, size_t, DataType) {
std::shared_ptr<void> AllreduceNvlsWarpPipeline::initAllreduceContext(std::shared_ptr<Communicator> comm, const void*,
void*, size_t, DataType) {
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
// setup channels
ctx->nvlsConnections = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels_);
ctx->switchChannels =
setupNvlsChannels(ctx->nvlsConnections, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_);
setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_);
ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels);
return ctx;
}
std::shared_ptr<Algorithm> AllreduceNvlsWithCopy::build() {
auto self = std::make_shared<AllreduceNvlsWithCopy>(reinterpret_cast<uintptr_t>(scratchBuffer_), scratchBufferSize_);
std::shared_ptr<Algorithm> AllreduceNvlsWarpPipeline::build() {
auto self =
std::make_shared<AllreduceNvlsWarpPipeline>(reinterpret_cast<uintptr_t>(scratchBuffer_), scratchBufferSize_);
return std::make_shared<NativeAlgorithm>(
"default_allreduce_nvls_with_copy", "allreduce",
"default_allreduce_nvls_warp_pipeline", "allreduce",
[self](std::shared_ptr<Communicator> comm) { self->initialize(comm); },
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
@@ -209,4 +211,4 @@ std::shared_ptr<Algorithm> AllreduceNvlsWithCopy::build() {
});
}
} // namespace collective
} // namespace mscclpp
} // namespace mscclpp

View File

@@ -3,7 +3,7 @@
#include <mscclpp/core.hpp>
#include "allreduce/allreduce_nvls.hpp"
#include "allreduce/allreduce_nvls_zero_copy.hpp"
#include "allreduce/common.hpp"
#include "collective_utils.hpp"
#include "debug.h"
@@ -11,6 +11,8 @@
namespace mscclpp {
namespace collective {
constexpr int MAX_NBLOCKS = 32;
template <typename T>
__global__ void __launch_bounds__(1024, 1)
allreduceNvls([[maybe_unused]] mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>* memoryChannels,
@@ -105,6 +107,8 @@ void AllreduceNvls::initialize(std::shared_ptr<mscclpp::Communicator> comm) {
// setup base memory channels
this->baseChannels_ = setupBaseMemoryChannels(this->conns_, memorySemaphores, nSwitchChannels_);
this->memoryChannelsDeviceHandle_ = setupBaseMemoryChannelDeviceHandles(this->baseChannels_);
this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels_);
this->nvlsOutConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels_);
}
CommResult AllreduceNvls::allreduceKernelFunc(const std::shared_ptr<void> ctx_void, const void* input, void* output,
@@ -134,12 +138,18 @@ CommResult AllreduceNvls::allreduceKernelFunc(const std::shared_ptr<void> ctx_vo
}
std::pair<int, int> numBlocksAndThreads = {nBlocks, nThreadsPerBlock};
if (numBlocksAndThreads.first == 0 || numBlocksAndThreads.second == 0) {
numBlocksAndThreads = {::min(ctx->nRanksPerNode, nSwitchChannels_), 1024};
// For GB200 devices, using more blocks to improve the performances when nRanksPerNode <= 8
if (computeCapabilityMajor_ == 10 && ctx->nRanksPerNode <= 8) {
numBlocksAndThreads.first = ::min(32, nSwitchChannels_);
numBlocksAndThreads = {::min(ctx->nRanksPerNode, MAX_NBLOCKS), 1024};
// For GB200 devices with MNNVLS (Multi-Node NVLink Sharp), scale the number of blocks inversely with
// the number of GPUs. Empirically, 32 blocks works well for 4 GPUs and 16 for 8 GPUs, which
// follows the formula 128 / nGPUs, clamped to [1, MAX_NBLOCKS].
if (computeCapabilityMajor_ == 10) {
numBlocksAndThreads.first = ::max(1, ::min(128 / ctx->workSize, MAX_NBLOCKS));
}
}
if (numBlocksAndThreads.first > MAX_NBLOCKS) {
WARN("Number of blocks exceeds maximum supported value of %d", MAX_NBLOCKS);
return CommResult::CommInvalidArgument;
}
cudaError_t error =
allreduce(nullptr, nullptr, nullptr, this->memoryChannelsDeviceHandle_.get(), nullptr, nvlsChannels,
nvlsOutChannels, channelInOffset, channelOutOffset, 0, ctx->rank, ctx->nRanksPerNode, ctx->workSize,
@@ -174,13 +184,11 @@ std::shared_ptr<void> AllreduceNvls::initAllreduceContext(std::shared_ptr<mscclp
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)output));
// setup channels
ctx->nvlsConnections = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels_);
ctx->switchChannels = setupNvlsChannels(ctx->nvlsConnections, (void*)sendBasePtr, sendBytes, nSwitchChannels_);
ctx->switchChannels = setupNvlsChannels(this->nvlsConnections_, (void*)sendBasePtr, sendBytes, nSwitchChannels_);
if (input != output) {
auto nvlsOutConnections = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels_);
auto nvlsOutConnections = this->nvlsOutConnections_;
std::vector<mscclpp::SwitchChannel> outChannels =
setupNvlsChannels(nvlsOutConnections, (void*)recvBasePtr, recvBytes, nSwitchChannels_);
ctx->nvlsConnections.insert(ctx->nvlsConnections.end(), nvlsOutConnections.begin(), nvlsOutConnections.end());
setupNvlsChannels(this->nvlsOutConnections_, (void*)recvBasePtr, recvBytes, nSwitchChannels_);
ctx->switchChannels.insert(ctx->switchChannels.end(), outChannels.begin(), outChannels.end());
}
@@ -191,7 +199,7 @@ std::shared_ptr<void> AllreduceNvls::initAllreduceContext(std::shared_ptr<mscclp
std::shared_ptr<mscclpp::Algorithm> AllreduceNvls::build() {
auto self = std::make_shared<AllreduceNvls>();
return std::make_shared<mscclpp::NativeAlgorithm>(
"default_allreduce_nvls", "allreduce",
"default_allreduce_nvls_zero_copy", "allreduce",
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize, mscclpp::DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,

View File

@@ -1,14 +1,17 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#ifndef MSCCLPP_EXT_ALLREDUCE_NVLS_BLOCK_PIPELINE_HPP_
#define MSCCLPP_EXT_ALLREDUCE_NVLS_BLOCK_PIPELINE_HPP_
#include <mscclpp/algorithm.hpp>
namespace mscclpp {
namespace collective {
class AllreduceNvlsWithCopy : public AlgorithmBuilder {
class AllreduceNvlsBlockPipeline : public AlgorithmBuilder {
public:
AllreduceNvlsWithCopy(uintptr_t scratchBuffer, size_t scratchBufferSize)
AllreduceNvlsBlockPipeline(uintptr_t scratchBuffer, size_t scratchBufferSize)
: scratchBuffer_(reinterpret_cast<void*>(scratchBuffer)), scratchBufferSize_(scratchBufferSize){};
std::shared_ptr<Algorithm> build() override;
@@ -29,6 +32,9 @@ class AllreduceNvlsWithCopy : public AlgorithmBuilder {
std::shared_ptr<DeviceHandle<BaseMemoryChannel>> memoryChannelsDeviceHandle_;
std::vector<BaseMemoryChannel> baseChannels_;
std::vector<Connection> conns_;
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections_;
};
} // namespace collective
} // namespace mscclpp
} // namespace mscclpp
#endif // MSCCLPP_EXT_ALLREDUCE_NVLS_BLOCK_PIPELINE_HPP_

View File

@@ -33,6 +33,7 @@ class AllreduceNvlsPacket : public mscclpp::AlgorithmBuilder {
const int maxBlockNum_ = 16;
uintptr_t flagBuffer_;
size_t flagBufferSize_;
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections_;
};
} // namespace collective
} // namespace mscclpp

View File

@@ -1,17 +1,17 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#ifndef MSCCLPP_EXT_ALLREDUCE_NVLS_WITH_COPY_2_HPP_
#define MSCCLPP_EXT_ALLREDUCE_NVLS_WITH_COPY_2_HPP_
#ifndef MSCCLPP_EXT_ALLREDUCE_NVLS_WARP_PIPELINE_HPP_
#define MSCCLPP_EXT_ALLREDUCE_NVLS_WARP_PIPELINE_HPP_
#include <mscclpp/algorithm.hpp>
namespace mscclpp {
namespace collective {
class AllreduceNvlsWithCopy2 : public AlgorithmBuilder {
class AllreduceNvlsWarpPipeline : public AlgorithmBuilder {
public:
AllreduceNvlsWithCopy2(uintptr_t scratchBuffer, size_t scratchBufferSize)
AllreduceNvlsWarpPipeline(uintptr_t scratchBuffer, size_t scratchBufferSize)
: scratchBuffer_(reinterpret_cast<void*>(scratchBuffer)), scratchBufferSize_(scratchBufferSize){};
std::shared_ptr<Algorithm> build() override;
@@ -32,8 +32,9 @@ class AllreduceNvlsWithCopy2 : public AlgorithmBuilder {
std::shared_ptr<DeviceHandle<BaseMemoryChannel>> memoryChannelsDeviceHandle_;
std::vector<BaseMemoryChannel> baseChannels_;
std::vector<Connection> conns_;
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections_;
};
} // namespace collective
} // namespace mscclpp
#endif // MSCCLPP_EXT_ALLREDUCE_NVLS_WITH_COPY_2_HPP_
#endif // MSCCLPP_EXT_ALLREDUCE_NVLS_WARP_PIPELINE_HPP_

View File

@@ -1,6 +1,9 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#ifndef MSCCLPP_ALLREDUCE_NVLS_ZERO_COPY_HPP_
#define MSCCLPP_ALLREDUCE_NVLS_ZERO_COPY_HPP_
#include <mscclpp/algorithm.hpp>
namespace mscclpp {
@@ -22,13 +25,20 @@ class AllreduceNvls : public AlgorithmBuilder {
DataType);
AlgorithmCtxKey generateAllreduceContextKey(const void*, void*, size_t, DataType, bool);
const size_t nvlsBufferSize_ = (1 << 30);
// Large buffer size because cuMemMap requires offset=0 for multicast handles, so the entire
// user allocation must be mapped. This only reserves virtual address space; no physical memory
// is consumed beyond what is actually bound.
const size_t nvlsBufferSize_ = (1UL << 34);
uint32_t nSwitchChannels_;
std::shared_ptr<DeviceHandle<BaseMemoryChannel>> memoryChannelsDeviceHandle_;
std::vector<BaseMemoryChannel> baseChannels_;
std::vector<Connection> conns_;
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections_;
std::vector<std::shared_ptr<NvlsConnection>> nvlsOutConnections_;
int computeCapabilityMajor_{0};
};
} // namespace collective
} // namespace mscclpp
} // namespace mscclpp
#endif // MSCCLPP_ALLREDUCE_NVLS_ZERO_COPY_HPP_

View File

@@ -78,7 +78,6 @@ class AlgorithmCtx {
std::vector<MemoryChannel> memoryChannels;
std::vector<SwitchChannel> switchChannels;
std::vector<PortChannel> portChannels;
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections;
std::shared_ptr<DeviceHandle<MemoryChannel>> memoryChannelDeviceHandles;
std::shared_ptr<DeviceHandle<SwitchChannel>> switchChannelDeviceHandles;
std::shared_ptr<DeviceHandle<PortChannel>> portChannelDeviceHandles;

View File

@@ -88,7 +88,7 @@ static std::shared_ptr<Algorithm> selectSingleNodeAllreduceBlackwell(
return algoMap.at("default_allreduce_packet");
}
if (useNvlsWithZeroCopy) {
return algoMap.at("default_allreduce_nvls");
return algoMap.at("default_allreduce_nvls_zero_copy");
}
return algoMap.at("default_allreduce_rsag_zero_copy");
@@ -123,14 +123,14 @@ std::shared_ptr<Algorithm> selectSingleNodeAllreduce(
}
// Large messages with NVLS zero-copy support
if (nvlsSupported && useNvlsWithZeroCopy) {
return algoMap.at("default_allreduce_nvls");
return algoMap.at("default_allreduce_nvls_zero_copy");
}
// Large messages with NVLS but without zero-copy
if (nvlsSupported) {
if (messageSize < (1 << 24)) { // < 16MB
return algoMap.at("default_allreduce_nvls_with_copy");
return algoMap.at("default_allreduce_nvls_warp_pipeline");
}
return algoMap.at("default_allreduce_nvls_with_copy2");
return algoMap.at("default_allreduce_nvls_block_pipeline");
}
#if defined(__HIP_PLATFORM_AMD__)
// AMD platform: use fullmesh algorithm

View File

@@ -68,3 +68,70 @@ TEST_F(SwitchChannelTest, SimpleAllReduce) {
}
ASSERT_EQ(result, expected) << "Expected " << expected << " but got " << result << " for rank " << gEnv->rank;
}
__constant__ mscclpp::SwitchChannelDeviceHandle gConstSwitchChan1;
__constant__ mscclpp::SwitchChannelDeviceHandle gConstSwitchChan2;
__global__ void kernelSwitchReduceTwo() {
#if (CUDA_NVLS_API_AVAILABLE) && (__CUDA_ARCH__ >= 900)
auto val1 = gConstSwitchChan1.reduce<mscclpp::f32x1>(0);
gConstSwitchChan1.broadcast(0, val1);
auto val2 = gConstSwitchChan2.reduce<mscclpp::f32x1>(0);
gConstSwitchChan2.broadcast(0, val2);
#endif // (CUDA_NVLS_API_AVAILABLE) && (__CUDA_ARCH__ >= 900)
}
TEST_F(SwitchChannelTest, TwoChannelsSameConnection) {
if (gEnv->rank >= numRanksToUse) return;
std::vector<int> ranks;
for (int i = 0; i < numRanksToUse; i++) {
ranks.push_back(i);
}
const size_t bufSize = 1024;
auto buffer1 = mscclpp::GpuBuffer<float>(bufSize / sizeof(float));
auto buffer2 = mscclpp::GpuBuffer<float>(bufSize / sizeof(float));
float data1 = (gEnv->rank + 1.0f) * 1.0f;
float data2 = (gEnv->rank + 1.0f) * 10.0f;
MSCCLPP_CUDATHROW(cudaMemcpy(buffer1.data(), &data1, sizeof(data1), cudaMemcpyHostToDevice));
MSCCLPP_CUDATHROW(cudaMemcpy(buffer2.data(), &data2, sizeof(data2), cudaMemcpyHostToDevice));
// Connection size must be large enough for two granularity-aligned buffers.
// The multicast granularity is typically 2MB, so we need at least 2 * 2MB.
const size_t connSize = buffer1.bytes() + buffer2.bytes();
auto nvlsConnection = mscclpp::connectNvlsCollective(communicator, ranks, connSize);
// Bind two separate buffers to the same connection
auto switchChannel1 = nvlsConnection->bindAllocatedMemory(CUdeviceptr(buffer1.data()), bufSize);
auto switchChannel2 = nvlsConnection->bindAllocatedMemory(CUdeviceptr(buffer2.data()), bufSize);
auto deviceHandle1 = switchChannel1.deviceHandle();
auto deviceHandle2 = switchChannel2.deviceHandle();
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gConstSwitchChan1, &deviceHandle1, sizeof(deviceHandle1)));
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gConstSwitchChan2, &deviceHandle2, sizeof(deviceHandle2)));
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
communicator->bootstrap()->barrier();
if (gEnv->rank == 0) {
kernelSwitchReduceTwo<<<1, 1>>>();
MSCCLPP_CUDATHROW(cudaGetLastError());
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
}
communicator->bootstrap()->barrier();
float result1, result2;
MSCCLPP_CUDATHROW(cudaMemcpy(&result1, buffer1.data(), sizeof(result1), cudaMemcpyDeviceToHost));
MSCCLPP_CUDATHROW(cudaMemcpy(&result2, buffer2.data(), sizeof(result2), cudaMemcpyDeviceToHost));
float expected1 = 0.0f;
float expected2 = 0.0f;
for (int i = 0; i < numRanksToUse; i++) {
expected1 += (i + 1.0f) * 1.0f;
expected2 += (i + 1.0f) * 10.0f;
}
ASSERT_EQ(result1, expected1) << "Channel1: expected " << expected1 << " but got " << result1;
ASSERT_EQ(result2, expected2) << "Channel2: expected " << expected2 << " but got " << result2;
}