Move the alltoallv kernel to the src directory; Utilize the kernel in mscclpp-test

This commit is contained in:
Qinghua Zhou
2026-02-06 02:57:34 +00:00
parent ac3e770c42
commit c42579e900
6 changed files with 440 additions and 486 deletions

View File

@@ -1,372 +1,22 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
// AllToAllV implementation for MSCCLPP
// This kernel handles variable element counts per rank for alltoallv operations.
// Unlike NCCL's ncclGroupStart/ncclGroupEnd approach, mscclpp uses explicit
// put/signal/wait operations on PortChannels.
// AllToAllV Python bindings for MSCCLPP
// This file provides Python bindings for the alltoallv algorithm.
// The actual implementation is in src/ext/collectives/alltoallv/
#include <Python.h>
#include <pybind11/pybind11.h>
#include <mscclpp/algorithm.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/port_channel.hpp>
#include <mscclpp/port_channel_device.hpp>
#include <mscclpp/concurrency_device.hpp>
// Include the implementation header
#include "alltoallv/alltoallv_fullmesh.hpp"
namespace py = pybind11;
#if defined(__HIP_PLATFORM_AMD__)
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
// Device syncer for synchronization across blocks
__device__ mscclpp::DeviceSyncer alltoallvDeviceSyncer;
/**
* AllToAllV kernel implementation
*
* This kernel performs an all-to-all exchange with variable-length data per rank.
* Each rank sends sendCounts[i] elements to rank i at sendDispls[i] offset,
* and receives recvCounts[i] elements from rank i at recvDispls[i] offset.
*
* Since mscclpp doesn't support ncclGroupStart/ncclGroupEnd, we implement
* the exchange using explicit put/signal/wait operations on PortChannels.
* The communication pattern uses a ring-based approach to avoid deadlocks.
*
* @param portChannels Array of PortChannel handles for each peer (worldSize-1 channels)
* @param rank Current rank
* @param worldSize Total number of ranks
* @param sendBuff Source buffer containing data to send
* @param recvBuff Destination buffer for received data
* @param sendCounts Array of send counts for each rank (in bytes)
* @param sendDispls Array of send displacements for each rank (in bytes)
* @param recvCounts Array of receive counts for each rank (in bytes)
* @param recvDispls Array of receive displacements for each rank (in bytes)
*/
__global__ void __launch_bounds__(1024)
alltoallv_kernel(mscclpp::DeviceHandle<mscclpp::PortChannel>* portChannels,
int rank,
int worldSize,
const void* sendBuff,
void* recvBuff,
const size_t* sendCounts,
const size_t* sendDispls,
const size_t* recvCounts,
const size_t* recvDispls) {
// First, copy local data (rank's own portion) from send to recv buffer
// This doesn't require any communication
if (threadIdx.x == 0 && blockIdx.x == 0) {
if (sendCounts[rank] > 0) {
// Local copy: sendBuff[sendDispls[rank]] -> recvBuff[recvDispls[rank]]
const char* src = (const char*)sendBuff + sendDispls[rank];
char* dst = (char*)recvBuff + recvDispls[rank];
memcpy(dst, src, sendCounts[rank]);
}
}
__syncthreads();
// Ring-based exchange pattern to avoid deadlocks
// In each step i, rank sends to (rank + i) % worldSize and receives from (rank - i + worldSize) % worldSize
for (int step = 1; step < worldSize; step++) {
int sendPeer = (rank + step) % worldSize;
int recvPeer = (rank - step + worldSize) % worldSize;
// Get channel indices (portChannels excludes self, so adjust index)
int sendChanIdx = sendPeer < rank ? sendPeer : sendPeer - 1;
int recvChanIdx = recvPeer < rank ? recvPeer : recvPeer - 1;
// Each warp handles one peer
int wid = threadIdx.x / WARP_SIZE;
int lid = threadIdx.x % WARP_SIZE;
// Send data to sendPeer if there's data to send
if (wid == 0 && lid == 0) {
if (sendCounts[sendPeer] > 0) {
// putWithSignal: copy data and signal completion
// src offset: sendDispls[sendPeer] in our sendBuff
// dst offset: recvDispls[rank] in peer's recvBuff (where our data should go)
portChannels[sendChanIdx].putWithSignal(
recvDispls[rank], // dst offset in peer's recv buffer (where we write)
sendDispls[sendPeer], // src offset in our send buffer
sendCounts[sendPeer] // size in bytes
);
}
}
// Sync all threads before flushing
alltoallvDeviceSyncer.sync(gridDim.x);
// Flush to ensure data is sent
if (wid == 0 && lid == 0) {
if (sendCounts[sendPeer] > 0) {
portChannels[sendChanIdx].flush();
}
}
// Wait for data from recvPeer if we're expecting data
if (wid == 0 && lid == 0) {
if (recvCounts[recvPeer] > 0) {
portChannels[recvChanIdx].wait();
}
}
// Sync all threads before next step
alltoallvDeviceSyncer.sync(gridDim.x);
}
}
/**
* Simplified AllToAllV kernel for single-block execution
*
* This version is optimized for cases where all communication can be
* handled within a single thread block.
*/
__global__ void __launch_bounds__(1024)
alltoallv_simple_kernel(mscclpp::DeviceHandle<mscclpp::PortChannel>* portChannels,
int rank,
int worldSize,
const void* sendBuff,
void* recvBuff,
const size_t* sendCounts,
const size_t* sendDispls,
const size_t* recvCounts,
const size_t* recvDispls) {
int tid = threadIdx.x;
int nPeers = worldSize - 1;
// Step 1: Copy local data
if (tid == 0 && sendCounts[rank] > 0) {
const char* src = (const char*)sendBuff + sendDispls[rank];
char* dst = (char*)recvBuff + recvDispls[rank];
memcpy(dst, src, sendCounts[rank]);
}
__syncthreads();
// Step 2: Each warp handles one peer for sending
// We have worldSize-1 peers, assign one warp per peer
int warpId = tid / WARP_SIZE;
int laneId = tid % WARP_SIZE;
if (warpId < nPeers && laneId == 0) {
// Determine which peer this warp handles
int peer = warpId < rank ? warpId : warpId + 1;
int chanIdx = warpId;
if (sendCounts[peer] > 0) {
portChannels[chanIdx].putWithSignal(
recvDispls[rank], // dst offset in peer's buffer
sendDispls[peer], // src offset in our buffer
sendCounts[peer] // size
);
}
}
__syncthreads();
// Step 3: Flush all pending operations
if (warpId < nPeers && laneId == 0) {
int peer = warpId < rank ? warpId : warpId + 1;
if (sendCounts[peer] > 0) {
portChannels[warpId].flush();
}
}
__syncthreads();
// Step 4: Wait for all incoming data
if (warpId < nPeers && laneId == 0) {
int peer = warpId < rank ? warpId : warpId + 1;
if (recvCounts[peer] > 0) {
portChannels[warpId].wait();
}
}
__syncthreads();
}
// Context to hold all necessary state for alltoallv execution
struct AllToAllVContext {
int rank;
int worldSize;
int nRanksPerNode;
std::vector<mscclpp::RegisteredMemory> registeredMemories;
std::shared_ptr<mscclpp::DeviceHandle<mscclpp::PortChannel>> portChannelDeviceHandles;
// Device memory for counts and displacements
size_t* d_sendCounts;
size_t* d_sendDispls;
size_t* d_recvCounts;
size_t* d_recvDispls;
};
class AllToAllVAlgoBuilder : public mscclpp::AlgorithmBuilder {
public:
AllToAllVAlgoBuilder() = default;
~AllToAllVAlgoBuilder() {
if (proxyService_) {
proxyService_->stopProxy();
}
}
std::shared_ptr<mscclpp::Algorithm> build() override {
auto self = std::make_shared<AllToAllVAlgoBuilder>();
std::shared_ptr<mscclpp::Algorithm> alltoallvAlgo = std::make_shared<mscclpp::NativeAlgorithm>(
"alltoallv", "alltoallv",
// Initialize function
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
// Kernel execution function
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize, size_t outputSize,
mscclpp::DataType dtype, [[maybe_unused]] mscclpp::ReduceOp op, cudaStream_t stream, int nBlocks,
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) {
return self->alltoallvKernelFunc(ctx, input, output, inputSize, outputSize, dtype, stream, extras);
},
// Context initialization function
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t inputSize,
size_t outputSize,
mscclpp::DataType dtype) { return self->initAlltoallvContext(comm, input, output, inputSize, outputSize, dtype); },
// Context key generation function
[self](const void* input, void* output, size_t inputSize, size_t outputSize, mscclpp::DataType dtype) {
return self->generateAlltoallvContextKey(input, output, inputSize, outputSize, dtype);
});
return alltoallvAlgo;
}
private:
std::vector<mscclpp::Connection> conns_;
std::shared_ptr<mscclpp::ProxyService> proxyService_;
int worldSize_;
void initialize(std::shared_ptr<mscclpp::Communicator> comm) {
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
worldSize_ = comm->bootstrap()->getNranks();
for (int i = 0; i < worldSize_; i++) {
if (i == comm->bootstrap()->getRank()) continue;
connectionFutures.push_back(comm->connect(mscclpp::Transport::CudaIpc, i));
}
std::vector<mscclpp::Connection> connections;
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
[](const auto& future) { return future.get(); });
this->conns_ = std::move(connections);
proxyService_ = std::make_shared<mscclpp::ProxyService>();
proxyService_->startProxy(true);
}
mscclpp::CommResult alltoallvKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output,
size_t inputSize, size_t outputSize,
[[maybe_unused]] mscclpp::DataType dtype,
cudaStream_t stream,
const std::unordered_map<std::string, uintptr_t>& extras) {
auto algoCtx = std::static_pointer_cast<AllToAllVContext>(ctx);
int rank = algoCtx->rank;
int worldSize = algoCtx->worldSize;
// Extract send/recv counts and displacements from extras
// The caller should pass these as device pointers via extras map
auto it_sendCounts = extras.find("sendCounts");
auto it_sendDispls = extras.find("sendDispls");
auto it_recvCounts = extras.find("recvCounts");
auto it_recvDispls = extras.find("recvDispls");
if (it_sendCounts == extras.end() || it_sendDispls == extras.end() ||
it_recvCounts == extras.end() || it_recvDispls == extras.end()) {
return mscclpp::CommResult::CommInternalError;
}
const size_t* d_sendCounts = reinterpret_cast<const size_t*>(it_sendCounts->second);
const size_t* d_sendDispls = reinterpret_cast<const size_t*>(it_sendDispls->second);
const size_t* d_recvCounts = reinterpret_cast<const size_t*>(it_recvCounts->second);
const size_t* d_recvDispls = reinterpret_cast<const size_t*>(it_recvDispls->second);
// Reset device syncer
mscclpp::DeviceSyncer syncer = {};
cudaMemcpyToSymbolAsync(alltoallvDeviceSyncer, &syncer, sizeof(mscclpp::DeviceSyncer), 0,
cudaMemcpyHostToDevice, stream);
// Use simple kernel for small world sizes, multi-block for larger
if (worldSize <= 16) {
int nThreads = (worldSize - 1) * WARP_SIZE;
if (nThreads < 32) nThreads = 32;
if (nThreads > 1024) nThreads = 1024;
alltoallv_simple_kernel<<<1, nThreads, 0, stream>>>(
algoCtx->portChannelDeviceHandles.get(),
rank, worldSize,
input, output,
d_sendCounts, d_sendDispls,
d_recvCounts, d_recvDispls);
} else {
alltoallv_kernel<<<1, 1024, 0, stream>>>(
algoCtx->portChannelDeviceHandles.get(),
rank, worldSize,
input, output,
d_sendCounts, d_sendDispls,
d_recvCounts, d_recvDispls);
}
if (cudaGetLastError() == cudaSuccess) {
return mscclpp::CommResult::CommSuccess;
}
return mscclpp::CommResult::CommInternalError;
}
std::shared_ptr<void> initAlltoallvContext(std::shared_ptr<mscclpp::Communicator> comm, const void* input,
void* output, size_t inputSize, size_t outputSize,
mscclpp::DataType dtype) {
auto ctx = std::make_shared<AllToAllVContext>();
ctx->rank = comm->bootstrap()->getRank();
ctx->worldSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
// Register memories for input and output buffers
mscclpp::RegisteredMemory inputBufRegMem =
comm->registerMemory((void*)input, inputSize, mscclpp::Transport::CudaIpc);
mscclpp::RegisteredMemory outputBufRegMem =
comm->registerMemory(output, outputSize, mscclpp::Transport::CudaIpc);
// Exchange output buffer registration with all peers
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemories;
for (int i = 0; i < ctx->worldSize; i++) {
if (i == ctx->rank) continue;
comm->sendMemory(outputBufRegMem, i, 0);
remoteRegMemories.push_back(comm->recvMemory(i, 0));
}
// Setup port channels for each peer
std::vector<mscclpp::DeviceHandle<mscclpp::PortChannel>> portChannels;
mscclpp::MemoryId inputMemoryId = this->proxyService_->addMemory(inputBufRegMem);
for (size_t i = 0; i < this->conns_.size(); i++) {
auto remoteMemory = remoteRegMemories[i].get();
mscclpp::MemoryId remoteMemoryId = this->proxyService_->addMemory(remoteMemory);
portChannels.push_back(mscclpp::deviceHandle(this->proxyService_->portChannel(
this->proxyService_->buildAndAddSemaphore(*comm, this->conns_[i]), remoteMemoryId, inputMemoryId)));
}
// Allocate and copy port channels to device
ctx->portChannelDeviceHandles =
mscclpp::detail::gpuCallocShared<mscclpp::DeviceHandle<mscclpp::PortChannel>>(portChannels.size());
mscclpp::gpuMemcpy(ctx->portChannelDeviceHandles.get(), portChannels.data(), portChannels.size(),
cudaMemcpyHostToDevice);
// Keep registered memory references to prevent deallocation
std::transform(remoteRegMemories.begin(), remoteRegMemories.end(), std::back_inserter(ctx->registeredMemories),
[](const auto& fut) { return fut.get(); });
ctx->registeredMemories.push_back(inputBufRegMem);
ctx->registeredMemories.push_back(outputBufRegMem);
return ctx;
}
mscclpp::AlgorithmCtxKey generateAlltoallvContextKey(const void* input, void* output, size_t inputSize,
size_t outputSize, mscclpp::DataType dtype) {
return {(void*)input, output, inputSize, outputSize, 0};
}
};
std::shared_ptr<mscclpp::Algorithm> createAlltoallvAlgorithm() {
auto alltoallvAlgoBuilder = std::make_shared<AllToAllVAlgoBuilder>();
auto alltoallvAlgoBuilder = std::make_shared<mscclpp::collective::AlltoallvFullmesh>();
return alltoallvAlgoBuilder->build();
}

