From c42579e900aa7ab51b2a588c421f4c8a9e2077ee Mon Sep 17 00:00:00 2001 From: Qinghua Zhou Date: Fri, 6 Feb 2026 02:57:34 +0000 Subject: [PATCH] Move the alltoallv kernel to the src directory; Utilize the kernel in mscclpp-test --- .../torch-integration/alltoallv_kernel.cu | 364 +----------------- .../alltoallv/alltoallv_fullmesh.cu | 197 ++++++++++ .../include/alltoallv/alltoallv_fullmesh.hpp | 55 +++ .../include/alltoallv/alltoallv_kernel.hpp | 153 ++++++++ test/mscclpp-test/CMakeLists.txt | 5 +- test/mscclpp-test/alltoallv_test.cu | 152 ++------ 6 files changed, 440 insertions(+), 486 deletions(-) create mode 100644 src/ext/collectives/alltoallv/alltoallv_fullmesh.cu create mode 100644 src/ext/collectives/include/alltoallv/alltoallv_fullmesh.hpp create mode 100644 src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp diff --git a/examples/torch-integration/alltoallv_kernel.cu b/examples/torch-integration/alltoallv_kernel.cu index 07a518d9..ea017243 100644 --- a/examples/torch-integration/alltoallv_kernel.cu +++ b/examples/torch-integration/alltoallv_kernel.cu @@ -1,372 +1,22 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -// AllToAllV implementation for MSCCLPP -// This kernel handles variable element counts per rank for alltoallv operations. -// Unlike NCCL's ncclGroupStart/ncclGroupEnd approach, mscclpp uses explicit -// put/signal/wait operations on PortChannels. +// AllToAllV Python bindings for MSCCLPP +// This file provides Python bindings for the alltoallv algorithm. +// The actual implementation is in src/ext/collectives/alltoallv/ #include #include #include -#include -#include -#include -#include + +// Include the implementation header +#include "alltoallv/alltoallv_fullmesh.hpp" namespace py = pybind11; -#if defined(__HIP_PLATFORM_AMD__) -#define WARP_SIZE 64 -#else -#define WARP_SIZE 32 -#endif - -// Device syncer for synchronization across blocks -__device__ mscclpp::DeviceSyncer alltoallvDeviceSyncer; - -/** - * AllToAllV kernel implementation - * - * This kernel performs an all-to-all exchange with variable-length data per rank. - * Each rank sends sendCounts[i] elements to rank i at sendDispls[i] offset, - * and receives recvCounts[i] elements from rank i at recvDispls[i] offset. - * - * Since mscclpp doesn't support ncclGroupStart/ncclGroupEnd, we implement - * the exchange using explicit put/signal/wait operations on PortChannels. - * The communication pattern uses a ring-based approach to avoid deadlocks. - * - * @param portChannels Array of PortChannel handles for each peer (worldSize-1 channels) - * @param rank Current rank - * @param worldSize Total number of ranks - * @param sendBuff Source buffer containing data to send - * @param recvBuff Destination buffer for received data - * @param sendCounts Array of send counts for each rank (in bytes) - * @param sendDispls Array of send displacements for each rank (in bytes) - * @param recvCounts Array of receive counts for each rank (in bytes) - * @param recvDispls Array of receive displacements for each rank (in bytes) - */ -__global__ void __launch_bounds__(1024) - alltoallv_kernel(mscclpp::DeviceHandle* portChannels, - int rank, - int worldSize, - const void* sendBuff, - void* recvBuff, - const size_t* sendCounts, - const size_t* sendDispls, - const size_t* recvCounts, - const size_t* recvDispls) { - // First, copy local data (rank's own portion) from send to recv buffer - // This doesn't require any communication - if (threadIdx.x == 0 && blockIdx.x == 0) { - if (sendCounts[rank] > 0) { - // Local copy: sendBuff[sendDispls[rank]] -> recvBuff[recvDispls[rank]] - const char* src = (const char*)sendBuff + sendDispls[rank]; - char* dst = (char*)recvBuff + recvDispls[rank]; - memcpy(dst, src, sendCounts[rank]); - } - } - __syncthreads(); - - // Ring-based exchange pattern to avoid deadlocks - // In each step i, rank sends to (rank + i) % worldSize and receives from (rank - i + worldSize) % worldSize - for (int step = 1; step < worldSize; step++) { - int sendPeer = (rank + step) % worldSize; - int recvPeer = (rank - step + worldSize) % worldSize; - - // Get channel indices (portChannels excludes self, so adjust index) - int sendChanIdx = sendPeer < rank ? sendPeer : sendPeer - 1; - int recvChanIdx = recvPeer < rank ? recvPeer : recvPeer - 1; - - // Each warp handles one peer - int wid = threadIdx.x / WARP_SIZE; - int lid = threadIdx.x % WARP_SIZE; - - // Send data to sendPeer if there's data to send - if (wid == 0 && lid == 0) { - if (sendCounts[sendPeer] > 0) { - // putWithSignal: copy data and signal completion - // src offset: sendDispls[sendPeer] in our sendBuff - // dst offset: recvDispls[rank] in peer's recvBuff (where our data should go) - portChannels[sendChanIdx].putWithSignal( - recvDispls[rank], // dst offset in peer's recv buffer (where we write) - sendDispls[sendPeer], // src offset in our send buffer - sendCounts[sendPeer] // size in bytes - ); - } - } - - // Sync all threads before flushing - alltoallvDeviceSyncer.sync(gridDim.x); - - // Flush to ensure data is sent - if (wid == 0 && lid == 0) { - if (sendCounts[sendPeer] > 0) { - portChannels[sendChanIdx].flush(); - } - } - - // Wait for data from recvPeer if we're expecting data - if (wid == 0 && lid == 0) { - if (recvCounts[recvPeer] > 0) { - portChannels[recvChanIdx].wait(); - } - } - - // Sync all threads before next step - alltoallvDeviceSyncer.sync(gridDim.x); - } -} - -/** - * Simplified AllToAllV kernel for single-block execution - * - * This version is optimized for cases where all communication can be - * handled within a single thread block. - */ -__global__ void __launch_bounds__(1024) - alltoallv_simple_kernel(mscclpp::DeviceHandle* portChannels, - int rank, - int worldSize, - const void* sendBuff, - void* recvBuff, - const size_t* sendCounts, - const size_t* sendDispls, - const size_t* recvCounts, - const size_t* recvDispls) { - int tid = threadIdx.x; - int nPeers = worldSize - 1; - - // Step 1: Copy local data - if (tid == 0 && sendCounts[rank] > 0) { - const char* src = (const char*)sendBuff + sendDispls[rank]; - char* dst = (char*)recvBuff + recvDispls[rank]; - memcpy(dst, src, sendCounts[rank]); - } - __syncthreads(); - - // Step 2: Each warp handles one peer for sending - // We have worldSize-1 peers, assign one warp per peer - int warpId = tid / WARP_SIZE; - int laneId = tid % WARP_SIZE; - - if (warpId < nPeers && laneId == 0) { - // Determine which peer this warp handles - int peer = warpId < rank ? warpId : warpId + 1; - int chanIdx = warpId; - - if (sendCounts[peer] > 0) { - portChannels[chanIdx].putWithSignal( - recvDispls[rank], // dst offset in peer's buffer - sendDispls[peer], // src offset in our buffer - sendCounts[peer] // size - ); - } - } - __syncthreads(); - - // Step 3: Flush all pending operations - if (warpId < nPeers && laneId == 0) { - int peer = warpId < rank ? warpId : warpId + 1; - if (sendCounts[peer] > 0) { - portChannels[warpId].flush(); - } - } - __syncthreads(); - - // Step 4: Wait for all incoming data - if (warpId < nPeers && laneId == 0) { - int peer = warpId < rank ? warpId : warpId + 1; - if (recvCounts[peer] > 0) { - portChannels[warpId].wait(); - } - } - __syncthreads(); -} - -// Context to hold all necessary state for alltoallv execution -struct AllToAllVContext { - int rank; - int worldSize; - int nRanksPerNode; - - std::vector registeredMemories; - std::shared_ptr> portChannelDeviceHandles; - - // Device memory for counts and displacements - size_t* d_sendCounts; - size_t* d_sendDispls; - size_t* d_recvCounts; - size_t* d_recvDispls; -}; - -class AllToAllVAlgoBuilder : public mscclpp::AlgorithmBuilder { - public: - AllToAllVAlgoBuilder() = default; - ~AllToAllVAlgoBuilder() { - if (proxyService_) { - proxyService_->stopProxy(); - } - } - - std::shared_ptr build() override { - auto self = std::make_shared(); - std::shared_ptr alltoallvAlgo = std::make_shared( - "alltoallv", "alltoallv", - // Initialize function - [self](std::shared_ptr comm) { self->initialize(comm); }, - // Kernel execution function - [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, size_t outputSize, - mscclpp::DataType dtype, [[maybe_unused]] mscclpp::ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) { - return self->alltoallvKernelFunc(ctx, input, output, inputSize, outputSize, dtype, stream, extras); - }, - // Context initialization function - [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, - size_t outputSize, - mscclpp::DataType dtype) { return self->initAlltoallvContext(comm, input, output, inputSize, outputSize, dtype); }, - // Context key generation function - [self](const void* input, void* output, size_t inputSize, size_t outputSize, mscclpp::DataType dtype) { - return self->generateAlltoallvContextKey(input, output, inputSize, outputSize, dtype); - }); - return alltoallvAlgo; - } - - private: - std::vector conns_; - std::shared_ptr proxyService_; - int worldSize_; - - void initialize(std::shared_ptr comm) { - std::vector> connectionFutures; - worldSize_ = comm->bootstrap()->getNranks(); - for (int i = 0; i < worldSize_; i++) { - if (i == comm->bootstrap()->getRank()) continue; - connectionFutures.push_back(comm->connect(mscclpp::Transport::CudaIpc, i)); - } - std::vector connections; - std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections), - [](const auto& future) { return future.get(); }); - this->conns_ = std::move(connections); - proxyService_ = std::make_shared(); - proxyService_->startProxy(true); - } - - mscclpp::CommResult alltoallvKernelFunc(const std::shared_ptr ctx, const void* input, void* output, - size_t inputSize, size_t outputSize, - [[maybe_unused]] mscclpp::DataType dtype, - cudaStream_t stream, - const std::unordered_map& extras) { - auto algoCtx = std::static_pointer_cast(ctx); - int rank = algoCtx->rank; - int worldSize = algoCtx->worldSize; - - // Extract send/recv counts and displacements from extras - // The caller should pass these as device pointers via extras map - auto it_sendCounts = extras.find("sendCounts"); - auto it_sendDispls = extras.find("sendDispls"); - auto it_recvCounts = extras.find("recvCounts"); - auto it_recvDispls = extras.find("recvDispls"); - - if (it_sendCounts == extras.end() || it_sendDispls == extras.end() || - it_recvCounts == extras.end() || it_recvDispls == extras.end()) { - return mscclpp::CommResult::CommInternalError; - } - - const size_t* d_sendCounts = reinterpret_cast(it_sendCounts->second); - const size_t* d_sendDispls = reinterpret_cast(it_sendDispls->second); - const size_t* d_recvCounts = reinterpret_cast(it_recvCounts->second); - const size_t* d_recvDispls = reinterpret_cast(it_recvDispls->second); - - // Reset device syncer - mscclpp::DeviceSyncer syncer = {}; - cudaMemcpyToSymbolAsync(alltoallvDeviceSyncer, &syncer, sizeof(mscclpp::DeviceSyncer), 0, - cudaMemcpyHostToDevice, stream); - - // Use simple kernel for small world sizes, multi-block for larger - if (worldSize <= 16) { - int nThreads = (worldSize - 1) * WARP_SIZE; - if (nThreads < 32) nThreads = 32; - if (nThreads > 1024) nThreads = 1024; - - alltoallv_simple_kernel<<<1, nThreads, 0, stream>>>( - algoCtx->portChannelDeviceHandles.get(), - rank, worldSize, - input, output, - d_sendCounts, d_sendDispls, - d_recvCounts, d_recvDispls); - } else { - alltoallv_kernel<<<1, 1024, 0, stream>>>( - algoCtx->portChannelDeviceHandles.get(), - rank, worldSize, - input, output, - d_sendCounts, d_sendDispls, - d_recvCounts, d_recvDispls); - } - - if (cudaGetLastError() == cudaSuccess) { - return mscclpp::CommResult::CommSuccess; - } - return mscclpp::CommResult::CommInternalError; - } - - std::shared_ptr initAlltoallvContext(std::shared_ptr comm, const void* input, - void* output, size_t inputSize, size_t outputSize, - mscclpp::DataType dtype) { - auto ctx = std::make_shared(); - ctx->rank = comm->bootstrap()->getRank(); - ctx->worldSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); - - // Register memories for input and output buffers - mscclpp::RegisteredMemory inputBufRegMem = - comm->registerMemory((void*)input, inputSize, mscclpp::Transport::CudaIpc); - mscclpp::RegisteredMemory outputBufRegMem = - comm->registerMemory(output, outputSize, mscclpp::Transport::CudaIpc); - - // Exchange output buffer registration with all peers - std::vector> remoteRegMemories; - for (int i = 0; i < ctx->worldSize; i++) { - if (i == ctx->rank) continue; - comm->sendMemory(outputBufRegMem, i, 0); - remoteRegMemories.push_back(comm->recvMemory(i, 0)); - } - - // Setup port channels for each peer - std::vector> portChannels; - mscclpp::MemoryId inputMemoryId = this->proxyService_->addMemory(inputBufRegMem); - - for (size_t i = 0; i < this->conns_.size(); i++) { - auto remoteMemory = remoteRegMemories[i].get(); - mscclpp::MemoryId remoteMemoryId = this->proxyService_->addMemory(remoteMemory); - portChannels.push_back(mscclpp::deviceHandle(this->proxyService_->portChannel( - this->proxyService_->buildAndAddSemaphore(*comm, this->conns_[i]), remoteMemoryId, inputMemoryId))); - } - - // Allocate and copy port channels to device - ctx->portChannelDeviceHandles = - mscclpp::detail::gpuCallocShared>(portChannels.size()); - mscclpp::gpuMemcpy(ctx->portChannelDeviceHandles.get(), portChannels.data(), portChannels.size(), - cudaMemcpyHostToDevice); - - // Keep registered memory references to prevent deallocation - std::transform(remoteRegMemories.begin(), remoteRegMemories.end(), std::back_inserter(ctx->registeredMemories), - [](const auto& fut) { return fut.get(); }); - ctx->registeredMemories.push_back(inputBufRegMem); - ctx->registeredMemories.push_back(outputBufRegMem); - - return ctx; - } - - mscclpp::AlgorithmCtxKey generateAlltoallvContextKey(const void* input, void* output, size_t inputSize, - size_t outputSize, mscclpp::DataType dtype) { - return {(void*)input, output, inputSize, outputSize, 0}; - } -}; - std::shared_ptr createAlltoallvAlgorithm() { - auto alltoallvAlgoBuilder = std::make_shared(); + auto alltoallvAlgoBuilder = std::make_shared(); return alltoallvAlgoBuilder->build(); } diff --git a/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu b/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu new file mode 100644 index 00000000..5dc16053 --- /dev/null +++ b/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu @@ -0,0 +1,197 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "alltoallv/alltoallv_fullmesh.hpp" +#include "alltoallv/alltoallv_kernel.hpp" + +#include +#include +#include +#include + +#include + +namespace mscclpp { +namespace collective { + +#if defined(__HIP_PLATFORM_AMD__) +#define ALLTOALLV_WARP_SIZE 64 +#else +#define ALLTOALLV_WARP_SIZE 32 +#endif + +// Context to hold all necessary state for alltoallv execution +struct AllToAllVContext { + int rank; + int worldSize; + int nRanksPerNode; + + std::vector registeredMemories; + std::shared_ptr> portChannelDeviceHandles; +}; + +AlltoallvFullmesh::~AlltoallvFullmesh() { + if (proxyService_) { + proxyService_->stopProxy(); + } +} + +std::shared_ptr AlltoallvFullmesh::build() { + auto self = std::shared_ptr(this, [](AlltoallvFullmesh*) {}); + + std::shared_ptr alltoallvAlgo = std::make_shared( + "alltoallv", "alltoallv_fullmesh", + // Initialize function + [self](std::shared_ptr comm) { self->initialize(comm); }, + // Kernel execution function + [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, + size_t outputSize, DataType dtype, [[maybe_unused]] ReduceOp op, cudaStream_t stream, + int nBlocks, int nThreadsPerBlock, + const std::unordered_map& extras) { + return self->alltoallvKernelFunc(ctx, input, output, inputSize, outputSize, dtype, stream, + nBlocks, nThreadsPerBlock, extras); + }, + // Context initialization function + [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, + size_t outputSize, DataType dtype) { + return self->initAlltoallvContext(comm, input, output, inputSize, outputSize, dtype); + }, + // Context key generation function + [self](const void* input, void* output, size_t inputSize, size_t outputSize, DataType dtype) { + return self->generateAlltoallvContextKey(input, output, inputSize, outputSize, dtype); + }); + + return alltoallvAlgo; +} + +void AlltoallvFullmesh::initialize(std::shared_ptr comm) { + std::vector> connectionFutures; + worldSize_ = comm->bootstrap()->getNranks(); + int rank = comm->bootstrap()->getRank(); + + for (int i = 0; i < worldSize_; i++) { + if (i == rank) continue; + connectionFutures.push_back(comm->connect(Transport::CudaIpc, i)); + } + + std::vector connections; + std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections), + [](const auto& future) { return future.get(); }); + this->conns_ = std::move(connections); + + proxyService_ = std::make_shared(); + proxyService_->startProxy(true); +} + +CommResult AlltoallvFullmesh::alltoallvKernelFunc( + const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, + size_t outputSize, [[maybe_unused]] DataType dtype, cudaStream_t stream, + [[maybe_unused]] int nBlocks, [[maybe_unused]] int nThreadsPerBlock, + const std::unordered_map& extras) { + + auto algoCtx = std::static_pointer_cast(ctx); + int rank = algoCtx->rank; + int worldSize = algoCtx->worldSize; + + // Extract send/recv counts and displacements from extras + auto it_sendCounts = extras.find("sendCounts"); + auto it_sendDispls = extras.find("sendDispls"); + auto it_recvCounts = extras.find("recvCounts"); + auto it_recvDispls = extras.find("recvDispls"); + + if (it_sendCounts == extras.end() || it_sendDispls == extras.end() || + it_recvCounts == extras.end() || it_recvDispls == extras.end()) { + return CommResult::CommInternalError; + } + + const size_t* d_sendCounts = reinterpret_cast(it_sendCounts->second); + const size_t* d_sendDispls = reinterpret_cast(it_sendDispls->second); + const size_t* d_recvCounts = reinterpret_cast(it_recvCounts->second); + const size_t* d_recvDispls = reinterpret_cast(it_recvDispls->second); + + // Choose kernel based on world size + if (worldSize <= 16) { + // Use parallel warp-based kernel for small world sizes + int nThreads = (worldSize - 1) * ALLTOALLV_WARP_SIZE; + if (nThreads < 32) nThreads = 32; + if (nThreads > 1024) nThreads = 1024; + + alltoallvKernel<<<1, nThreads, 0, stream>>>( + algoCtx->portChannelDeviceHandles.get(), + rank, worldSize, + input, output, + d_sendCounts, d_sendDispls, + d_recvCounts, d_recvDispls); + } else { + // Use ring-based kernel for larger world sizes + alltoallvRingKernel<<<1, 32, 0, stream>>>( + algoCtx->portChannelDeviceHandles.get(), + rank, worldSize, + input, output, + d_sendCounts, d_sendDispls, + d_recvCounts, d_recvDispls); + } + + if (cudaGetLastError() == cudaSuccess) { + return CommResult::CommSuccess; + } + return CommResult::CommInternalError; +} + +std::shared_ptr AlltoallvFullmesh::initAlltoallvContext( + std::shared_ptr comm, const void* input, void* output, size_t inputSize, + size_t outputSize, [[maybe_unused]] DataType dtype) { + + auto ctx = std::make_shared(); + ctx->rank = comm->bootstrap()->getRank(); + ctx->worldSize = comm->bootstrap()->getNranks(); + ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + + // Register memories for input and output buffers + RegisteredMemory inputBufRegMem = comm->registerMemory((void*)input, inputSize, Transport::CudaIpc); + RegisteredMemory outputBufRegMem = comm->registerMemory(output, outputSize, Transport::CudaIpc); + + // Exchange output buffer registration with all peers + std::vector> remoteRegMemories; + for (int i = 0; i < ctx->worldSize; i++) { + if (i == ctx->rank) continue; + comm->sendMemory(outputBufRegMem, i, 0); + remoteRegMemories.push_back(comm->recvMemory(i, 0)); + } + + // Setup port channels for each peer + std::vector> portChannels; + MemoryId inputMemoryId = this->proxyService_->addMemory(inputBufRegMem); + + for (size_t i = 0; i < this->conns_.size(); i++) { + auto remoteMemory = remoteRegMemories[i].get(); + MemoryId remoteMemoryId = this->proxyService_->addMemory(remoteMemory); + portChannels.push_back(deviceHandle(this->proxyService_->portChannel( + this->proxyService_->buildAndAddSemaphore(*comm, this->conns_[i]), remoteMemoryId, inputMemoryId))); + } + + // Allocate and copy port channels to device + ctx->portChannelDeviceHandles = detail::gpuCallocShared>(portChannels.size()); + gpuMemcpy(ctx->portChannelDeviceHandles.get(), portChannels.data(), portChannels.size(), + cudaMemcpyHostToDevice); + + // Keep registered memory references to prevent deallocation + std::transform(remoteRegMemories.begin(), remoteRegMemories.end(), + std::back_inserter(ctx->registeredMemories), + [](const auto& fut) { return fut.get(); }); + ctx->registeredMemories.push_back(inputBufRegMem); + ctx->registeredMemories.push_back(outputBufRegMem); + + return ctx; +} + +AlgorithmCtxKey AlltoallvFullmesh::generateAlltoallvContextKey( + const void* input, void* output, size_t inputSize, size_t outputSize, + [[maybe_unused]] DataType dtype) { + return {(void*)input, output, inputSize, outputSize, 0}; +} + +#undef ALLTOALLV_WARP_SIZE + +} // namespace collective +} // namespace mscclpp diff --git a/src/ext/collectives/include/alltoallv/alltoallv_fullmesh.hpp b/src/ext/collectives/include/alltoallv/alltoallv_fullmesh.hpp new file mode 100644 index 00000000..99c09200 --- /dev/null +++ b/src/ext/collectives/include/alltoallv/alltoallv_fullmesh.hpp @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace mscclpp { +namespace collective { + +/** + * AllToAllV collective operation builder. + * + * This class builds an AllToAllV algorithm that handles variable element counts + * per rank, similar to MPI_Alltoallv. Unlike NCCL's ncclGroupStart/ncclGroupEnd + * approach, mscclpp uses explicit put/signal/wait operations on PortChannels. + * + * The implementation uses a ring-based exchange pattern to avoid deadlocks. + * + * Usage: + * auto builder = std::make_shared(); + * auto algorithm = builder->build(); + * // Then execute with extras containing sendCounts, sendDispls, recvCounts, recvDispls + */ +class AlltoallvFullmesh : public AlgorithmBuilder { + public: + AlltoallvFullmesh() = default; + ~AlltoallvFullmesh(); + + std::shared_ptr build() override; + + private: + void initialize(std::shared_ptr comm); + + CommResult alltoallvKernelFunc(const std::shared_ptr ctx, const void* input, void* output, + size_t inputSize, size_t outputSize, DataType dtype, cudaStream_t stream, + int nBlocks, int nThreadsPerBlock, + const std::unordered_map& extras); + + std::shared_ptr initAlltoallvContext(std::shared_ptr comm, const void* input, + void* output, size_t inputSize, size_t outputSize, + DataType dtype); + + AlgorithmCtxKey generateAlltoallvContextKey(const void* input, void* output, size_t inputSize, + size_t outputSize, DataType dtype); + + std::vector conns_; + std::shared_ptr proxyService_; + int worldSize_; +}; + +} // namespace collective +} // namespace mscclpp diff --git a/src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp b/src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp new file mode 100644 index 00000000..e1ec948c --- /dev/null +++ b/src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp @@ -0,0 +1,153 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace mscclpp { +namespace collective { + +#if defined(__HIP_PLATFORM_AMD__) +#define ALLTOALLV_WARP_SIZE 64 +#else +#define ALLTOALLV_WARP_SIZE 32 +#endif + +/** + * AllToAllV kernel implementation using parallel warp-based communication. + * + * Each warp handles communication with one peer. All sends happen in parallel, + * followed by flushes and waits. + * + * @param portChannels Array of PortChannel handles for each peer (worldSize-1 channels) + * @param rank Current rank + * @param worldSize Total number of ranks + * @param sendBuff Source buffer containing data to send + * @param recvBuff Destination buffer for received data + * @param sendCounts Array of send counts for each rank (in bytes) + * @param sendDispls Array of send displacements for each rank (in bytes) + * @param recvCounts Array of receive counts for each rank (in bytes) + * @param recvDispls Array of receive displacements for each rank (in bytes) + */ +__global__ void __launch_bounds__(1024) + alltoallvKernel(DeviceHandle* portChannels, + int rank, + int worldSize, + const void* sendBuff, + void* recvBuff, + const size_t* sendCounts, + const size_t* sendDispls, + const size_t* recvCounts, + const size_t* recvDispls) { + int tid = threadIdx.x; + int nPeers = worldSize - 1; + + // Step 1: Copy local data (rank's own portion) + if (tid == 0 && sendCounts[rank] > 0) { + const char* src = (const char*)sendBuff + sendDispls[rank]; + char* dst = (char*)recvBuff + recvDispls[rank]; + memcpy(dst, src, sendCounts[rank]); + } + __syncthreads(); + + // Step 2: Each warp handles one peer for sending + int warpId = tid / ALLTOALLV_WARP_SIZE; + int laneId = tid % ALLTOALLV_WARP_SIZE; + + if (warpId < nPeers && laneId == 0) { + // Determine which peer this warp handles + int peer = warpId < rank ? warpId : warpId + 1; + int chanIdx = warpId; + + if (sendCounts[peer] > 0) { + portChannels[chanIdx].putWithSignal( + recvDispls[rank], // dst offset in peer's buffer + sendDispls[peer], // src offset in our buffer + sendCounts[peer] // size + ); + } + } + __syncthreads(); + + // Step 3: Flush all pending operations + if (warpId < nPeers && laneId == 0) { + int peer = warpId < rank ? warpId : warpId + 1; + if (sendCounts[peer] > 0) { + portChannels[warpId].flush(); + } + } + __syncthreads(); + + // Step 4: Wait for all incoming data + if (warpId < nPeers && laneId == 0) { + int peer = warpId < rank ? warpId : warpId + 1; + if (recvCounts[peer] > 0) { + portChannels[warpId].wait(); + } + } + __syncthreads(); +} + +/** + * Ring-based AllToAllV kernel for serialized communication. + * + * Uses step-by-step ring pattern to exchange data, sending to (rank+step) and + * receiving from (rank-step) in each step. Single thread handles all communication + * to avoid race conditions. + * + * This kernel is more robust but slower than the parallel version. + */ +__global__ void __launch_bounds__(1024) + alltoallvRingKernel(DeviceHandle* portChannels, + int rank, + int worldSize, + const void* sendBuff, + void* recvBuff, + const size_t* sendCounts, + const size_t* sendDispls, + const size_t* recvCounts, + const size_t* recvDispls) { + // Copy local data first + if (threadIdx.x == 0) { + if (sendCounts[rank] > 0) { + const char* src = (const char*)sendBuff + sendDispls[rank]; + char* dst = (char*)recvBuff + recvDispls[rank]; + memcpy(dst, src, sendCounts[rank]); + } + } + __syncthreads(); + + // Ring-based exchange - single thread handles communication + if (threadIdx.x == 0) { + for (int step = 1; step < worldSize; step++) { + int sendPeer = (rank + step) % worldSize; + int recvPeer = (rank - step + worldSize) % worldSize; + + int sendChanIdx = sendPeer < rank ? sendPeer : sendPeer - 1; + int recvChanIdx = recvPeer < rank ? recvPeer : recvPeer - 1; + + // Send data to sendPeer + if (sendCounts[sendPeer] > 0) { + portChannels[sendChanIdx].putWithSignal( + recvDispls[rank], + sendDispls[sendPeer], + sendCounts[sendPeer] + ); + portChannels[sendChanIdx].flush(); + } + + // Wait for data from recvPeer + if (recvCounts[recvPeer] > 0) { + portChannels[recvChanIdx].wait(); + } + } + } +} + +#undef ALLTOALLV_WARP_SIZE + +} // namespace collective +} // namespace mscclpp diff --git a/test/mscclpp-test/CMakeLists.txt b/test/mscclpp-test/CMakeLists.txt index d249b4d7..8b9c63fa 100644 --- a/test/mscclpp-test/CMakeLists.txt +++ b/test/mscclpp-test/CMakeLists.txt @@ -4,13 +4,16 @@ FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.2/json.tar.xz) FetchContent_MakeAvailable(json) +# Include path for collective algorithm headers (alltoallv, etc.) +set(COLLECTIVES_INC_DIR ${PROJECT_SOURCE_DIR}/src/ext/collectives/include) + function(add_mscclpp_test_executable name sources) if(MSCCLPP_USE_ROCM) set_source_files_properties(${sources} PROPERTIES LANGUAGE CXX) endif() add_executable(${name} ${sources} common.cc) target_link_libraries(${name} ${TEST_LIBS_COMMON} MPI::MPI_CXX nlohmann_json::nlohmann_json) - target_include_directories(${name} ${TEST_INC_COMMON} ${TEST_INC_INTERNAL}) + target_include_directories(${name} ${TEST_INC_COMMON} ${TEST_INC_INTERNAL} PRIVATE ${COLLECTIVES_INC_DIR}) set_target_properties(${name} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin/mscclpp-test") endfunction() diff --git a/test/mscclpp-test/alltoallv_test.cu b/test/mscclpp-test/alltoallv_test.cu index 8467a7a4..f5910bd0 100644 --- a/test/mscclpp-test/alltoallv_test.cu +++ b/test/mscclpp-test/alltoallv_test.cu @@ -3,6 +3,7 @@ // AllToAllV test - tests variable-length alltoall operations // This test validates the alltoallv kernel that handles variable element counts per rank. +// Uses the kernel implementations from src/ext/collectives/include/alltoallv/ #include #include @@ -11,15 +12,11 @@ #include "common.hpp" -#if defined(__HIP_PLATFORM_AMD__) -#define WARP_SIZE 64 -#else -#define WARP_SIZE 32 -#endif +// Include the alltoallv kernel implementations from src/ext/collectives +#include "alltoallv/alltoallv_kernel.hpp" template using DeviceHandle = mscclpp::DeviceHandle; -__constant__ DeviceHandle constPortChansV[16]; __device__ mscclpp::DeviceSyncer deviceSyncerV; static void* localRecvBuffV; @@ -31,117 +28,8 @@ static size_t* d_sendDispls; static size_t* d_recvCounts; static size_t* d_recvDispls; -/** - * AllToAllV kernel implementation - * - * Each rank sends sendCounts[i] bytes to rank i at sendDispls[i] offset, - * and receives recvCounts[i] bytes from rank i at recvDispls[i] offset. - * - * Uses ring-based exchange pattern to avoid deadlocks. - */ -__global__ void __launch_bounds__(1024) - alltoallv0(int rank, int worldSize, - const void* sendBuff, void* recvBuff, - const size_t* sendCounts, const size_t* sendDispls, - const size_t* recvCounts, const size_t* recvDispls) { - int tid = threadIdx.x; - int nPeers = worldSize - 1; - - // Step 1: Copy local data (rank's own portion) - if (tid == 0 && sendCounts[rank] > 0) { - const char* src = (const char*)sendBuff + sendDispls[rank]; - char* dst = (char*)recvBuff + recvDispls[rank]; - memcpy(dst, src, sendCounts[rank]); - } - __syncthreads(); - - // Step 2: Each warp handles one peer for sending - int warpId = tid / WARP_SIZE; - int laneId = tid % WARP_SIZE; - - if (warpId < nPeers && laneId == 0) { - // Determine which peer this warp handles - int peer = warpId < rank ? warpId : warpId + 1; - int chanIdx = warpId; - - if (sendCounts[peer] > 0) { - constPortChansV[chanIdx].putWithSignal( - recvDispls[rank], // dst offset in peer's buffer - sendDispls[peer], // src offset in our buffer - sendCounts[peer] // size - ); - } - } - __syncthreads(); - - // Step 3: Flush all pending operations - if (warpId < nPeers && laneId == 0) { - int peer = warpId < rank ? warpId : warpId + 1; - if (sendCounts[peer] > 0) { - constPortChansV[warpId].flush(); - } - } - __syncthreads(); - - // Step 4: Wait for all incoming data - if (warpId < nPeers && laneId == 0) { - int peer = warpId < rank ? warpId : warpId + 1; - if (recvCounts[peer] > 0) { - constPortChansV[warpId].wait(); - } - } - __syncthreads(); -} - -/** - * Ring-based AllToAllV kernel for larger world sizes - * - * Uses step-by-step ring pattern to exchange data, sending to (rank+step) and - * receiving from (rank-step) in each step. Single block to avoid concurrent - * access to the same port channels. - */ -__global__ void __launch_bounds__(1024) - alltoallv1(int rank, int worldSize, - const void* sendBuff, void* recvBuff, - const size_t* sendCounts, const size_t* sendDispls, - const size_t* recvCounts, const size_t* recvDispls) { - // Copy local data first - if (threadIdx.x == 0) { - if (sendCounts[rank] > 0) { - const char* src = (const char*)sendBuff + sendDispls[rank]; - char* dst = (char*)recvBuff + recvDispls[rank]; - memcpy(dst, src, sendCounts[rank]); - } - } - __syncthreads(); - - // Ring-based exchange - single thread handles the communication - // to avoid race conditions on port channels - if (threadIdx.x == 0) { - for (int step = 1; step < worldSize; step++) { - int sendPeer = (rank + step) % worldSize; - int recvPeer = (rank - step + worldSize) % worldSize; - - int sendChanIdx = sendPeer < rank ? sendPeer : sendPeer - 1; - int recvChanIdx = recvPeer < rank ? recvPeer : recvPeer - 1; - - // Send data to sendPeer (non-blocking put with signal) - if (sendCounts[sendPeer] > 0) { - constPortChansV[sendChanIdx].putWithSignal( - recvDispls[rank], // dst offset in peer's buffer - sendDispls[sendPeer], // src offset in our buffer - sendCounts[sendPeer] // size - ); - constPortChansV[sendChanIdx].flush(); - } - - // Wait for data from recvPeer - if (recvCounts[recvPeer] > 0) { - constPortChansV[recvChanIdx].wait(); - } - } - } -} +// Device array for port channels (used by library kernels) +static DeviceHandle* d_portChannels; class AllToAllVTestColl : public BaseTestColl { public: @@ -174,17 +62,23 @@ void AllToAllVTestColl::runColl(const TestArgs& args, cudaStream_t stream) { CUDATHROW(cudaMemcpyToSymbol(deviceSyncerV, &syncer, sizeof(mscclpp::DeviceSyncer))); if (kernelNum == 0) { - int nThreads = (worldSize - 1) * WARP_SIZE; + // Use parallel warp-based kernel from library + int nThreads = (worldSize - 1) * 32; // One warp per peer +#if defined(__HIP_PLATFORM_AMD__) + nThreads = (worldSize - 1) * 64; +#endif if (nThreads < 32) nThreads = 32; if (nThreads > 1024) nThreads = 1024; - alltoallv0<<<1, nThreads, 0, stream>>>( + mscclpp::collective::alltoallvKernel<<<1, nThreads, 0, stream>>>( + d_portChannels, rank, worldSize, localSendBuffV, localRecvBuffV, d_sendCounts, d_sendDispls, d_recvCounts, d_recvDispls); } else if (kernelNum == 1) { - // Single block, single thread for ring-based serialized communication - alltoallv1<<<1, 32, 0, stream>>>( + // Use ring-based kernel from library + mscclpp::collective::alltoallvRingKernel<<<1, 32, 0, stream>>>( + d_portChannels, rank, worldSize, localSendBuffV, localRecvBuffV, d_sendCounts, d_sendDispls, @@ -275,8 +169,8 @@ void AllToAllVTestColl::setupCollTest(size_t size) { std::vector AllToAllVTestColl::getKernelRestrictions() { return { - {0, "alltoallv0", true, 1, 4 * worldSize_}, - {1, "alltoallv1", true, 1, 4 * worldSize_} + {0, "alltoallvKernel", true, 1, 4 * worldSize_}, + {1, "alltoallvRingKernel", true, 1, 4 * worldSize_} }; } @@ -318,17 +212,19 @@ void AllToAllVTestEngine::allocateBuffer() { CUDATHROW(cudaMalloc(&d_sendDispls, args_.totalRanks * sizeof(size_t))); CUDATHROW(cudaMalloc(&d_recvCounts, args_.totalRanks * sizeof(size_t))); CUDATHROW(cudaMalloc(&d_recvDispls, args_.totalRanks * sizeof(size_t))); + + // Allocate device array for port channels + CUDATHROW(cudaMalloc(&d_portChannels, args_.totalRanks * sizeof(DeviceHandle))); } void AllToAllVTestEngine::setupConnections() { std::vector> portChannels; setupMeshConnections(portChannels, sendBuff_.get(), args_.maxBytes, recvBuff_.get(), args_.maxBytes); - if (portChannels.size() > sizeof(constPortChansV) / sizeof(DeviceHandle)) { - throw std::runtime_error("Too many port channels for alltoallv test"); - } - CUDATHROW(cudaMemcpyToSymbol(constPortChansV, portChannels.data(), - sizeof(DeviceHandle) * portChannels.size())); + // Copy port channels to device memory for use by library kernels + CUDATHROW(cudaMemcpy(d_portChannels, portChannels.data(), + sizeof(DeviceHandle) * portChannels.size(), + cudaMemcpyHostToDevice)); } std::vector AllToAllVTestEngine::getSendBuff() { return {sendBuff_.get()}; }