Add alltoallv kernel and test

This commit is contained in:
Qinghua Zhou
2026-02-05 07:41:35 +00:00
parent f0441ee4ea
commit ac3e770c42
4 changed files with 1063 additions and 0 deletions

View File

@@ -0,0 +1,315 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# AllToAllV implementation for MSCCLPP
# This module provides a PyTorch-compatible alltoallv operation using MSCCLPP.
#
# Usage:
# MSCCLPP_MASTER_ADDR=<master_ip> MSCCLPP_MASTER_PORT=<port> \
# torchrun --nnodes=1 --nproc_per_node=8 alltoallv.py
#
# For AMD GPUs:
# MSCCLPP_MASTER_ADDR=<master_ip> MSCCLPP_MASTER_PORT=<port> \
# GPU_MAX_HW_QUEUES=7 torchrun --nnodes=1 --nproc_per_node=8 alltoallv.py
import mscclpp
import mscclpp.utils as mscclpp_utils
import torch
import os
import netifaces as ni
import ipaddress
from typing import List, Optional
_abs_path = os.path.dirname(os.path.abspath(__file__))
def interfaces_for_ip_netifaces(ip: str):
"""Find the network interface for a given IP address."""
target = ipaddress.ip_address(ip)
for interface in ni.interfaces():
addresses = ni.ifaddresses(interface)
if ni.AF_INET in addresses:
for link in addresses[ni.AF_INET]:
if "addr" in link:
addr = ipaddress.ip_address(link["addr"])
if addr == target:
return interface
return None
class AllToAllVComm:
"""
AllToAllV communication class using MSCCLPP.
This class provides a customized alltoallv implementation that handles
variable element counts per rank, similar to MPI_Alltoallv or the
batch_all_to_all_v pattern commonly used in MOE (Mixture of Experts) models.
Unlike NCCL's ncclGroupStart/ncclGroupEnd approach, MSCCLPP uses explicit
put/signal/wait operations on PortChannels for communication.
Attributes:
comm: MSCCLPP CommGroup instance
rank: Current rank
world_size: Total number of ranks
"""
def __init__(self, comm: mscclpp.CommGroup):
"""
Initialize AllToAllV communication.
Args:
comm: MSCCLPP CommGroup instance
"""
self.comm = comm
self.rank = comm.my_rank
self.world_size = comm.nranks
self.local_rank = comm.my_rank % comm.nranks_per_node
self.n_ranks_per_node = comm.nranks_per_node
self.executor = mscclpp.Executor(comm.communicator)
# Compile and load the native CUDA kernel
mscclpp_native = mscclpp.compile_native(
name="mscclpp_alltoallv",
file=os.path.join(_abs_path, "alltoallv_kernel.cu")
)
capsule = mscclpp_native.create_alltoallv_algorithm()
self.algorithm = mscclpp.Algorithm.create_from_native_capsule(capsule)
def alltoallv(
self,
send_tensor: torch.Tensor,
recv_tensor: torch.Tensor,
send_counts: torch.Tensor,
send_displs: torch.Tensor,
recv_counts: torch.Tensor,
recv_displs: torch.Tensor,
stream: Optional[torch.cuda.Stream] = None
):
"""
Perform alltoallv operation with variable element counts.
This function exchanges data between all ranks where each rank can send
different amounts of data to each other rank.
Args:
send_tensor: Source tensor containing all data to be sent
recv_tensor: Destination tensor for received data
send_counts: Tensor of shape [world_size] with byte counts to send to each rank
send_displs: Tensor of shape [world_size] with byte offsets in send_tensor for each rank
recv_counts: Tensor of shape [world_size] with byte counts to receive from each rank
recv_displs: Tensor of shape [world_size] with byte offsets in recv_tensor for each rank
stream: Optional CUDA stream to use for the operation
Note:
- All count and displacement tensors should be on GPU and contain size_t values
- send_counts[i] is the number of bytes to send to rank i
- send_displs[i] is the byte offset in send_tensor for data going to rank i
- recv_counts[i] is the number of bytes to receive from rank i
- recv_displs[i] is the byte offset in recv_tensor for data from rank i
"""
# Ensure counts and displacements are on GPU and have correct dtype
assert send_counts.device.type == "cuda", "send_counts must be on GPU"
assert send_displs.device.type == "cuda", "send_displs must be on GPU"
assert recv_counts.device.type == "cuda", "recv_counts must be on GPU"
assert recv_displs.device.type == "cuda", "recv_displs must be on GPU"
# Prepare extras dict with device pointers for counts and displacements
extras = {
"sendCounts": send_counts.data_ptr(),
"sendDispls": send_displs.data_ptr(),
"recvCounts": recv_counts.data_ptr(),
"recvDispls": recv_displs.data_ptr(),
}
cuda_stream = stream.cuda_stream if stream is not None else 0
self.algorithm.execute(
self.comm.communicator,
send_tensor.data_ptr(),
recv_tensor.data_ptr(),
send_tensor.nbytes,
recv_tensor.nbytes,
mscclpp_utils.torch_dtype_to_mscclpp_dtype(send_tensor.dtype),
stream=cuda_stream,
extras=extras
)
def alltoallv_by_elements(
self,
send_tensor: torch.Tensor,
recv_tensor: torch.Tensor,
send_counts_elements: torch.Tensor,
recv_counts_elements: torch.Tensor,
stream: Optional[torch.cuda.Stream] = None
):
"""
Convenience function for alltoallv with element counts (not byte counts).
This function automatically computes byte counts and displacements from
element counts, making it easier to use with typical MOE patterns.
Args:
send_tensor: Source tensor containing all data to be sent
recv_tensor: Destination tensor for received data
send_counts_elements: Tensor of shape [world_size] with element counts to send to each rank
recv_counts_elements: Tensor of shape [world_size] with element counts to receive from each rank
stream: Optional CUDA stream to use for the operation
"""
element_size = send_tensor.element_size()
# Convert element counts to byte counts
send_counts = send_counts_elements.to(torch.int64) * element_size
recv_counts = recv_counts_elements.to(torch.int64) * element_size
# Compute displacements (exclusive prefix sum)
send_displs = torch.zeros(self.world_size, dtype=torch.int64, device=send_tensor.device)
recv_displs = torch.zeros(self.world_size, dtype=torch.int64, device=recv_tensor.device)
if self.world_size > 1:
send_displs[1:] = torch.cumsum(send_counts[:-1], dim=0)
recv_displs[1:] = torch.cumsum(recv_counts[:-1], dim=0)
self.alltoallv(
send_tensor, recv_tensor,
send_counts, send_displs,
recv_counts, recv_displs,
stream
)
def barrier_cpu(self):
"""CPU barrier to synchronize all ranks."""
self.comm.barrier()
def batch_alltoallv(
comm: AllToAllVComm,
inputs: List[torch.Tensor],
outputs: List[torch.Tensor],
in_sizes: torch.Tensor,
out_sizes: torch.Tensor,
stream: Optional[torch.cuda.Stream] = None
):
"""
Batch all-to-all-v operation for multiple tensors.
This function replicates the pattern from MOE implementations:
```
for k in range(len(inputs)):
ncclGroupStart()
for i in range(world_size):
ncclSend(in_buff + in_offset, in_sizes[i], ...)
ncclRecv(out_buff + out_offset, out_sizes[i], ...)
ncclGroupEnd()
```
Since MSCCLPP doesn't support ncclGroupStart/ncclGroupEnd, we implement
this using explicit alltoallv operations for each tensor in the batch.
Args:
comm: AllToAllVComm instance
inputs: List of input tensors to send
outputs: List of output tensors to receive
in_sizes: Tensor of shape [world_size] with element counts to send to each rank
out_sizes: Tensor of shape [world_size] with element counts to receive from each rank
stream: Optional CUDA stream
"""
assert len(inputs) == len(outputs), "Input and output lists must have same length"
# Ensure sizes are on CPU for computing displacements
in_sizes_cpu = in_sizes.cpu().to(torch.int64)
out_sizes_cpu = out_sizes.cpu().to(torch.int64)
for k in range(len(inputs)):
input_tensor = inputs[k]
output_tensor = outputs[k]
element_size = input_tensor.element_size()
# Compute byte counts
send_counts = (in_sizes_cpu * element_size).cuda()
recv_counts = (out_sizes_cpu * element_size).cuda()
# Compute displacements
send_displs = torch.zeros(comm.world_size, dtype=torch.int64, device="cuda")
recv_displs = torch.zeros(comm.world_size, dtype=torch.int64, device="cuda")
if comm.world_size > 1:
send_displs_cpu = torch.zeros(comm.world_size, dtype=torch.int64)
recv_displs_cpu = torch.zeros(comm.world_size, dtype=torch.int64)
send_displs_cpu[1:] = torch.cumsum(send_counts.cpu()[:-1], dim=0)
recv_displs_cpu[1:] = torch.cumsum(recv_counts.cpu()[:-1], dim=0)
send_displs = send_displs_cpu.cuda()
recv_displs = recv_displs_cpu.cuda()
comm.alltoallv(
input_tensor, output_tensor,
send_counts, send_displs,
recv_counts, recv_displs,
stream
)
def main():
"""Test the alltoallv implementation."""
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", os.environ["RANK"]))
torch.cuda.set_device(local_rank)
master_addr = os.environ["MSCCLPP_MASTER_ADDR"]
master_port = os.environ["MSCCLPP_MASTER_PORT"]
interface = interfaces_for_ip_netifaces(master_addr)
if interface is None:
raise ValueError(f"Cannot find network interface for IP address {master_addr}")
interfaceIpPortTrio = f"{interface}:{master_addr}:{master_port}"
mscclpp_group = mscclpp.CommGroup(interfaceIpPortTrio=interfaceIpPortTrio, rank=rank, size=world_size)
# Create communicator
comm = AllToAllVComm(mscclpp_group)
# Test with variable sizes per rank
# Each rank sends different amounts to different peers
# For simplicity: rank i sends (i+1)*100 elements to each peer
elements_per_peer = (rank + 1) * 100
# Create send buffer with data
total_send = elements_per_peer * world_size
send_tensor = torch.arange(total_send, device="cuda", dtype=torch.float32) + rank * 10000
# Receive buffer needs to accommodate variable amounts from each sender
# Each sender j sends (j+1)*100 elements to us
recv_counts_cpu = torch.tensor([(j + 1) * 100 for j in range(world_size)], dtype=torch.int64)
total_recv = recv_counts_cpu.sum().item()
recv_tensor = torch.zeros(total_recv, device="cuda", dtype=torch.float32)
# Send counts: we send elements_per_peer to everyone
send_counts_cpu = torch.tensor([elements_per_peer] * world_size, dtype=torch.int64)
# Move to GPU
send_counts = send_counts_cpu.cuda()
recv_counts = recv_counts_cpu.cuda()
comm.barrier_cpu()
# Perform alltoallv
comm.alltoallv_by_elements(
send_tensor, recv_tensor,
send_counts, recv_counts,
stream=torch.cuda.current_stream()
)
torch.cuda.synchronize()
comm.barrier_cpu()
print(f"Rank {rank}: alltoallv completed successfully!")
print(f" Sent {total_send} elements, received {total_recv} elements")
# Cleanup
comm = None
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,400 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
// AllToAllV implementation for MSCCLPP
// This kernel handles variable element counts per rank for alltoallv operations.
// Unlike NCCL's ncclGroupStart/ncclGroupEnd approach, mscclpp uses explicit
// put/signal/wait operations on PortChannels.
#include <Python.h>
#include <pybind11/pybind11.h>
#include <mscclpp/algorithm.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/port_channel.hpp>
#include <mscclpp/port_channel_device.hpp>
#include <mscclpp/concurrency_device.hpp>
namespace py = pybind11;
#if defined(__HIP_PLATFORM_AMD__)
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
// Device syncer for synchronization across blocks
__device__ mscclpp::DeviceSyncer alltoallvDeviceSyncer;
/**
* AllToAllV kernel implementation
*
* This kernel performs an all-to-all exchange with variable-length data per rank.
* Each rank sends sendCounts[i] elements to rank i at sendDispls[i] offset,
* and receives recvCounts[i] elements from rank i at recvDispls[i] offset.
*
* Since mscclpp doesn't support ncclGroupStart/ncclGroupEnd, we implement
* the exchange using explicit put/signal/wait operations on PortChannels.
* The communication pattern uses a ring-based approach to avoid deadlocks.
*
* @param portChannels Array of PortChannel handles for each peer (worldSize-1 channels)
* @param rank Current rank
* @param worldSize Total number of ranks
* @param sendBuff Source buffer containing data to send
* @param recvBuff Destination buffer for received data
* @param sendCounts Array of send counts for each rank (in bytes)
* @param sendDispls Array of send displacements for each rank (in bytes)
* @param recvCounts Array of receive counts for each rank (in bytes)
* @param recvDispls Array of receive displacements for each rank (in bytes)
*/
__global__ void __launch_bounds__(1024)
alltoallv_kernel(mscclpp::DeviceHandle<mscclpp::PortChannel>* portChannels,
int rank,
int worldSize,
const void* sendBuff,
void* recvBuff,
const size_t* sendCounts,
const size_t* sendDispls,
const size_t* recvCounts,
const size_t* recvDispls) {
// First, copy local data (rank's own portion) from send to recv buffer
// This doesn't require any communication
if (threadIdx.x == 0 && blockIdx.x == 0) {
if (sendCounts[rank] > 0) {
// Local copy: sendBuff[sendDispls[rank]] -> recvBuff[recvDispls[rank]]
const char* src = (const char*)sendBuff + sendDispls[rank];
char* dst = (char*)recvBuff + recvDispls[rank];
memcpy(dst, src, sendCounts[rank]);
}
}
__syncthreads();
// Ring-based exchange pattern to avoid deadlocks
// In each step i, rank sends to (rank + i) % worldSize and receives from (rank - i + worldSize) % worldSize
for (int step = 1; step < worldSize; step++) {
int sendPeer = (rank + step) % worldSize;
int recvPeer = (rank - step + worldSize) % worldSize;
// Get channel indices (portChannels excludes self, so adjust index)
int sendChanIdx = sendPeer < rank ? sendPeer : sendPeer - 1;
int recvChanIdx = recvPeer < rank ? recvPeer : recvPeer - 1;
// Each warp handles one peer
int wid = threadIdx.x / WARP_SIZE;
int lid = threadIdx.x % WARP_SIZE;
// Send data to sendPeer if there's data to send
if (wid == 0 && lid == 0) {
if (sendCounts[sendPeer] > 0) {
// putWithSignal: copy data and signal completion
// src offset: sendDispls[sendPeer] in our sendBuff
// dst offset: recvDispls[rank] in peer's recvBuff (where our data should go)
portChannels[sendChanIdx].putWithSignal(
recvDispls[rank], // dst offset in peer's recv buffer (where we write)
sendDispls[sendPeer], // src offset in our send buffer
sendCounts[sendPeer] // size in bytes
);
}
}
// Sync all threads before flushing
alltoallvDeviceSyncer.sync(gridDim.x);
// Flush to ensure data is sent
if (wid == 0 && lid == 0) {
if (sendCounts[sendPeer] > 0) {
portChannels[sendChanIdx].flush();
}
}
// Wait for data from recvPeer if we're expecting data
if (wid == 0 && lid == 0) {
if (recvCounts[recvPeer] > 0) {
portChannels[recvChanIdx].wait();
}
}
// Sync all threads before next step
alltoallvDeviceSyncer.sync(gridDim.x);
}
}
/**
* Simplified AllToAllV kernel for single-block execution
*
* This version is optimized for cases where all communication can be
* handled within a single thread block.
*/
__global__ void __launch_bounds__(1024)
alltoallv_simple_kernel(mscclpp::DeviceHandle<mscclpp::PortChannel>* portChannels,
int rank,
int worldSize,
const void* sendBuff,
void* recvBuff,
const size_t* sendCounts,
const size_t* sendDispls,
const size_t* recvCounts,
const size_t* recvDispls) {
int tid = threadIdx.x;
int nPeers = worldSize - 1;
// Step 1: Copy local data
if (tid == 0 && sendCounts[rank] > 0) {
const char* src = (const char*)sendBuff + sendDispls[rank];
char* dst = (char*)recvBuff + recvDispls[rank];
memcpy(dst, src, sendCounts[rank]);
}
__syncthreads();
// Step 2: Each warp handles one peer for sending
// We have worldSize-1 peers, assign one warp per peer
int warpId = tid / WARP_SIZE;
int laneId = tid % WARP_SIZE;
if (warpId < nPeers && laneId == 0) {
// Determine which peer this warp handles
int peer = warpId < rank ? warpId : warpId + 1;
int chanIdx = warpId;
if (sendCounts[peer] > 0) {
portChannels[chanIdx].putWithSignal(
recvDispls[rank], // dst offset in peer's buffer
sendDispls[peer], // src offset in our buffer
sendCounts[peer] // size
);
}
}
__syncthreads();
// Step 3: Flush all pending operations
if (warpId < nPeers && laneId == 0) {
int peer = warpId < rank ? warpId : warpId + 1;
if (sendCounts[peer] > 0) {
portChannels[warpId].flush();
}
}
__syncthreads();
// Step 4: Wait for all incoming data
if (warpId < nPeers && laneId == 0) {
int peer = warpId < rank ? warpId : warpId + 1;
if (recvCounts[peer] > 0) {
portChannels[warpId].wait();
}
}
__syncthreads();
}
// Context to hold all necessary state for alltoallv execution
struct AllToAllVContext {
int rank;
int worldSize;
int nRanksPerNode;
std::vector<mscclpp::RegisteredMemory> registeredMemories;
std::shared_ptr<mscclpp::DeviceHandle<mscclpp::PortChannel>> portChannelDeviceHandles;
// Device memory for counts and displacements
size_t* d_sendCounts;
size_t* d_sendDispls;
size_t* d_recvCounts;
size_t* d_recvDispls;
};
class AllToAllVAlgoBuilder : public mscclpp::AlgorithmBuilder {
public:
AllToAllVAlgoBuilder() = default;
~AllToAllVAlgoBuilder() {
if (proxyService_) {
proxyService_->stopProxy();
}
}
std::shared_ptr<mscclpp::Algorithm> build() override {
auto self = std::make_shared<AllToAllVAlgoBuilder>();
std::shared_ptr<mscclpp::Algorithm> alltoallvAlgo = std::make_shared<mscclpp::NativeAlgorithm>(
"alltoallv", "alltoallv",
// Initialize function
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
// Kernel execution function
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize, size_t outputSize,
mscclpp::DataType dtype, [[maybe_unused]] mscclpp::ReduceOp op, cudaStream_t stream, int nBlocks,
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) {
return self->alltoallvKernelFunc(ctx, input, output, inputSize, outputSize, dtype, stream, extras);
},
// Context initialization function
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t inputSize,
size_t outputSize,
mscclpp::DataType dtype) { return self->initAlltoallvContext(comm, input, output, inputSize, outputSize, dtype); },
// Context key generation function
[self](const void* input, void* output, size_t inputSize, size_t outputSize, mscclpp::DataType dtype) {
return self->generateAlltoallvContextKey(input, output, inputSize, outputSize, dtype);
});
return alltoallvAlgo;
}
private:
std::vector<mscclpp::Connection> conns_;
std::shared_ptr<mscclpp::ProxyService> proxyService_;
int worldSize_;
void initialize(std::shared_ptr<mscclpp::Communicator> comm) {
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
worldSize_ = comm->bootstrap()->getNranks();
for (int i = 0; i < worldSize_; i++) {
if (i == comm->bootstrap()->getRank()) continue;
connectionFutures.push_back(comm->connect(mscclpp::Transport::CudaIpc, i));
}
std::vector<mscclpp::Connection> connections;
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
[](const auto& future) { return future.get(); });
this->conns_ = std::move(connections);
proxyService_ = std::make_shared<mscclpp::ProxyService>();
proxyService_->startProxy(true);
}
mscclpp::CommResult alltoallvKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output,
size_t inputSize, size_t outputSize,
[[maybe_unused]] mscclpp::DataType dtype,
cudaStream_t stream,
const std::unordered_map<std::string, uintptr_t>& extras) {
auto algoCtx = std::static_pointer_cast<AllToAllVContext>(ctx);
int rank = algoCtx->rank;
int worldSize = algoCtx->worldSize;
// Extract send/recv counts and displacements from extras
// The caller should pass these as device pointers via extras map
auto it_sendCounts = extras.find("sendCounts");
auto it_sendDispls = extras.find("sendDispls");
auto it_recvCounts = extras.find("recvCounts");
auto it_recvDispls = extras.find("recvDispls");
if (it_sendCounts == extras.end() || it_sendDispls == extras.end() ||
it_recvCounts == extras.end() || it_recvDispls == extras.end()) {
return mscclpp::CommResult::CommInternalError;
}
const size_t* d_sendCounts = reinterpret_cast<const size_t*>(it_sendCounts->second);
const size_t* d_sendDispls = reinterpret_cast<const size_t*>(it_sendDispls->second);
const size_t* d_recvCounts = reinterpret_cast<const size_t*>(it_recvCounts->second);
const size_t* d_recvDispls = reinterpret_cast<const size_t*>(it_recvDispls->second);
// Reset device syncer
mscclpp::DeviceSyncer syncer = {};
cudaMemcpyToSymbolAsync(alltoallvDeviceSyncer, &syncer, sizeof(mscclpp::DeviceSyncer), 0,
cudaMemcpyHostToDevice, stream);
// Use simple kernel for small world sizes, multi-block for larger
if (worldSize <= 16) {
int nThreads = (worldSize - 1) * WARP_SIZE;
if (nThreads < 32) nThreads = 32;
if (nThreads > 1024) nThreads = 1024;
alltoallv_simple_kernel<<<1, nThreads, 0, stream>>>(
algoCtx->portChannelDeviceHandles.get(),
rank, worldSize,
input, output,
d_sendCounts, d_sendDispls,
d_recvCounts, d_recvDispls);
} else {
alltoallv_kernel<<<1, 1024, 0, stream>>>(
algoCtx->portChannelDeviceHandles.get(),
rank, worldSize,
input, output,
d_sendCounts, d_sendDispls,
d_recvCounts, d_recvDispls);
}
if (cudaGetLastError() == cudaSuccess) {
return mscclpp::CommResult::CommSuccess;
}
return mscclpp::CommResult::CommInternalError;
}
std::shared_ptr<void> initAlltoallvContext(std::shared_ptr<mscclpp::Communicator> comm, const void* input,
void* output, size_t inputSize, size_t outputSize,
mscclpp::DataType dtype) {
auto ctx = std::make_shared<AllToAllVContext>();
ctx->rank = comm->bootstrap()->getRank();
ctx->worldSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
// Register memories for input and output buffers
mscclpp::RegisteredMemory inputBufRegMem =
comm->registerMemory((void*)input, inputSize, mscclpp::Transport::CudaIpc);
mscclpp::RegisteredMemory outputBufRegMem =
comm->registerMemory(output, outputSize, mscclpp::Transport::CudaIpc);
// Exchange output buffer registration with all peers
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemories;
for (int i = 0; i < ctx->worldSize; i++) {
if (i == ctx->rank) continue;
comm->sendMemory(outputBufRegMem, i, 0);
remoteRegMemories.push_back(comm->recvMemory(i, 0));
}
// Setup port channels for each peer
std::vector<mscclpp::DeviceHandle<mscclpp::PortChannel>> portChannels;
mscclpp::MemoryId inputMemoryId = this->proxyService_->addMemory(inputBufRegMem);
for (size_t i = 0; i < this->conns_.size(); i++) {
auto remoteMemory = remoteRegMemories[i].get();
mscclpp::MemoryId remoteMemoryId = this->proxyService_->addMemory(remoteMemory);
portChannels.push_back(mscclpp::deviceHandle(this->proxyService_->portChannel(
this->proxyService_->buildAndAddSemaphore(*comm, this->conns_[i]), remoteMemoryId, inputMemoryId)));
}
// Allocate and copy port channels to device
ctx->portChannelDeviceHandles =
mscclpp::detail::gpuCallocShared<mscclpp::DeviceHandle<mscclpp::PortChannel>>(portChannels.size());
mscclpp::gpuMemcpy(ctx->portChannelDeviceHandles.get(), portChannels.data(), portChannels.size(),
cudaMemcpyHostToDevice);
// Keep registered memory references to prevent deallocation
std::transform(remoteRegMemories.begin(), remoteRegMemories.end(), std::back_inserter(ctx->registeredMemories),
[](const auto& fut) { return fut.get(); });
ctx->registeredMemories.push_back(inputBufRegMem);
ctx->registeredMemories.push_back(outputBufRegMem);
return ctx;
}
mscclpp::AlgorithmCtxKey generateAlltoallvContextKey(const void* input, void* output, size_t inputSize,
size_t outputSize, mscclpp::DataType dtype) {
return {(void*)input, output, inputSize, outputSize, 0};
}
};
std::shared_ptr<mscclpp::Algorithm> createAlltoallvAlgorithm() {
auto alltoallvAlgoBuilder = std::make_shared<AllToAllVAlgoBuilder>();
return alltoallvAlgoBuilder->build();
}
void deletePtr(PyObject* capsule) {
const char* name = PyCapsule_GetName(capsule);
void* p = PyCapsule_GetPointer(capsule, name);
if (p == nullptr) {
PyErr_WriteUnraisable(capsule);
return;
}
auto* ptr = static_cast<std::shared_ptr<mscclpp::Algorithm>*>(p);
delete ptr;
}
PyObject* getCapsule(std::shared_ptr<mscclpp::Algorithm> algo) {
auto* ptrCopy = new std::shared_ptr<mscclpp::Algorithm>(algo);
PyObject* capsule = PyCapsule_New(ptrCopy, mscclpp::ALGORITHM_NATIVE_CAPSULE_NAME, deletePtr);
if (capsule == nullptr) {
delete ptrCopy;
throw pybind11::error_already_set();
}
return capsule;
}
PYBIND11_MODULE(mscclpp_alltoallv, m) {
m.doc() = "AllToAllV implementation for MSCCLPP - handles variable element counts per rank";
m.def(
"create_alltoallv_algorithm",
[]() { return py::reinterpret_steal<py::capsule>(getCapsule(createAlltoallvAlgorithm())); },
"Create an alltoallv algorithm and return it as a PyCapsule usable by MSCCL++ Python bindings");
}