View File

@@ -0,0 +1,197 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include "alltoallv/alltoallv_fullmesh.hpp"
#include "alltoallv/alltoallv_kernel.hpp"
#include <mscclpp/core.hpp>
#include <mscclpp/port_channel.hpp>
#include <mscclpp/port_channel_device.hpp>
#include <mscclpp/gpu_utils.hpp>
#include <algorithm>
namespace mscclpp {
namespace collective {
#if defined(__HIP_PLATFORM_AMD__)
#define ALLTOALLV_WARP_SIZE 64
#else
#define ALLTOALLV_WARP_SIZE 32
#endif
// Context to hold all necessary state for alltoallv execution
struct AllToAllVContext {
int rank;
int worldSize;
int nRanksPerNode;
std::vector<RegisteredMemory> registeredMemories;
std::shared_ptr<DeviceHandle<PortChannel>> portChannelDeviceHandles;
};
AlltoallvFullmesh::~AlltoallvFullmesh() {
if (proxyService_) {
proxyService_->stopProxy();
}
}
std::shared_ptr<Algorithm> AlltoallvFullmesh::build() {
auto self = std::shared_ptr<AlltoallvFullmesh>(this, [](AlltoallvFullmesh*) {});
std::shared_ptr<Algorithm> alltoallvAlgo = std::make_shared<NativeAlgorithm>(
"alltoallv", "alltoallv_fullmesh",
// Initialize function
[self](std::shared_ptr<Communicator> comm) { self->initialize(comm); },
// Kernel execution function
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
size_t outputSize, DataType dtype, [[maybe_unused]] ReduceOp op, cudaStream_t stream,
int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>& extras) {
return self->alltoallvKernelFunc(ctx, input, output, inputSize, outputSize, dtype, stream,
nBlocks, nThreadsPerBlock, extras);
},
// Context initialization function
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
size_t outputSize, DataType dtype) {
return self->initAlltoallvContext(comm, input, output, inputSize, outputSize, dtype);
},
// Context key generation function
[self](const void* input, void* output, size_t inputSize, size_t outputSize, DataType dtype) {
return self->generateAlltoallvContextKey(input, output, inputSize, outputSize, dtype);
});
return alltoallvAlgo;
}
void AlltoallvFullmesh::initialize(std::shared_ptr<Communicator> comm) {
std::vector<std::shared_future<Connection>> connectionFutures;
worldSize_ = comm->bootstrap()->getNranks();
int rank = comm->bootstrap()->getRank();
for (int i = 0; i < worldSize_; i++) {
if (i == rank) continue;
connectionFutures.push_back(comm->connect(Transport::CudaIpc, i));
}
std::vector<Connection> connections;
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
[](const auto& future) { return future.get(); });
this->conns_ = std::move(connections);
proxyService_ = std::make_shared<ProxyService>();
proxyService_->startProxy(true);
}
CommResult AlltoallvFullmesh::alltoallvKernelFunc(
const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
size_t outputSize, [[maybe_unused]] DataType dtype, cudaStream_t stream,
[[maybe_unused]] int nBlocks, [[maybe_unused]] int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>& extras) {
auto algoCtx = std::static_pointer_cast<AllToAllVContext>(ctx);
int rank = algoCtx->rank;
int worldSize = algoCtx->worldSize;
// Extract send/recv counts and displacements from extras
auto it_sendCounts = extras.find("sendCounts");
auto it_sendDispls = extras.find("sendDispls");
auto it_recvCounts = extras.find("recvCounts");
auto it_recvDispls = extras.find("recvDispls");
if (it_sendCounts == extras.end() || it_sendDispls == extras.end() ||
it_recvCounts == extras.end() || it_recvDispls == extras.end()) {
return CommResult::CommInternalError;
}
const size_t* d_sendCounts = reinterpret_cast<const size_t*>(it_sendCounts->second);
const size_t* d_sendDispls = reinterpret_cast<const size_t*>(it_sendDispls->second);
const size_t* d_recvCounts = reinterpret_cast<const size_t*>(it_recvCounts->second);
const size_t* d_recvDispls = reinterpret_cast<const size_t*>(it_recvDispls->second);
// Choose kernel based on world size
if (worldSize <= 16) {
// Use parallel warp-based kernel for small world sizes
int nThreads = (worldSize - 1) * ALLTOALLV_WARP_SIZE;
if (nThreads < 32) nThreads = 32;
if (nThreads > 1024) nThreads = 1024;
alltoallvKernel<<<1, nThreads, 0, stream>>>(
algoCtx->portChannelDeviceHandles.get(),
rank, worldSize,
input, output,
d_sendCounts, d_sendDispls,
d_recvCounts, d_recvDispls);
} else {
// Use ring-based kernel for larger world sizes
alltoallvRingKernel<<<1, 32, 0, stream>>>(
algoCtx->portChannelDeviceHandles.get(),
rank, worldSize,
input, output,
d_sendCounts, d_sendDispls,
d_recvCounts, d_recvDispls);
}
if (cudaGetLastError() == cudaSuccess) {
return CommResult::CommSuccess;
}
return CommResult::CommInternalError;
}
std::shared_ptr<void> AlltoallvFullmesh::initAlltoallvContext(
std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
size_t outputSize, [[maybe_unused]] DataType dtype) {
auto ctx = std::make_shared<AllToAllVContext>();
ctx->rank = comm->bootstrap()->getRank();
ctx->worldSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
// Register memories for input and output buffers
RegisteredMemory inputBufRegMem = comm->registerMemory((void*)input, inputSize, Transport::CudaIpc);
RegisteredMemory outputBufRegMem = comm->registerMemory(output, outputSize, Transport::CudaIpc);
// Exchange output buffer registration with all peers
std::vector<std::shared_future<RegisteredMemory>> remoteRegMemories;
for (int i = 0; i < ctx->worldSize; i++) {
if (i == ctx->rank) continue;
comm->sendMemory(outputBufRegMem, i, 0);
remoteRegMemories.push_back(comm->recvMemory(i, 0));
}
// Setup port channels for each peer
std::vector<DeviceHandle<PortChannel>> portChannels;
MemoryId inputMemoryId = this->proxyService_->addMemory(inputBufRegMem);
for (size_t i = 0; i < this->conns_.size(); i++) {
auto remoteMemory = remoteRegMemories[i].get();
MemoryId remoteMemoryId = this->proxyService_->addMemory(remoteMemory);
portChannels.push_back(deviceHandle(this->proxyService_->portChannel(
this->proxyService_->buildAndAddSemaphore(*comm, this->conns_[i]), remoteMemoryId, inputMemoryId)));
}
// Allocate and copy port channels to device
ctx->portChannelDeviceHandles = detail::gpuCallocShared<DeviceHandle<PortChannel>>(portChannels.size());
gpuMemcpy(ctx->portChannelDeviceHandles.get(), portChannels.data(), portChannels.size(),
cudaMemcpyHostToDevice);
// Keep registered memory references to prevent deallocation
std::transform(remoteRegMemories.begin(), remoteRegMemories.end(),
std::back_inserter(ctx->registeredMemories),
[](const auto& fut) { return fut.get(); });
ctx->registeredMemories.push_back(inputBufRegMem);
ctx->registeredMemories.push_back(outputBufRegMem);
return ctx;
}
AlgorithmCtxKey AlltoallvFullmesh::generateAlltoallvContextKey(
const void* input, void* output, size_t inputSize, size_t outputSize,
[[maybe_unused]] DataType dtype) {
return {(void*)input, output, inputSize, outputSize, 0};
}
#undef ALLTOALLV_WARP_SIZE
} // namespace collective
} // namespace mscclpp

