From ac3e770c42ff6d037e41b5739e9970932e5c8f2c Mon Sep 17 00:00:00 2001 From: Qinghua Zhou Date: Thu, 5 Feb 2026 07:41:35 +0000 Subject: [PATCH] Add alltoallv kernel and test --- examples/torch-integration/alltoallv.py | 315 ++++++++++++++ .../torch-integration/alltoallv_kernel.cu | 400 ++++++++++++++++++ test/mscclpp-test/CMakeLists.txt | 1 + test/mscclpp-test/alltoallv_test.cu | 347 +++++++++++++++ 4 files changed, 1063 insertions(+) create mode 100644 examples/torch-integration/alltoallv.py create mode 100644 examples/torch-integration/alltoallv_kernel.cu create mode 100644 test/mscclpp-test/alltoallv_test.cu diff --git a/examples/torch-integration/alltoallv.py b/examples/torch-integration/alltoallv.py new file mode 100644 index 00000000..29968a26 --- /dev/null +++ b/examples/torch-integration/alltoallv.py @@ -0,0 +1,315 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# AllToAllV implementation for MSCCLPP +# This module provides a PyTorch-compatible alltoallv operation using MSCCLPP. +# +# Usage: +# MSCCLPP_MASTER_ADDR= MSCCLPP_MASTER_PORT= \ +# torchrun --nnodes=1 --nproc_per_node=8 alltoallv.py +# +# For AMD GPUs: +# MSCCLPP_MASTER_ADDR= MSCCLPP_MASTER_PORT= \ +# GPU_MAX_HW_QUEUES=7 torchrun --nnodes=1 --nproc_per_node=8 alltoallv.py + +import mscclpp +import mscclpp.utils as mscclpp_utils +import torch +import os +import netifaces as ni +import ipaddress +from typing import List, Optional + +_abs_path = os.path.dirname(os.path.abspath(__file__)) + + +def interfaces_for_ip_netifaces(ip: str): + """Find the network interface for a given IP address.""" + target = ipaddress.ip_address(ip) + for interface in ni.interfaces(): + addresses = ni.ifaddresses(interface) + if ni.AF_INET in addresses: + for link in addresses[ni.AF_INET]: + if "addr" in link: + addr = ipaddress.ip_address(link["addr"]) + if addr == target: + return interface + return None + + +class AllToAllVComm: + """ + AllToAllV communication class using MSCCLPP. + + This class provides a customized alltoallv implementation that handles + variable element counts per rank, similar to MPI_Alltoallv or the + batch_all_to_all_v pattern commonly used in MOE (Mixture of Experts) models. + + Unlike NCCL's ncclGroupStart/ncclGroupEnd approach, MSCCLPP uses explicit + put/signal/wait operations on PortChannels for communication. + + Attributes: + comm: MSCCLPP CommGroup instance + rank: Current rank + world_size: Total number of ranks + """ + + def __init__(self, comm: mscclpp.CommGroup): + """ + Initialize AllToAllV communication. + + Args: + comm: MSCCLPP CommGroup instance + """ + self.comm = comm + self.rank = comm.my_rank + self.world_size = comm.nranks + self.local_rank = comm.my_rank % comm.nranks_per_node + self.n_ranks_per_node = comm.nranks_per_node + self.executor = mscclpp.Executor(comm.communicator) + + # Compile and load the native CUDA kernel + mscclpp_native = mscclpp.compile_native( + name="mscclpp_alltoallv", + file=os.path.join(_abs_path, "alltoallv_kernel.cu") + ) + capsule = mscclpp_native.create_alltoallv_algorithm() + self.algorithm = mscclpp.Algorithm.create_from_native_capsule(capsule) + + def alltoallv( + self, + send_tensor: torch.Tensor, + recv_tensor: torch.Tensor, + send_counts: torch.Tensor, + send_displs: torch.Tensor, + recv_counts: torch.Tensor, + recv_displs: torch.Tensor, + stream: Optional[torch.cuda.Stream] = None + ): + """ + Perform alltoallv operation with variable element counts. + + This function exchanges data between all ranks where each rank can send + different amounts of data to each other rank. + + Args: + send_tensor: Source tensor containing all data to be sent + recv_tensor: Destination tensor for received data + send_counts: Tensor of shape [world_size] with byte counts to send to each rank + send_displs: Tensor of shape [world_size] with byte offsets in send_tensor for each rank + recv_counts: Tensor of shape [world_size] with byte counts to receive from each rank + recv_displs: Tensor of shape [world_size] with byte offsets in recv_tensor for each rank + stream: Optional CUDA stream to use for the operation + + Note: + - All count and displacement tensors should be on GPU and contain size_t values + - send_counts[i] is the number of bytes to send to rank i + - send_displs[i] is the byte offset in send_tensor for data going to rank i + - recv_counts[i] is the number of bytes to receive from rank i + - recv_displs[i] is the byte offset in recv_tensor for data from rank i + """ + # Ensure counts and displacements are on GPU and have correct dtype + assert send_counts.device.type == "cuda", "send_counts must be on GPU" + assert send_displs.device.type == "cuda", "send_displs must be on GPU" + assert recv_counts.device.type == "cuda", "recv_counts must be on GPU" + assert recv_displs.device.type == "cuda", "recv_displs must be on GPU" + + # Prepare extras dict with device pointers for counts and displacements + extras = { + "sendCounts": send_counts.data_ptr(), + "sendDispls": send_displs.data_ptr(), + "recvCounts": recv_counts.data_ptr(), + "recvDispls": recv_displs.data_ptr(), + } + + cuda_stream = stream.cuda_stream if stream is not None else 0 + + self.algorithm.execute( + self.comm.communicator, + send_tensor.data_ptr(), + recv_tensor.data_ptr(), + send_tensor.nbytes, + recv_tensor.nbytes, + mscclpp_utils.torch_dtype_to_mscclpp_dtype(send_tensor.dtype), + stream=cuda_stream, + extras=extras + ) + + def alltoallv_by_elements( + self, + send_tensor: torch.Tensor, + recv_tensor: torch.Tensor, + send_counts_elements: torch.Tensor, + recv_counts_elements: torch.Tensor, + stream: Optional[torch.cuda.Stream] = None + ): + """ + Convenience function for alltoallv with element counts (not byte counts). + + This function automatically computes byte counts and displacements from + element counts, making it easier to use with typical MOE patterns. + + Args: + send_tensor: Source tensor containing all data to be sent + recv_tensor: Destination tensor for received data + send_counts_elements: Tensor of shape [world_size] with element counts to send to each rank + recv_counts_elements: Tensor of shape [world_size] with element counts to receive from each rank + stream: Optional CUDA stream to use for the operation + """ + element_size = send_tensor.element_size() + + # Convert element counts to byte counts + send_counts = send_counts_elements.to(torch.int64) * element_size + recv_counts = recv_counts_elements.to(torch.int64) * element_size + + # Compute displacements (exclusive prefix sum) + send_displs = torch.zeros(self.world_size, dtype=torch.int64, device=send_tensor.device) + recv_displs = torch.zeros(self.world_size, dtype=torch.int64, device=recv_tensor.device) + + if self.world_size > 1: + send_displs[1:] = torch.cumsum(send_counts[:-1], dim=0) + recv_displs[1:] = torch.cumsum(recv_counts[:-1], dim=0) + + self.alltoallv( + send_tensor, recv_tensor, + send_counts, send_displs, + recv_counts, recv_displs, + stream + ) + + def barrier_cpu(self): + """CPU barrier to synchronize all ranks.""" + self.comm.barrier() + + +def batch_alltoallv( + comm: AllToAllVComm, + inputs: List[torch.Tensor], + outputs: List[torch.Tensor], + in_sizes: torch.Tensor, + out_sizes: torch.Tensor, + stream: Optional[torch.cuda.Stream] = None +): + """ + Batch all-to-all-v operation for multiple tensors. + + This function replicates the pattern from MOE implementations: + ``` + for k in range(len(inputs)): + ncclGroupStart() + for i in range(world_size): + ncclSend(in_buff + in_offset, in_sizes[i], ...) + ncclRecv(out_buff + out_offset, out_sizes[i], ...) + ncclGroupEnd() + ``` + + Since MSCCLPP doesn't support ncclGroupStart/ncclGroupEnd, we implement + this using explicit alltoallv operations for each tensor in the batch. + + Args: + comm: AllToAllVComm instance + inputs: List of input tensors to send + outputs: List of output tensors to receive + in_sizes: Tensor of shape [world_size] with element counts to send to each rank + out_sizes: Tensor of shape [world_size] with element counts to receive from each rank + stream: Optional CUDA stream + """ + assert len(inputs) == len(outputs), "Input and output lists must have same length" + + # Ensure sizes are on CPU for computing displacements + in_sizes_cpu = in_sizes.cpu().to(torch.int64) + out_sizes_cpu = out_sizes.cpu().to(torch.int64) + + for k in range(len(inputs)): + input_tensor = inputs[k] + output_tensor = outputs[k] + + element_size = input_tensor.element_size() + + # Compute byte counts + send_counts = (in_sizes_cpu * element_size).cuda() + recv_counts = (out_sizes_cpu * element_size).cuda() + + # Compute displacements + send_displs = torch.zeros(comm.world_size, dtype=torch.int64, device="cuda") + recv_displs = torch.zeros(comm.world_size, dtype=torch.int64, device="cuda") + + if comm.world_size > 1: + send_displs_cpu = torch.zeros(comm.world_size, dtype=torch.int64) + recv_displs_cpu = torch.zeros(comm.world_size, dtype=torch.int64) + send_displs_cpu[1:] = torch.cumsum(send_counts.cpu()[:-1], dim=0) + recv_displs_cpu[1:] = torch.cumsum(recv_counts.cpu()[:-1], dim=0) + send_displs = send_displs_cpu.cuda() + recv_displs = recv_displs_cpu.cuda() + + comm.alltoallv( + input_tensor, output_tensor, + send_counts, send_displs, + recv_counts, recv_displs, + stream + ) + + +def main(): + """Test the alltoallv implementation.""" + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ.get("LOCAL_RANK", os.environ["RANK"])) + torch.cuda.set_device(local_rank) + + master_addr = os.environ["MSCCLPP_MASTER_ADDR"] + master_port = os.environ["MSCCLPP_MASTER_PORT"] + + interface = interfaces_for_ip_netifaces(master_addr) + if interface is None: + raise ValueError(f"Cannot find network interface for IP address {master_addr}") + + interfaceIpPortTrio = f"{interface}:{master_addr}:{master_port}" + mscclpp_group = mscclpp.CommGroup(interfaceIpPortTrio=interfaceIpPortTrio, rank=rank, size=world_size) + + # Create communicator + comm = AllToAllVComm(mscclpp_group) + + # Test with variable sizes per rank + # Each rank sends different amounts to different peers + # For simplicity: rank i sends (i+1)*100 elements to each peer + elements_per_peer = (rank + 1) * 100 + + # Create send buffer with data + total_send = elements_per_peer * world_size + send_tensor = torch.arange(total_send, device="cuda", dtype=torch.float32) + rank * 10000 + + # Receive buffer needs to accommodate variable amounts from each sender + # Each sender j sends (j+1)*100 elements to us + recv_counts_cpu = torch.tensor([(j + 1) * 100 for j in range(world_size)], dtype=torch.int64) + total_recv = recv_counts_cpu.sum().item() + recv_tensor = torch.zeros(total_recv, device="cuda", dtype=torch.float32) + + # Send counts: we send elements_per_peer to everyone + send_counts_cpu = torch.tensor([elements_per_peer] * world_size, dtype=torch.int64) + + # Move to GPU + send_counts = send_counts_cpu.cuda() + recv_counts = recv_counts_cpu.cuda() + + comm.barrier_cpu() + + # Perform alltoallv + comm.alltoallv_by_elements( + send_tensor, recv_tensor, + send_counts, recv_counts, + stream=torch.cuda.current_stream() + ) + + torch.cuda.synchronize() + comm.barrier_cpu() + + print(f"Rank {rank}: alltoallv completed successfully!") + print(f" Sent {total_send} elements, received {total_recv} elements") + + # Cleanup + comm = None + + +if __name__ == "__main__": + main() diff --git a/examples/torch-integration/alltoallv_kernel.cu b/examples/torch-integration/alltoallv_kernel.cu new file mode 100644 index 00000000..07a518d9 --- /dev/null +++ b/examples/torch-integration/alltoallv_kernel.cu @@ -0,0 +1,400 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +// AllToAllV implementation for MSCCLPP +// This kernel handles variable element counts per rank for alltoallv operations. +// Unlike NCCL's ncclGroupStart/ncclGroupEnd approach, mscclpp uses explicit +// put/signal/wait operations on PortChannels. + +#include +#include + +#include +#include +#include +#include +#include + +namespace py = pybind11; + +#if defined(__HIP_PLATFORM_AMD__) +#define WARP_SIZE 64 +#else +#define WARP_SIZE 32 +#endif + +// Device syncer for synchronization across blocks +__device__ mscclpp::DeviceSyncer alltoallvDeviceSyncer; + +/** + * AllToAllV kernel implementation + * + * This kernel performs an all-to-all exchange with variable-length data per rank. + * Each rank sends sendCounts[i] elements to rank i at sendDispls[i] offset, + * and receives recvCounts[i] elements from rank i at recvDispls[i] offset. + * + * Since mscclpp doesn't support ncclGroupStart/ncclGroupEnd, we implement + * the exchange using explicit put/signal/wait operations on PortChannels. + * The communication pattern uses a ring-based approach to avoid deadlocks. + * + * @param portChannels Array of PortChannel handles for each peer (worldSize-1 channels) + * @param rank Current rank + * @param worldSize Total number of ranks + * @param sendBuff Source buffer containing data to send + * @param recvBuff Destination buffer for received data + * @param sendCounts Array of send counts for each rank (in bytes) + * @param sendDispls Array of send displacements for each rank (in bytes) + * @param recvCounts Array of receive counts for each rank (in bytes) + * @param recvDispls Array of receive displacements for each rank (in bytes) + */ +__global__ void __launch_bounds__(1024) + alltoallv_kernel(mscclpp::DeviceHandle* portChannels, + int rank, + int worldSize, + const void* sendBuff, + void* recvBuff, + const size_t* sendCounts, + const size_t* sendDispls, + const size_t* recvCounts, + const size_t* recvDispls) { + // First, copy local data (rank's own portion) from send to recv buffer + // This doesn't require any communication + if (threadIdx.x == 0 && blockIdx.x == 0) { + if (sendCounts[rank] > 0) { + // Local copy: sendBuff[sendDispls[rank]] -> recvBuff[recvDispls[rank]] + const char* src = (const char*)sendBuff + sendDispls[rank]; + char* dst = (char*)recvBuff + recvDispls[rank]; + memcpy(dst, src, sendCounts[rank]); + } + } + __syncthreads(); + + // Ring-based exchange pattern to avoid deadlocks + // In each step i, rank sends to (rank + i) % worldSize and receives from (rank - i + worldSize) % worldSize + for (int step = 1; step < worldSize; step++) { + int sendPeer = (rank + step) % worldSize; + int recvPeer = (rank - step + worldSize) % worldSize; + + // Get channel indices (portChannels excludes self, so adjust index) + int sendChanIdx = sendPeer < rank ? sendPeer : sendPeer - 1; + int recvChanIdx = recvPeer < rank ? recvPeer : recvPeer - 1; + + // Each warp handles one peer + int wid = threadIdx.x / WARP_SIZE; + int lid = threadIdx.x % WARP_SIZE; + + // Send data to sendPeer if there's data to send + if (wid == 0 && lid == 0) { + if (sendCounts[sendPeer] > 0) { + // putWithSignal: copy data and signal completion + // src offset: sendDispls[sendPeer] in our sendBuff + // dst offset: recvDispls[rank] in peer's recvBuff (where our data should go) + portChannels[sendChanIdx].putWithSignal( + recvDispls[rank], // dst offset in peer's recv buffer (where we write) + sendDispls[sendPeer], // src offset in our send buffer + sendCounts[sendPeer] // size in bytes + ); + } + } + + // Sync all threads before flushing + alltoallvDeviceSyncer.sync(gridDim.x); + + // Flush to ensure data is sent + if (wid == 0 && lid == 0) { + if (sendCounts[sendPeer] > 0) { + portChannels[sendChanIdx].flush(); + } + } + + // Wait for data from recvPeer if we're expecting data + if (wid == 0 && lid == 0) { + if (recvCounts[recvPeer] > 0) { + portChannels[recvChanIdx].wait(); + } + } + + // Sync all threads before next step + alltoallvDeviceSyncer.sync(gridDim.x); + } +} + +/** + * Simplified AllToAllV kernel for single-block execution + * + * This version is optimized for cases where all communication can be + * handled within a single thread block. + */ +__global__ void __launch_bounds__(1024) + alltoallv_simple_kernel(mscclpp::DeviceHandle* portChannels, + int rank, + int worldSize, + const void* sendBuff, + void* recvBuff, + const size_t* sendCounts, + const size_t* sendDispls, + const size_t* recvCounts, + const size_t* recvDispls) { + int tid = threadIdx.x; + int nPeers = worldSize - 1; + + // Step 1: Copy local data + if (tid == 0 && sendCounts[rank] > 0) { + const char* src = (const char*)sendBuff + sendDispls[rank]; + char* dst = (char*)recvBuff + recvDispls[rank]; + memcpy(dst, src, sendCounts[rank]); + } + __syncthreads(); + + // Step 2: Each warp handles one peer for sending + // We have worldSize-1 peers, assign one warp per peer + int warpId = tid / WARP_SIZE; + int laneId = tid % WARP_SIZE; + + if (warpId < nPeers && laneId == 0) { + // Determine which peer this warp handles + int peer = warpId < rank ? warpId : warpId + 1; + int chanIdx = warpId; + + if (sendCounts[peer] > 0) { + portChannels[chanIdx].putWithSignal( + recvDispls[rank], // dst offset in peer's buffer + sendDispls[peer], // src offset in our buffer + sendCounts[peer] // size + ); + } + } + __syncthreads(); + + // Step 3: Flush all pending operations + if (warpId < nPeers && laneId == 0) { + int peer = warpId < rank ? warpId : warpId + 1; + if (sendCounts[peer] > 0) { + portChannels[warpId].flush(); + } + } + __syncthreads(); + + // Step 4: Wait for all incoming data + if (warpId < nPeers && laneId == 0) { + int peer = warpId < rank ? warpId : warpId + 1; + if (recvCounts[peer] > 0) { + portChannels[warpId].wait(); + } + } + __syncthreads(); +} + +// Context to hold all necessary state for alltoallv execution +struct AllToAllVContext { + int rank; + int worldSize; + int nRanksPerNode; + + std::vector registeredMemories; + std::shared_ptr> portChannelDeviceHandles; + + // Device memory for counts and displacements + size_t* d_sendCounts; + size_t* d_sendDispls; + size_t* d_recvCounts; + size_t* d_recvDispls; +}; + +class AllToAllVAlgoBuilder : public mscclpp::AlgorithmBuilder { + public: + AllToAllVAlgoBuilder() = default; + ~AllToAllVAlgoBuilder() { + if (proxyService_) { + proxyService_->stopProxy(); + } + } + + std::shared_ptr build() override { + auto self = std::make_shared(); + std::shared_ptr alltoallvAlgo = std::make_shared( + "alltoallv", "alltoallv", + // Initialize function + [self](std::shared_ptr comm) { self->initialize(comm); }, + // Kernel execution function + [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, size_t outputSize, + mscclpp::DataType dtype, [[maybe_unused]] mscclpp::ReduceOp op, cudaStream_t stream, int nBlocks, + int nThreadsPerBlock, const std::unordered_map& extras) { + return self->alltoallvKernelFunc(ctx, input, output, inputSize, outputSize, dtype, stream, extras); + }, + // Context initialization function + [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, + size_t outputSize, + mscclpp::DataType dtype) { return self->initAlltoallvContext(comm, input, output, inputSize, outputSize, dtype); }, + // Context key generation function + [self](const void* input, void* output, size_t inputSize, size_t outputSize, mscclpp::DataType dtype) { + return self->generateAlltoallvContextKey(input, output, inputSize, outputSize, dtype); + }); + return alltoallvAlgo; + } + + private: + std::vector conns_; + std::shared_ptr proxyService_; + int worldSize_; + + void initialize(std::shared_ptr comm) { + std::vector> connectionFutures; + worldSize_ = comm->bootstrap()->getNranks(); + for (int i = 0; i < worldSize_; i++) { + if (i == comm->bootstrap()->getRank()) continue; + connectionFutures.push_back(comm->connect(mscclpp::Transport::CudaIpc, i)); + } + std::vector connections; + std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections), + [](const auto& future) { return future.get(); }); + this->conns_ = std::move(connections); + proxyService_ = std::make_shared(); + proxyService_->startProxy(true); + } + + mscclpp::CommResult alltoallvKernelFunc(const std::shared_ptr ctx, const void* input, void* output, + size_t inputSize, size_t outputSize, + [[maybe_unused]] mscclpp::DataType dtype, + cudaStream_t stream, + const std::unordered_map& extras) { + auto algoCtx = std::static_pointer_cast(ctx); + int rank = algoCtx->rank; + int worldSize = algoCtx->worldSize; + + // Extract send/recv counts and displacements from extras + // The caller should pass these as device pointers via extras map + auto it_sendCounts = extras.find("sendCounts"); + auto it_sendDispls = extras.find("sendDispls"); + auto it_recvCounts = extras.find("recvCounts"); + auto it_recvDispls = extras.find("recvDispls"); + + if (it_sendCounts == extras.end() || it_sendDispls == extras.end() || + it_recvCounts == extras.end() || it_recvDispls == extras.end()) { + return mscclpp::CommResult::CommInternalError; + } + + const size_t* d_sendCounts = reinterpret_cast(it_sendCounts->second); + const size_t* d_sendDispls = reinterpret_cast(it_sendDispls->second); + const size_t* d_recvCounts = reinterpret_cast(it_recvCounts->second); + const size_t* d_recvDispls = reinterpret_cast(it_recvDispls->second); + + // Reset device syncer + mscclpp::DeviceSyncer syncer = {}; + cudaMemcpyToSymbolAsync(alltoallvDeviceSyncer, &syncer, sizeof(mscclpp::DeviceSyncer), 0, + cudaMemcpyHostToDevice, stream); + + // Use simple kernel for small world sizes, multi-block for larger + if (worldSize <= 16) { + int nThreads = (worldSize - 1) * WARP_SIZE; + if (nThreads < 32) nThreads = 32; + if (nThreads > 1024) nThreads = 1024; + + alltoallv_simple_kernel<<<1, nThreads, 0, stream>>>( + algoCtx->portChannelDeviceHandles.get(), + rank, worldSize, + input, output, + d_sendCounts, d_sendDispls, + d_recvCounts, d_recvDispls); + } else { + alltoallv_kernel<<<1, 1024, 0, stream>>>( + algoCtx->portChannelDeviceHandles.get(), + rank, worldSize, + input, output, + d_sendCounts, d_sendDispls, + d_recvCounts, d_recvDispls); + } + + if (cudaGetLastError() == cudaSuccess) { + return mscclpp::CommResult::CommSuccess; + } + return mscclpp::CommResult::CommInternalError; + } + + std::shared_ptr initAlltoallvContext(std::shared_ptr comm, const void* input, + void* output, size_t inputSize, size_t outputSize, + mscclpp::DataType dtype) { + auto ctx = std::make_shared(); + ctx->rank = comm->bootstrap()->getRank(); + ctx->worldSize = comm->bootstrap()->getNranks(); + ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + + // Register memories for input and output buffers + mscclpp::RegisteredMemory inputBufRegMem = + comm->registerMemory((void*)input, inputSize, mscclpp::Transport::CudaIpc); + mscclpp::RegisteredMemory outputBufRegMem = + comm->registerMemory(output, outputSize, mscclpp::Transport::CudaIpc); + + // Exchange output buffer registration with all peers + std::vector> remoteRegMemories; + for (int i = 0; i < ctx->worldSize; i++) { + if (i == ctx->rank) continue; + comm->sendMemory(outputBufRegMem, i, 0); + remoteRegMemories.push_back(comm->recvMemory(i, 0)); + } + + // Setup port channels for each peer + std::vector> portChannels; + mscclpp::MemoryId inputMemoryId = this->proxyService_->addMemory(inputBufRegMem); + + for (size_t i = 0; i < this->conns_.size(); i++) { + auto remoteMemory = remoteRegMemories[i].get(); + mscclpp::MemoryId remoteMemoryId = this->proxyService_->addMemory(remoteMemory); + portChannels.push_back(mscclpp::deviceHandle(this->proxyService_->portChannel( + this->proxyService_->buildAndAddSemaphore(*comm, this->conns_[i]), remoteMemoryId, inputMemoryId))); + } + + // Allocate and copy port channels to device + ctx->portChannelDeviceHandles = + mscclpp::detail::gpuCallocShared>(portChannels.size()); + mscclpp::gpuMemcpy(ctx->portChannelDeviceHandles.get(), portChannels.data(), portChannels.size(), + cudaMemcpyHostToDevice); + + // Keep registered memory references to prevent deallocation + std::transform(remoteRegMemories.begin(), remoteRegMemories.end(), std::back_inserter(ctx->registeredMemories), + [](const auto& fut) { return fut.get(); }); + ctx->registeredMemories.push_back(inputBufRegMem); + ctx->registeredMemories.push_back(outputBufRegMem); + + return ctx; + } + + mscclpp::AlgorithmCtxKey generateAlltoallvContextKey(const void* input, void* output, size_t inputSize, + size_t outputSize, mscclpp::DataType dtype) { + return {(void*)input, output, inputSize, outputSize, 0}; + } +}; + +std::shared_ptr createAlltoallvAlgorithm() { + auto alltoallvAlgoBuilder = std::make_shared(); + return alltoallvAlgoBuilder->build(); +} + +void deletePtr(PyObject* capsule) { + const char* name = PyCapsule_GetName(capsule); + void* p = PyCapsule_GetPointer(capsule, name); + if (p == nullptr) { + PyErr_WriteUnraisable(capsule); + return; + } + auto* ptr = static_cast*>(p); + delete ptr; +} + +PyObject* getCapsule(std::shared_ptr algo) { + auto* ptrCopy = new std::shared_ptr(algo); + PyObject* capsule = PyCapsule_New(ptrCopy, mscclpp::ALGORITHM_NATIVE_CAPSULE_NAME, deletePtr); + if (capsule == nullptr) { + delete ptrCopy; + throw pybind11::error_already_set(); + } + return capsule; +} + +PYBIND11_MODULE(mscclpp_alltoallv, m) { + m.doc() = "AllToAllV implementation for MSCCLPP - handles variable element counts per rank"; + m.def( + "create_alltoallv_algorithm", + []() { return py::reinterpret_steal(getCapsule(createAlltoallvAlgorithm())); }, + "Create an alltoallv algorithm and return it as a PyCapsule usable by MSCCL++ Python bindings"); +} diff --git a/test/mscclpp-test/CMakeLists.txt b/test/mscclpp-test/CMakeLists.txt index eb2b26ca..d249b4d7 100644 --- a/test/mscclpp-test/CMakeLists.txt +++ b/test/mscclpp-test/CMakeLists.txt @@ -18,3 +18,4 @@ add_mscclpp_test_executable(sendrecv_test_perf sendrecv_test.cu) add_mscclpp_test_executable(allgather_test_perf allgather_test.cu) add_mscclpp_test_executable(allreduce_test_perf allreduce_test.cu) add_mscclpp_test_executable(alltoall_test_perf alltoall_test.cu) +add_mscclpp_test_executable(alltoallv_test_perf alltoallv_test.cu) diff --git a/test/mscclpp-test/alltoallv_test.cu b/test/mscclpp-test/alltoallv_test.cu new file mode 100644 index 00000000..8467a7a4 --- /dev/null +++ b/test/mscclpp-test/alltoallv_test.cu @@ -0,0 +1,347 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +// AllToAllV test - tests variable-length alltoall operations +// This test validates the alltoallv kernel that handles variable element counts per rank. + +#include +#include +#include +#include + +#include "common.hpp" + +#if defined(__HIP_PLATFORM_AMD__) +#define WARP_SIZE 64 +#else +#define WARP_SIZE 32 +#endif + +template +using DeviceHandle = mscclpp::DeviceHandle; +__constant__ DeviceHandle constPortChansV[16]; +__device__ mscclpp::DeviceSyncer deviceSyncerV; + +static void* localRecvBuffV; +static void* localSendBuffV; + +// Device arrays for variable counts and displacements +static size_t* d_sendCounts; +static size_t* d_sendDispls; +static size_t* d_recvCounts; +static size_t* d_recvDispls; + +/** + * AllToAllV kernel implementation + * + * Each rank sends sendCounts[i] bytes to rank i at sendDispls[i] offset, + * and receives recvCounts[i] bytes from rank i at recvDispls[i] offset. + * + * Uses ring-based exchange pattern to avoid deadlocks. + */ +__global__ void __launch_bounds__(1024) + alltoallv0(int rank, int worldSize, + const void* sendBuff, void* recvBuff, + const size_t* sendCounts, const size_t* sendDispls, + const size_t* recvCounts, const size_t* recvDispls) { + int tid = threadIdx.x; + int nPeers = worldSize - 1; + + // Step 1: Copy local data (rank's own portion) + if (tid == 0 && sendCounts[rank] > 0) { + const char* src = (const char*)sendBuff + sendDispls[rank]; + char* dst = (char*)recvBuff + recvDispls[rank]; + memcpy(dst, src, sendCounts[rank]); + } + __syncthreads(); + + // Step 2: Each warp handles one peer for sending + int warpId = tid / WARP_SIZE; + int laneId = tid % WARP_SIZE; + + if (warpId < nPeers && laneId == 0) { + // Determine which peer this warp handles + int peer = warpId < rank ? warpId : warpId + 1; + int chanIdx = warpId; + + if (sendCounts[peer] > 0) { + constPortChansV[chanIdx].putWithSignal( + recvDispls[rank], // dst offset in peer's buffer + sendDispls[peer], // src offset in our buffer + sendCounts[peer] // size + ); + } + } + __syncthreads(); + + // Step 3: Flush all pending operations + if (warpId < nPeers && laneId == 0) { + int peer = warpId < rank ? warpId : warpId + 1; + if (sendCounts[peer] > 0) { + constPortChansV[warpId].flush(); + } + } + __syncthreads(); + + // Step 4: Wait for all incoming data + if (warpId < nPeers && laneId == 0) { + int peer = warpId < rank ? warpId : warpId + 1; + if (recvCounts[peer] > 0) { + constPortChansV[warpId].wait(); + } + } + __syncthreads(); +} + +/** + * Ring-based AllToAllV kernel for larger world sizes + * + * Uses step-by-step ring pattern to exchange data, sending to (rank+step) and + * receiving from (rank-step) in each step. Single block to avoid concurrent + * access to the same port channels. + */ +__global__ void __launch_bounds__(1024) + alltoallv1(int rank, int worldSize, + const void* sendBuff, void* recvBuff, + const size_t* sendCounts, const size_t* sendDispls, + const size_t* recvCounts, const size_t* recvDispls) { + // Copy local data first + if (threadIdx.x == 0) { + if (sendCounts[rank] > 0) { + const char* src = (const char*)sendBuff + sendDispls[rank]; + char* dst = (char*)recvBuff + recvDispls[rank]; + memcpy(dst, src, sendCounts[rank]); + } + } + __syncthreads(); + + // Ring-based exchange - single thread handles the communication + // to avoid race conditions on port channels + if (threadIdx.x == 0) { + for (int step = 1; step < worldSize; step++) { + int sendPeer = (rank + step) % worldSize; + int recvPeer = (rank - step + worldSize) % worldSize; + + int sendChanIdx = sendPeer < rank ? sendPeer : sendPeer - 1; + int recvChanIdx = recvPeer < rank ? recvPeer : recvPeer - 1; + + // Send data to sendPeer (non-blocking put with signal) + if (sendCounts[sendPeer] > 0) { + constPortChansV[sendChanIdx].putWithSignal( + recvDispls[rank], // dst offset in peer's buffer + sendDispls[sendPeer], // src offset in our buffer + sendCounts[sendPeer] // size + ); + constPortChansV[sendChanIdx].flush(); + } + + // Wait for data from recvPeer + if (recvCounts[recvPeer] > 0) { + constPortChansV[recvChanIdx].wait(); + } + } + } +} + +class AllToAllVTestColl : public BaseTestColl { + public: + AllToAllVTestColl() = default; + ~AllToAllVTestColl() override = default; + + void runColl(const TestArgs& args, cudaStream_t stream) override; + void initData(const TestArgs& args, std::vector sendBuff, void* expectedBuff) override; + void getBw(const double deltaSec, double& algBw /*OUT*/, double& busBw /*OUT*/) override; + void setupCollTest(size_t size) override; + std::vector getKernelRestrictions() override; + + private: + // Host-side counts and displacements + std::vector sendCounts_; + std::vector sendDispls_; + std::vector recvCounts_; + std::vector recvDispls_; + size_t totalSendBytes_; + size_t totalRecvBytes_; +}; + +void AllToAllVTestColl::runColl(const TestArgs& args, cudaStream_t stream) { + const int worldSize = args.totalRanks; + const int rank = args.rank; + const int kernelNum = args.kernelNum; + + // Reset device syncer + mscclpp::DeviceSyncer syncer = {}; + CUDATHROW(cudaMemcpyToSymbol(deviceSyncerV, &syncer, sizeof(mscclpp::DeviceSyncer))); + + if (kernelNum == 0) { + int nThreads = (worldSize - 1) * WARP_SIZE; + if (nThreads < 32) nThreads = 32; + if (nThreads > 1024) nThreads = 1024; + alltoallv0<<<1, nThreads, 0, stream>>>( + rank, worldSize, + localSendBuffV, localRecvBuffV, + d_sendCounts, d_sendDispls, + d_recvCounts, d_recvDispls); + } else if (kernelNum == 1) { + // Single block, single thread for ring-based serialized communication + alltoallv1<<<1, 32, 0, stream>>>( + rank, worldSize, + localSendBuffV, localRecvBuffV, + d_sendCounts, d_sendDispls, + d_recvCounts, d_recvDispls); + } +} + +void AllToAllVTestColl::initData(const TestArgs& args, std::vector sendBuff, void* expectedBuff) { + if (sendBuff.size() != 1) throw std::runtime_error("unexpected error"); + const int rank = args.rank; + const int worldSize = args.totalRanks; + + // Create send data: each segment has values identifying source and destination + std::vector sendData(totalSendBytes_ / sizeof(int), 0); + for (int peer = 0; peer < worldSize; peer++) { + size_t offset = sendDispls_[peer] / sizeof(int); + size_t count = sendCounts_[peer] / sizeof(int); + for (size_t i = 0; i < count; i++) { + // Encode: rank * 10000 + peer * 100 + position + sendData[offset + i] = rank * 10000 + peer * 100 + i; + } + } + CUDATHROW(cudaMemcpy(sendBuff[0], sendData.data(), totalSendBytes_, cudaMemcpyHostToDevice)); + + // Create expected data: we receive from each peer + std::vector expectedData(totalRecvBytes_ / sizeof(int), 0); + for (int peer = 0; peer < worldSize; peer++) { + size_t offset = recvDispls_[peer] / sizeof(int); + size_t count = recvCounts_[peer] / sizeof(int); + for (size_t i = 0; i < count; i++) { + // We receive data sent by peer to us + expectedData[offset + i] = peer * 10000 + rank * 100 + i; + } + } + std::memcpy(expectedBuff, expectedData.data(), totalRecvBytes_); + + // Copy counts and displacements to device + CUDATHROW(cudaMemcpy(d_sendCounts, sendCounts_.data(), worldSize * sizeof(size_t), cudaMemcpyHostToDevice)); + CUDATHROW(cudaMemcpy(d_sendDispls, sendDispls_.data(), worldSize * sizeof(size_t), cudaMemcpyHostToDevice)); + CUDATHROW(cudaMemcpy(d_recvCounts, recvCounts_.data(), worldSize * sizeof(size_t), cudaMemcpyHostToDevice)); + CUDATHROW(cudaMemcpy(d_recvDispls, recvDispls_.data(), worldSize * sizeof(size_t), cudaMemcpyHostToDevice)); +} + +void AllToAllVTestColl::getBw(const double deltaSec, double& algBw, double& busBw) { + double baseBw = (double)(totalRecvBytes_) / 1.0E9 / deltaSec; + algBw = baseBw; + double factor = ((double)(worldSize_ - 1)) / ((double)worldSize_); + busBw = baseBw * factor; +} + +void AllToAllVTestColl::setupCollTest(size_t size) { + // For alltoallv, we use variable sizes per peer + // For testing: rank i sends (rank + 1) * baseCount elements to each peer + // Each peer j sends (j + 1) * baseCount elements to us + + size_t baseBytes = size / (worldSize_ * worldSize_); // Base unit for variable sizing + if (baseBytes < sizeof(int)) baseBytes = sizeof(int); + baseBytes = (baseBytes / sizeof(int)) * sizeof(int); // Align to int size + + sendCounts_.resize(worldSize_); + sendDispls_.resize(worldSize_); + recvCounts_.resize(worldSize_); + recvDispls_.resize(worldSize_); + + // Each rank sends the same amount to each peer (for simplicity in this test) + // In a real MOE scenario, these would be variable + totalSendBytes_ = 0; + totalRecvBytes_ = 0; + + for (int peer = 0; peer < worldSize_; peer++) { + sendCounts_[peer] = baseBytes; + sendDispls_[peer] = totalSendBytes_; + totalSendBytes_ += sendCounts_[peer]; + + recvCounts_[peer] = baseBytes; + recvDispls_[peer] = totalRecvBytes_; + totalRecvBytes_ += recvCounts_[peer]; + } + + sendCount_ = totalSendBytes_ / typeSize_; + recvCount_ = totalRecvBytes_ / typeSize_; + paramCount_ = sendCount_; + expectedCount_ = recvCount_; + + mscclpp::DeviceSyncer syncer = {}; + CUDATHROW(cudaMemcpyToSymbol(deviceSyncerV, &syncer, sizeof(mscclpp::DeviceSyncer))); +} + +std::vector AllToAllVTestColl::getKernelRestrictions() { + return { + {0, "alltoallv0", true, 1, 4 * worldSize_}, + {1, "alltoallv1", true, 1, 4 * worldSize_} + }; +} + +class AllToAllVTestEngine : public BaseTestEngine { + public: + AllToAllVTestEngine(const TestArgs& args); + ~AllToAllVTestEngine() override = default; + + void allocateBuffer() override; + void setupConnections() override; + + std::vector getSendBuff() override; + void* getRecvBuff() override; + void* getScratchBuff() override; + + private: + void* getExpectedBuff() override; + bool isInPlace() const; + + std::shared_ptr sendBuff_; + std::shared_ptr recvBuff_; + std::shared_ptr expectedBuff_; +}; + +bool AllToAllVTestEngine::isInPlace() const { return false; } + +AllToAllVTestEngine::AllToAllVTestEngine(const TestArgs& args) : BaseTestEngine(args, "alltoallv") { inPlace_ = false; } + +void AllToAllVTestEngine::allocateBuffer() { + sendBuff_ = mscclpp::GpuBuffer(args_.maxBytes / sizeof(int)).memory(); + recvBuff_ = mscclpp::GpuBuffer(args_.maxBytes / sizeof(int)).memory(); + expectedBuff_ = std::shared_ptr(new int[args_.maxBytes / sizeof(int)]); + + localSendBuffV = sendBuff_.get(); + localRecvBuffV = recvBuff_.get(); + + // Allocate device arrays for counts and displacements + CUDATHROW(cudaMalloc(&d_sendCounts, args_.totalRanks * sizeof(size_t))); + CUDATHROW(cudaMalloc(&d_sendDispls, args_.totalRanks * sizeof(size_t))); + CUDATHROW(cudaMalloc(&d_recvCounts, args_.totalRanks * sizeof(size_t))); + CUDATHROW(cudaMalloc(&d_recvDispls, args_.totalRanks * sizeof(size_t))); +} + +void AllToAllVTestEngine::setupConnections() { + std::vector> portChannels; + setupMeshConnections(portChannels, sendBuff_.get(), args_.maxBytes, recvBuff_.get(), args_.maxBytes); + + if (portChannels.size() > sizeof(constPortChansV) / sizeof(DeviceHandle)) { + throw std::runtime_error("Too many port channels for alltoallv test"); + } + CUDATHROW(cudaMemcpyToSymbol(constPortChansV, portChannels.data(), + sizeof(DeviceHandle) * portChannels.size())); +} + +std::vector AllToAllVTestEngine::getSendBuff() { return {sendBuff_.get()}; } +void* AllToAllVTestEngine::getExpectedBuff() { return expectedBuff_.get(); } +void* AllToAllVTestEngine::getRecvBuff() { + if (this->isInPlace()) + return sendBuff_.get(); + else + return recvBuff_.get(); +} +void* AllToAllVTestEngine::getScratchBuff() { return nullptr; } + +std::shared_ptr getTestEngine(const TestArgs& args) { + return std::make_shared(args); +} +std::shared_ptr getTestColl() { return std::make_shared(); }