View File

@@ -18,3 +18,4 @@ add_mscclpp_test_executable(sendrecv_test_perf sendrecv_test.cu)
add_mscclpp_test_executable(allgather_test_perf allgather_test.cu)
add_mscclpp_test_executable(allreduce_test_perf allreduce_test.cu)
add_mscclpp_test_executable(alltoall_test_perf alltoall_test.cu)
add_mscclpp_test_executable(alltoallv_test_perf alltoallv_test.cu)

View File

@@ -0,0 +1,347 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
// AllToAllV test - tests variable-length alltoall operations
// This test validates the alltoallv kernel that handles variable element counts per rank.
#include <cstdint>
#include <cstring>
#include <numeric>
#include <mscclpp/concurrency_device.hpp>
#include "common.hpp"
#if defined(__HIP_PLATFORM_AMD__)
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
template <class T>
using DeviceHandle = mscclpp::DeviceHandle<T>;
__constant__ DeviceHandle<mscclpp::PortChannel> constPortChansV[16];
__device__ mscclpp::DeviceSyncer deviceSyncerV;
static void* localRecvBuffV;
static void* localSendBuffV;
// Device arrays for variable counts and displacements
static size_t* d_sendCounts;
static size_t* d_sendDispls;
static size_t* d_recvCounts;
static size_t* d_recvDispls;
/**
* AllToAllV kernel implementation
*
* Each rank sends sendCounts[i] bytes to rank i at sendDispls[i] offset,
* and receives recvCounts[i] bytes from rank i at recvDispls[i] offset.
*
* Uses ring-based exchange pattern to avoid deadlocks.
*/
__global__ void __launch_bounds__(1024)
alltoallv0(int rank, int worldSize,
const void* sendBuff, void* recvBuff,
const size_t* sendCounts, const size_t* sendDispls,
const size_t* recvCounts, const size_t* recvDispls) {
int tid = threadIdx.x;
int nPeers = worldSize - 1;
// Step 1: Copy local data (rank's own portion)
if (tid == 0 && sendCounts[rank] > 0) {
const char* src = (const char*)sendBuff + sendDispls[rank];
char* dst = (char*)recvBuff + recvDispls[rank];
memcpy(dst, src, sendCounts[rank]);
}
__syncthreads();
// Step 2: Each warp handles one peer for sending
int warpId = tid / WARP_SIZE;
int laneId = tid % WARP_SIZE;
if (warpId < nPeers && laneId == 0) {
// Determine which peer this warp handles
int peer = warpId < rank ? warpId : warpId + 1;
int chanIdx = warpId;
if (sendCounts[peer] > 0) {
constPortChansV[chanIdx].putWithSignal(
recvDispls[rank], // dst offset in peer's buffer
sendDispls[peer], // src offset in our buffer
sendCounts[peer] // size
);
}
}
__syncthreads();
// Step 3: Flush all pending operations
if (warpId < nPeers && laneId == 0) {
int peer = warpId < rank ? warpId : warpId + 1;
if (sendCounts[peer] > 0) {
constPortChansV[warpId].flush();
}
}
__syncthreads();
// Step 4: Wait for all incoming data
if (warpId < nPeers && laneId == 0) {
int peer = warpId < rank ? warpId : warpId + 1;
if (recvCounts[peer] > 0) {
constPortChansV[warpId].wait();
}
}
__syncthreads();
}
/**
* Ring-based AllToAllV kernel for larger world sizes
*
* Uses step-by-step ring pattern to exchange data, sending to (rank+step) and
* receiving from (rank-step) in each step. Single block to avoid concurrent
* access to the same port channels.
*/
__global__ void __launch_bounds__(1024)
alltoallv1(int rank, int worldSize,
const void* sendBuff, void* recvBuff,
const size_t* sendCounts, const size_t* sendDispls,
const size_t* recvCounts, const size_t* recvDispls) {
// Copy local data first
if (threadIdx.x == 0) {
if (sendCounts[rank] > 0) {
const char* src = (const char*)sendBuff + sendDispls[rank];
char* dst = (char*)recvBuff + recvDispls[rank];
memcpy(dst, src, sendCounts[rank]);
}
}
__syncthreads();
// Ring-based exchange - single thread handles the communication
// to avoid race conditions on port channels
if (threadIdx.x == 0) {
for (int step = 1; step < worldSize; step++) {
int sendPeer = (rank + step) % worldSize;
int recvPeer = (rank - step + worldSize) % worldSize;
int sendChanIdx = sendPeer < rank ? sendPeer : sendPeer - 1;
int recvChanIdx = recvPeer < rank ? recvPeer : recvPeer - 1;
// Send data to sendPeer (non-blocking put with signal)
if (sendCounts[sendPeer] > 0) {
constPortChansV[sendChanIdx].putWithSignal(
recvDispls[rank], // dst offset in peer's buffer
sendDispls[sendPeer], // src offset in our buffer
sendCounts[sendPeer] // size
);
constPortChansV[sendChanIdx].flush();
}
// Wait for data from recvPeer
if (recvCounts[recvPeer] > 0) {
constPortChansV[recvChanIdx].wait();
}
}
}
}
class AllToAllVTestColl : public BaseTestColl {
public:
AllToAllVTestColl() = default;
~AllToAllVTestColl() override = default;
void runColl(const TestArgs& args, cudaStream_t stream) override;
void initData(const TestArgs& args, std::vector<void*> sendBuff, void* expectedBuff) override;
void getBw(const double deltaSec, double& algBw /*OUT*/, double& busBw /*OUT*/) override;
void setupCollTest(size_t size) override;
std::vector<KernelRestriction> getKernelRestrictions() override;
private:
// Host-side counts and displacements
std::vector<size_t> sendCounts_;
std::vector<size_t> sendDispls_;
std::vector<size_t> recvCounts_;
std::vector<size_t> recvDispls_;
size_t totalSendBytes_;
size_t totalRecvBytes_;
};
void AllToAllVTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
const int worldSize = args.totalRanks;
const int rank = args.rank;
const int kernelNum = args.kernelNum;
// Reset device syncer
mscclpp::DeviceSyncer syncer = {};
CUDATHROW(cudaMemcpyToSymbol(deviceSyncerV, &syncer, sizeof(mscclpp::DeviceSyncer)));
if (kernelNum == 0) {
int nThreads = (worldSize - 1) * WARP_SIZE;
if (nThreads < 32) nThreads = 32;
if (nThreads > 1024) nThreads = 1024;
alltoallv0<<<1, nThreads, 0, stream>>>(
rank, worldSize,
localSendBuffV, localRecvBuffV,
d_sendCounts, d_sendDispls,
d_recvCounts, d_recvDispls);
} else if (kernelNum == 1) {
// Single block, single thread for ring-based serialized communication
alltoallv1<<<1, 32, 0, stream>>>(
rank, worldSize,
localSendBuffV, localRecvBuffV,
d_sendCounts, d_sendDispls,
d_recvCounts, d_recvDispls);
}
}
void AllToAllVTestColl::initData(const TestArgs& args, std::vector<void*> sendBuff, void* expectedBuff) {
if (sendBuff.size() != 1) throw std::runtime_error("unexpected error");
const int rank = args.rank;
const int worldSize = args.totalRanks;
// Create send data: each segment has values identifying source and destination
std::vector<int> sendData(totalSendBytes_ / sizeof(int), 0);
for (int peer = 0; peer < worldSize; peer++) {
size_t offset = sendDispls_[peer] / sizeof(int);
size_t count = sendCounts_[peer] / sizeof(int);
for (size_t i = 0; i < count; i++) {
// Encode: rank * 10000 + peer * 100 + position
sendData[offset + i] = rank * 10000 + peer * 100 + i;
}
}
CUDATHROW(cudaMemcpy(sendBuff[0], sendData.data(), totalSendBytes_, cudaMemcpyHostToDevice));
// Create expected data: we receive from each peer
std::vector<int> expectedData(totalRecvBytes_ / sizeof(int), 0);
for (int peer = 0; peer < worldSize; peer++) {
size_t offset = recvDispls_[peer] / sizeof(int);
size_t count = recvCounts_[peer] / sizeof(int);
for (size_t i = 0; i < count; i++) {
// We receive data sent by peer to us
expectedData[offset + i] = peer * 10000 + rank * 100 + i;
}
}
std::memcpy(expectedBuff, expectedData.data(), totalRecvBytes_);
// Copy counts and displacements to device
CUDATHROW(cudaMemcpy(d_sendCounts, sendCounts_.data(), worldSize * sizeof(size_t), cudaMemcpyHostToDevice));
CUDATHROW(cudaMemcpy(d_sendDispls, sendDispls_.data(), worldSize * sizeof(size_t), cudaMemcpyHostToDevice));
CUDATHROW(cudaMemcpy(d_recvCounts, recvCounts_.data(), worldSize * sizeof(size_t), cudaMemcpyHostToDevice));
CUDATHROW(cudaMemcpy(d_recvDispls, recvDispls_.data(), worldSize * sizeof(size_t), cudaMemcpyHostToDevice));
}
void AllToAllVTestColl::getBw(const double deltaSec, double& algBw, double& busBw) {
double baseBw = (double)(totalRecvBytes_) / 1.0E9 / deltaSec;
algBw = baseBw;
double factor = ((double)(worldSize_ - 1)) / ((double)worldSize_);
busBw = baseBw * factor;
}
void AllToAllVTestColl::setupCollTest(size_t size) {
// For alltoallv, we use variable sizes per peer
// For testing: rank i sends (rank + 1) * baseCount elements to each peer
// Each peer j sends (j + 1) * baseCount elements to us
size_t baseBytes = size / (worldSize_ * worldSize_); // Base unit for variable sizing
if (baseBytes < sizeof(int)) baseBytes = sizeof(int);
baseBytes = (baseBytes / sizeof(int)) * sizeof(int); // Align to int size
sendCounts_.resize(worldSize_);
sendDispls_.resize(worldSize_);
recvCounts_.resize(worldSize_);
recvDispls_.resize(worldSize_);
// Each rank sends the same amount to each peer (for simplicity in this test)
// In a real MOE scenario, these would be variable
totalSendBytes_ = 0;
totalRecvBytes_ = 0;
for (int peer = 0; peer < worldSize_; peer++) {
sendCounts_[peer] = baseBytes;
sendDispls_[peer] = totalSendBytes_;
totalSendBytes_ += sendCounts_[peer];
recvCounts_[peer] = baseBytes;
recvDispls_[peer] = totalRecvBytes_;
totalRecvBytes_ += recvCounts_[peer];
}
sendCount_ = totalSendBytes_ / typeSize_;
recvCount_ = totalRecvBytes_ / typeSize_;
paramCount_ = sendCount_;
expectedCount_ = recvCount_;
mscclpp::DeviceSyncer syncer = {};
CUDATHROW(cudaMemcpyToSymbol(deviceSyncerV, &syncer, sizeof(mscclpp::DeviceSyncer)));
}
std::vector<KernelRestriction> AllToAllVTestColl::getKernelRestrictions() {
return {
{0, "alltoallv0", true, 1, 4 * worldSize_},
{1, "alltoallv1", true, 1, 4 * worldSize_}
};
}
class AllToAllVTestEngine : public BaseTestEngine {
public:
AllToAllVTestEngine(const TestArgs& args);
~AllToAllVTestEngine() override = default;
void allocateBuffer() override;
void setupConnections() override;
std::vector<void*> getSendBuff() override;
void* getRecvBuff() override;
void* getScratchBuff() override;
private:
void* getExpectedBuff() override;
bool isInPlace() const;
std::shared_ptr<int> sendBuff_;
std::shared_ptr<int> recvBuff_;
std::shared_ptr<int[]> expectedBuff_;
};
bool AllToAllVTestEngine::isInPlace() const { return false; }
AllToAllVTestEngine::AllToAllVTestEngine(const TestArgs& args) : BaseTestEngine(args, "alltoallv") { inPlace_ = false; }
void AllToAllVTestEngine::allocateBuffer() {
sendBuff_ = mscclpp::GpuBuffer<int>(args_.maxBytes / sizeof(int)).memory();
recvBuff_ = mscclpp::GpuBuffer<int>(args_.maxBytes / sizeof(int)).memory();
expectedBuff_ = std::shared_ptr<int[]>(new int[args_.maxBytes / sizeof(int)]);
localSendBuffV = sendBuff_.get();
localRecvBuffV = recvBuff_.get();
// Allocate device arrays for counts and displacements
CUDATHROW(cudaMalloc(&d_sendCounts, args_.totalRanks * sizeof(size_t)));
CUDATHROW(cudaMalloc(&d_sendDispls, args_.totalRanks * sizeof(size_t)));
CUDATHROW(cudaMalloc(&d_recvCounts, args_.totalRanks * sizeof(size_t)));
CUDATHROW(cudaMalloc(&d_recvDispls, args_.totalRanks * sizeof(size_t)));
}
void AllToAllVTestEngine::setupConnections() {
std::vector<DeviceHandle<mscclpp::PortChannel>> portChannels;
setupMeshConnections(portChannels, sendBuff_.get(), args_.maxBytes, recvBuff_.get(), args_.maxBytes);
if (portChannels.size() > sizeof(constPortChansV) / sizeof(DeviceHandle<mscclpp::PortChannel>)) {
throw std::runtime_error("Too many port channels for alltoallv test");
}
CUDATHROW(cudaMemcpyToSymbol(constPortChansV, portChannels.data(),
sizeof(DeviceHandle<mscclpp::PortChannel>) * portChannels.size()));
}
std::vector<void*> AllToAllVTestEngine::getSendBuff() { return {sendBuff_.get()}; }
void* AllToAllVTestEngine::getExpectedBuff() { return expectedBuff_.get(); }
void* AllToAllVTestEngine::getRecvBuff() {
if (this->isInPlace())
return sendBuff_.get();
else
return recvBuff_.get();
}
void* AllToAllVTestEngine::getScratchBuff() { return nullptr; }
std::shared_ptr<BaseTestEngine> getTestEngine(const TestArgs& args) {
return std::make_shared<AllToAllVTestEngine>(args);
}
std::shared_ptr<BaseTestColl> getTestColl() { return std::make_shared<AllToAllVTestColl>(); }