View File

@@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include <mscclpp/algorithm.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/port_channel.hpp>
namespace mscclpp {
namespace collective {
/**
* AllToAllV collective operation builder.
*
* This class builds an AllToAllV algorithm that handles variable element counts
* per rank, similar to MPI_Alltoallv. Unlike NCCL's ncclGroupStart/ncclGroupEnd
* approach, mscclpp uses explicit put/signal/wait operations on PortChannels.
*
* The implementation uses a ring-based exchange pattern to avoid deadlocks.
*
* Usage:
* auto builder = std::make_shared<AlltoallvFullmesh>();
* auto algorithm = builder->build();
* // Then execute with extras containing sendCounts, sendDispls, recvCounts, recvDispls
*/
class AlltoallvFullmesh : public AlgorithmBuilder {
public:
AlltoallvFullmesh() = default;
~AlltoallvFullmesh();
std::shared_ptr<Algorithm> build() override;
private:
void initialize(std::shared_ptr<Communicator> comm);
CommResult alltoallvKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output,
size_t inputSize, size_t outputSize, DataType dtype, cudaStream_t stream,
int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>& extras);
std::shared_ptr<void> initAlltoallvContext(std::shared_ptr<Communicator> comm, const void* input,
void* output, size_t inputSize, size_t outputSize,
DataType dtype);
AlgorithmCtxKey generateAlltoallvContextKey(const void* input, void* output, size_t inputSize,
size_t outputSize, DataType dtype);
std::vector<Connection> conns_;
std::shared_ptr<ProxyService> proxyService_;
int worldSize_;
};
} // namespace collective
} // namespace mscclpp

View File

@@ -0,0 +1,153 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include <mscclpp/port_channel_device.hpp>
#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/copy_device.hpp>
namespace mscclpp {
namespace collective {
#if defined(__HIP_PLATFORM_AMD__)
#define ALLTOALLV_WARP_SIZE 64
#else
#define ALLTOALLV_WARP_SIZE 32
#endif
/**
* AllToAllV kernel implementation using parallel warp-based communication.
*
* Each warp handles communication with one peer. All sends happen in parallel,
* followed by flushes and waits.
*
* @param portChannels Array of PortChannel handles for each peer (worldSize-1 channels)
* @param rank Current rank
* @param worldSize Total number of ranks
* @param sendBuff Source buffer containing data to send
* @param recvBuff Destination buffer for received data
* @param sendCounts Array of send counts for each rank (in bytes)
* @param sendDispls Array of send displacements for each rank (in bytes)
* @param recvCounts Array of receive counts for each rank (in bytes)
* @param recvDispls Array of receive displacements for each rank (in bytes)
*/
__global__ void __launch_bounds__(1024)
alltoallvKernel(DeviceHandle<PortChannel>* portChannels,
int rank,
int worldSize,
const void* sendBuff,
void* recvBuff,
const size_t* sendCounts,
const size_t* sendDispls,
const size_t* recvCounts,
const size_t* recvDispls) {
int tid = threadIdx.x;
int nPeers = worldSize - 1;
// Step 1: Copy local data (rank's own portion)
if (tid == 0 && sendCounts[rank] > 0) {
const char* src = (const char*)sendBuff + sendDispls[rank];
char* dst = (char*)recvBuff + recvDispls[rank];
memcpy(dst, src, sendCounts[rank]);
}
__syncthreads();
// Step 2: Each warp handles one peer for sending
int warpId = tid / ALLTOALLV_WARP_SIZE;
int laneId = tid % ALLTOALLV_WARP_SIZE;
if (warpId < nPeers && laneId == 0) {
// Determine which peer this warp handles
int peer = warpId < rank ? warpId : warpId + 1;
int chanIdx = warpId;
if (sendCounts[peer] > 0) {
portChannels[chanIdx].putWithSignal(
recvDispls[rank], // dst offset in peer's buffer
sendDispls[peer], // src offset in our buffer
sendCounts[peer] // size
);
}
}
__syncthreads();
// Step 3: Flush all pending operations
if (warpId < nPeers && laneId == 0) {
int peer = warpId < rank ? warpId : warpId + 1;
if (sendCounts[peer] > 0) {
portChannels[warpId].flush();
}
}
__syncthreads();
// Step 4: Wait for all incoming data
if (warpId < nPeers && laneId == 0) {
int peer = warpId < rank ? warpId : warpId + 1;
if (recvCounts[peer] > 0) {
portChannels[warpId].wait();
}
}
__syncthreads();
}
/**
* Ring-based AllToAllV kernel for serialized communication.
*
* Uses step-by-step ring pattern to exchange data, sending to (rank+step) and
* receiving from (rank-step) in each step. Single thread handles all communication
* to avoid race conditions.
*
* This kernel is more robust but slower than the parallel version.
*/
__global__ void __launch_bounds__(1024)
alltoallvRingKernel(DeviceHandle<PortChannel>* portChannels,
int rank,
int worldSize,
const void* sendBuff,
void* recvBuff,
const size_t* sendCounts,
const size_t* sendDispls,
const size_t* recvCounts,
const size_t* recvDispls) {
// Copy local data first
if (threadIdx.x == 0) {
if (sendCounts[rank] > 0) {
const char* src = (const char*)sendBuff + sendDispls[rank];
char* dst = (char*)recvBuff + recvDispls[rank];
memcpy(dst, src, sendCounts[rank]);
}
}
__syncthreads();
// Ring-based exchange - single thread handles communication
if (threadIdx.x == 0) {
for (int step = 1; step < worldSize; step++) {
int sendPeer = (rank + step) % worldSize;
int recvPeer = (rank - step + worldSize) % worldSize;
int sendChanIdx = sendPeer < rank ? sendPeer : sendPeer - 1;
int recvChanIdx = recvPeer < rank ? recvPeer : recvPeer - 1;
// Send data to sendPeer
if (sendCounts[sendPeer] > 0) {
portChannels[sendChanIdx].putWithSignal(
recvDispls[rank],
sendDispls[sendPeer],
sendCounts[sendPeer]
);
portChannels[sendChanIdx].flush();
}
// Wait for data from recvPeer
if (recvCounts[recvPeer] > 0) {
portChannels[recvChanIdx].wait();
}
}
}
}
#undef ALLTOALLV_WARP_SIZE
} // namespace collective
} // namespace mscclpp

View File

@@ -4,13 +4,16 @@
FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.2/json.tar.xz)
FetchContent_MakeAvailable(json)
# Include path for collective algorithm headers (alltoallv, etc.)
set(COLLECTIVES_INC_DIR ${PROJECT_SOURCE_DIR}/src/ext/collectives/include)
function(add_mscclpp_test_executable name sources)
if(MSCCLPP_USE_ROCM)
set_source_files_properties(${sources} PROPERTIES LANGUAGE CXX)
endif()
add_executable(${name} ${sources} common.cc)
target_link_libraries(${name} ${TEST_LIBS_COMMON} MPI::MPI_CXX nlohmann_json::nlohmann_json)
target_include_directories(${name} ${TEST_INC_COMMON} ${TEST_INC_INTERNAL})
target_include_directories(${name} ${TEST_INC_COMMON} ${TEST_INC_INTERNAL} PRIVATE ${COLLECTIVES_INC_DIR})
set_target_properties(${name} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin/mscclpp-test")
endfunction()

View File

@@ -3,6 +3,7 @@
// AllToAllV test - tests variable-length alltoall operations
// This test validates the alltoallv kernel that handles variable element counts per rank.
// Uses the kernel implementations from src/ext/collectives/include/alltoallv/
#include <cstdint>
#include <cstring>
@@ -11,15 +12,11 @@
#include "common.hpp"
#if defined(__HIP_PLATFORM_AMD__)
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
// Include the alltoallv kernel implementations from src/ext/collectives
#include "alltoallv/alltoallv_kernel.hpp"
template <class T>
using DeviceHandle = mscclpp::DeviceHandle<T>;
__constant__ DeviceHandle<mscclpp::PortChannel> constPortChansV[16];
__device__ mscclpp::DeviceSyncer deviceSyncerV;
static void* localRecvBuffV;
@@ -31,117 +28,8 @@ static size_t* d_sendDispls;
static size_t* d_recvCounts;
static size_t* d_recvDispls;
/**
* AllToAllV kernel implementation
*
* Each rank sends sendCounts[i] bytes to rank i at sendDispls[i] offset,
* and receives recvCounts[i] bytes from rank i at recvDispls[i] offset.
*
* Uses ring-based exchange pattern to avoid deadlocks.
*/
__global__ void __launch_bounds__(1024)
alltoallv0(int rank, int worldSize,
const void* sendBuff, void* recvBuff,
const size_t* sendCounts, const size_t* sendDispls,
const size_t* recvCounts, const size_t* recvDispls) {
int tid = threadIdx.x;
int nPeers = worldSize - 1;
// Step 1: Copy local data (rank's own portion)
if (tid == 0 && sendCounts[rank] > 0) {
const char* src = (const char*)sendBuff + sendDispls[rank];
char* dst = (char*)recvBuff + recvDispls[rank];
memcpy(dst, src, sendCounts[rank]);
}
__syncthreads();
// Step 2: Each warp handles one peer for sending
int warpId = tid / WARP_SIZE;
int laneId = tid % WARP_SIZE;
if (warpId < nPeers && laneId == 0) {
// Determine which peer this warp handles
int peer = warpId < rank ? warpId : warpId + 1;
int chanIdx = warpId;
if (sendCounts[peer] > 0) {
constPortChansV[chanIdx].putWithSignal(
recvDispls[rank], // dst offset in peer's buffer
sendDispls[peer], // src offset in our buffer
sendCounts[peer] // size
);
}
}
__syncthreads();
// Step 3: Flush all pending operations
if (warpId < nPeers && laneId == 0) {
int peer = warpId < rank ? warpId : warpId + 1;
if (sendCounts[peer] > 0) {
constPortChansV[warpId].flush();
}
}
__syncthreads();
// Step 4: Wait for all incoming data
if (warpId < nPeers && laneId == 0) {
int peer = warpId < rank ? warpId : warpId + 1;
if (recvCounts[peer] > 0) {
constPortChansV[warpId].wait();
}
}
__syncthreads();
}
/**
* Ring-based AllToAllV kernel for larger world sizes
*
* Uses step-by-step ring pattern to exchange data, sending to (rank+step) and
* receiving from (rank-step) in each step. Single block to avoid concurrent
* access to the same port channels.
*/
__global__ void __launch_bounds__(1024)
alltoallv1(int rank, int worldSize,
const void* sendBuff, void* recvBuff,
const size_t* sendCounts, const size_t* sendDispls,
const size_t* recvCounts, const size_t* recvDispls) {
// Copy local data first
if (threadIdx.x == 0) {
if (sendCounts[rank] > 0) {
const char* src = (const char*)sendBuff + sendDispls[rank];
char* dst = (char*)recvBuff + recvDispls[rank];
memcpy(dst, src, sendCounts[rank]);
}
}
__syncthreads();
// Ring-based exchange - single thread handles the communication
// to avoid race conditions on port channels
if (threadIdx.x == 0) {
for (int step = 1; step < worldSize; step++) {
int sendPeer = (rank + step) % worldSize;
int recvPeer = (rank - step + worldSize) % worldSize;
int sendChanIdx = sendPeer < rank ? sendPeer : sendPeer - 1;
int recvChanIdx = recvPeer < rank ? recvPeer : recvPeer - 1;
// Send data to sendPeer (non-blocking put with signal)
if (sendCounts[sendPeer] > 0) {
constPortChansV[sendChanIdx].putWithSignal(
recvDispls[rank], // dst offset in peer's buffer
sendDispls[sendPeer], // src offset in our buffer
sendCounts[sendPeer] // size
);
constPortChansV[sendChanIdx].flush();
}
// Wait for data from recvPeer
if (recvCounts[recvPeer] > 0) {
constPortChansV[recvChanIdx].wait();
}
}
}
}
// Device array for port channels (used by library kernels)
static DeviceHandle<mscclpp::PortChannel>* d_portChannels;
class AllToAllVTestColl : public BaseTestColl {
public:
@@ -174,17 +62,23 @@ void AllToAllVTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
CUDATHROW(cudaMemcpyToSymbol(deviceSyncerV, &syncer, sizeof(mscclpp::DeviceSyncer)));
if (kernelNum == 0) {
int nThreads = (worldSize - 1) * WARP_SIZE;
// Use parallel warp-based kernel from library
int nThreads = (worldSize - 1) * 32; // One warp per peer
#if defined(__HIP_PLATFORM_AMD__)
nThreads = (worldSize - 1) * 64;
#endif
if (nThreads < 32) nThreads = 32;
if (nThreads > 1024) nThreads = 1024;
alltoallv0<<<1, nThreads, 0, stream>>>(
mscclpp::collective::alltoallvKernel<<<1, nThreads, 0, stream>>>(
d_portChannels,
rank, worldSize,
localSendBuffV, localRecvBuffV,
d_sendCounts, d_sendDispls,
d_recvCounts, d_recvDispls);
} else if (kernelNum == 1) {
// Single block, single thread for ring-based serialized communication
alltoallv1<<<1, 32, 0, stream>>>(
// Use ring-based kernel from library
mscclpp::collective::alltoallvRingKernel<<<1, 32, 0, stream>>>(
d_portChannels,
rank, worldSize,
localSendBuffV, localRecvBuffV,
d_sendCounts, d_sendDispls,
@@ -275,8 +169,8 @@ void AllToAllVTestColl::setupCollTest(size_t size) {
std::vector<KernelRestriction> AllToAllVTestColl::getKernelRestrictions() {
return {
{0, "alltoallv0", true, 1, 4 * worldSize_},
{1, "alltoallv1", true, 1, 4 * worldSize_}
{0, "alltoallvKernel", true, 1, 4 * worldSize_},
{1, "alltoallvRingKernel", true, 1, 4 * worldSize_}
};
}
@@ -318,17 +212,19 @@ void AllToAllVTestEngine::allocateBuffer() {
CUDATHROW(cudaMalloc(&d_sendDispls, args_.totalRanks * sizeof(size_t)));
CUDATHROW(cudaMalloc(&d_recvCounts, args_.totalRanks * sizeof(size_t)));
CUDATHROW(cudaMalloc(&d_recvDispls, args_.totalRanks * sizeof(size_t)));
// Allocate device array for port channels
CUDATHROW(cudaMalloc(&d_portChannels, args_.totalRanks * sizeof(DeviceHandle<mscclpp::PortChannel>)));
}
void AllToAllVTestEngine::setupConnections() {
std::vector<DeviceHandle<mscclpp::PortChannel>> portChannels;
setupMeshConnections(portChannels, sendBuff_.get(), args_.maxBytes, recvBuff_.get(), args_.maxBytes);
if (portChannels.size() > sizeof(constPortChansV) / sizeof(DeviceHandle<mscclpp::PortChannel>)) {
throw std::runtime_error("Too many port channels for alltoallv test");
}
CUDATHROW(cudaMemcpyToSymbol(constPortChansV, portChannels.data(),
sizeof(DeviceHandle<mscclpp::PortChannel>) * portChannels.size()));
// Copy port channels to device memory for use by library kernels
CUDATHROW(cudaMemcpy(d_portChannels, portChannels.data(),
sizeof(DeviceHandle<mscclpp::PortChannel>) * portChannels.size(),
cudaMemcpyHostToDevice));
}
std::vector<void*> AllToAllVTestEngine::getSendBuff() { return {sendBuff_.get()}; }