mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
Add alltoallv kernel and test
This commit is contained in:
315
examples/torch-integration/alltoallv.py
Normal file
315
examples/torch-integration/alltoallv.py
Normal file
@@ -0,0 +1,315 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# AllToAllV implementation for MSCCLPP
|
||||
# This module provides a PyTorch-compatible alltoallv operation using MSCCLPP.
|
||||
#
|
||||
# Usage:
|
||||
# MSCCLPP_MASTER_ADDR=<master_ip> MSCCLPP_MASTER_PORT=<port> \
|
||||
# torchrun --nnodes=1 --nproc_per_node=8 alltoallv.py
|
||||
#
|
||||
# For AMD GPUs:
|
||||
# MSCCLPP_MASTER_ADDR=<master_ip> MSCCLPP_MASTER_PORT=<port> \
|
||||
# GPU_MAX_HW_QUEUES=7 torchrun --nnodes=1 --nproc_per_node=8 alltoallv.py
|
||||
|
||||
import mscclpp
|
||||
import mscclpp.utils as mscclpp_utils
|
||||
import torch
|
||||
import os
|
||||
import netifaces as ni
|
||||
import ipaddress
|
||||
from typing import List, Optional
|
||||
|
||||
_abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
def interfaces_for_ip_netifaces(ip: str):
|
||||
"""Find the network interface for a given IP address."""
|
||||
target = ipaddress.ip_address(ip)
|
||||
for interface in ni.interfaces():
|
||||
addresses = ni.ifaddresses(interface)
|
||||
if ni.AF_INET in addresses:
|
||||
for link in addresses[ni.AF_INET]:
|
||||
if "addr" in link:
|
||||
addr = ipaddress.ip_address(link["addr"])
|
||||
if addr == target:
|
||||
return interface
|
||||
return None
|
||||
|
||||
|
||||
class AllToAllVComm:
|
||||
"""
|
||||
AllToAllV communication class using MSCCLPP.
|
||||
|
||||
This class provides a customized alltoallv implementation that handles
|
||||
variable element counts per rank, similar to MPI_Alltoallv or the
|
||||
batch_all_to_all_v pattern commonly used in MOE (Mixture of Experts) models.
|
||||
|
||||
Unlike NCCL's ncclGroupStart/ncclGroupEnd approach, MSCCLPP uses explicit
|
||||
put/signal/wait operations on PortChannels for communication.
|
||||
|
||||
Attributes:
|
||||
comm: MSCCLPP CommGroup instance
|
||||
rank: Current rank
|
||||
world_size: Total number of ranks
|
||||
"""
|
||||
|
||||
def __init__(self, comm: mscclpp.CommGroup):
|
||||
"""
|
||||
Initialize AllToAllV communication.
|
||||
|
||||
Args:
|
||||
comm: MSCCLPP CommGroup instance
|
||||
"""
|
||||
self.comm = comm
|
||||
self.rank = comm.my_rank
|
||||
self.world_size = comm.nranks
|
||||
self.local_rank = comm.my_rank % comm.nranks_per_node
|
||||
self.n_ranks_per_node = comm.nranks_per_node
|
||||
self.executor = mscclpp.Executor(comm.communicator)
|
||||
|
||||
# Compile and load the native CUDA kernel
|
||||
mscclpp_native = mscclpp.compile_native(
|
||||
name="mscclpp_alltoallv",
|
||||
file=os.path.join(_abs_path, "alltoallv_kernel.cu")
|
||||
)
|
||||
capsule = mscclpp_native.create_alltoallv_algorithm()
|
||||
self.algorithm = mscclpp.Algorithm.create_from_native_capsule(capsule)
|
||||
|
||||
def alltoallv(
|
||||
self,
|
||||
send_tensor: torch.Tensor,
|
||||
recv_tensor: torch.Tensor,
|
||||
send_counts: torch.Tensor,
|
||||
send_displs: torch.Tensor,
|
||||
recv_counts: torch.Tensor,
|
||||
recv_displs: torch.Tensor,
|
||||
stream: Optional[torch.cuda.Stream] = None
|
||||
):
|
||||
"""
|
||||
Perform alltoallv operation with variable element counts.
|
||||
|
||||
This function exchanges data between all ranks where each rank can send
|
||||
different amounts of data to each other rank.
|
||||
|
||||
Args:
|
||||
send_tensor: Source tensor containing all data to be sent
|
||||
recv_tensor: Destination tensor for received data
|
||||
send_counts: Tensor of shape [world_size] with byte counts to send to each rank
|
||||
send_displs: Tensor of shape [world_size] with byte offsets in send_tensor for each rank
|
||||
recv_counts: Tensor of shape [world_size] with byte counts to receive from each rank
|
||||
recv_displs: Tensor of shape [world_size] with byte offsets in recv_tensor for each rank
|
||||
stream: Optional CUDA stream to use for the operation
|
||||
|
||||
Note:
|
||||
- All count and displacement tensors should be on GPU and contain size_t values
|
||||
- send_counts[i] is the number of bytes to send to rank i
|
||||
- send_displs[i] is the byte offset in send_tensor for data going to rank i
|
||||
- recv_counts[i] is the number of bytes to receive from rank i
|
||||
- recv_displs[i] is the byte offset in recv_tensor for data from rank i
|
||||
"""
|
||||
# Ensure counts and displacements are on GPU and have correct dtype
|
||||
assert send_counts.device.type == "cuda", "send_counts must be on GPU"
|
||||
assert send_displs.device.type == "cuda", "send_displs must be on GPU"
|
||||
assert recv_counts.device.type == "cuda", "recv_counts must be on GPU"
|
||||
assert recv_displs.device.type == "cuda", "recv_displs must be on GPU"
|
||||
|
||||
# Prepare extras dict with device pointers for counts and displacements
|
||||
extras = {
|
||||
"sendCounts": send_counts.data_ptr(),
|
||||
"sendDispls": send_displs.data_ptr(),
|
||||
"recvCounts": recv_counts.data_ptr(),
|
||||
"recvDispls": recv_displs.data_ptr(),
|
||||
}
|
||||
|
||||
cuda_stream = stream.cuda_stream if stream is not None else 0
|
||||
|
||||
self.algorithm.execute(
|
||||
self.comm.communicator,
|
||||
send_tensor.data_ptr(),
|
||||
recv_tensor.data_ptr(),
|
||||
send_tensor.nbytes,
|
||||
recv_tensor.nbytes,
|
||||
mscclpp_utils.torch_dtype_to_mscclpp_dtype(send_tensor.dtype),
|
||||
stream=cuda_stream,
|
||||
extras=extras
|
||||
)
|
||||
|
||||
def alltoallv_by_elements(
|
||||
self,
|
||||
send_tensor: torch.Tensor,
|
||||
recv_tensor: torch.Tensor,
|
||||
send_counts_elements: torch.Tensor,
|
||||
recv_counts_elements: torch.Tensor,
|
||||
stream: Optional[torch.cuda.Stream] = None
|
||||
):
|
||||
"""
|
||||
Convenience function for alltoallv with element counts (not byte counts).
|
||||
|
||||
This function automatically computes byte counts and displacements from
|
||||
element counts, making it easier to use with typical MOE patterns.
|
||||
|
||||
Args:
|
||||
send_tensor: Source tensor containing all data to be sent
|
||||
recv_tensor: Destination tensor for received data
|
||||
send_counts_elements: Tensor of shape [world_size] with element counts to send to each rank
|
||||
recv_counts_elements: Tensor of shape [world_size] with element counts to receive from each rank
|
||||
stream: Optional CUDA stream to use for the operation
|
||||
"""
|
||||
element_size = send_tensor.element_size()
|
||||
|
||||
# Convert element counts to byte counts
|
||||
send_counts = send_counts_elements.to(torch.int64) * element_size
|
||||
recv_counts = recv_counts_elements.to(torch.int64) * element_size
|
||||
|
||||
# Compute displacements (exclusive prefix sum)
|
||||
send_displs = torch.zeros(self.world_size, dtype=torch.int64, device=send_tensor.device)
|
||||
recv_displs = torch.zeros(self.world_size, dtype=torch.int64, device=recv_tensor.device)
|
||||
|
||||
if self.world_size > 1:
|
||||
send_displs[1:] = torch.cumsum(send_counts[:-1], dim=0)
|
||||
recv_displs[1:] = torch.cumsum(recv_counts[:-1], dim=0)
|
||||
|
||||
self.alltoallv(
|
||||
send_tensor, recv_tensor,
|
||||
send_counts, send_displs,
|
||||
recv_counts, recv_displs,
|
||||
stream
|
||||
)
|
||||
|
||||
def barrier_cpu(self):
|
||||
"""CPU barrier to synchronize all ranks."""
|
||||
self.comm.barrier()
|
||||
|
||||
|
||||
def batch_alltoallv(
|
||||
comm: AllToAllVComm,
|
||||
inputs: List[torch.Tensor],
|
||||
outputs: List[torch.Tensor],
|
||||
in_sizes: torch.Tensor,
|
||||
out_sizes: torch.Tensor,
|
||||
stream: Optional[torch.cuda.Stream] = None
|
||||
):
|
||||
"""
|
||||
Batch all-to-all-v operation for multiple tensors.
|
||||
|
||||
This function replicates the pattern from MOE implementations:
|
||||
```
|
||||
for k in range(len(inputs)):
|
||||
ncclGroupStart()
|
||||
for i in range(world_size):
|
||||
ncclSend(in_buff + in_offset, in_sizes[i], ...)
|
||||
ncclRecv(out_buff + out_offset, out_sizes[i], ...)
|
||||
ncclGroupEnd()
|
||||
```
|
||||
|
||||
Since MSCCLPP doesn't support ncclGroupStart/ncclGroupEnd, we implement
|
||||
this using explicit alltoallv operations for each tensor in the batch.
|
||||
|
||||
Args:
|
||||
comm: AllToAllVComm instance
|
||||
inputs: List of input tensors to send
|
||||
outputs: List of output tensors to receive
|
||||
in_sizes: Tensor of shape [world_size] with element counts to send to each rank
|
||||
out_sizes: Tensor of shape [world_size] with element counts to receive from each rank
|
||||
stream: Optional CUDA stream
|
||||
"""
|
||||
assert len(inputs) == len(outputs), "Input and output lists must have same length"
|
||||
|
||||
# Ensure sizes are on CPU for computing displacements
|
||||
in_sizes_cpu = in_sizes.cpu().to(torch.int64)
|
||||
out_sizes_cpu = out_sizes.cpu().to(torch.int64)
|
||||
|
||||
for k in range(len(inputs)):
|
||||
input_tensor = inputs[k]
|
||||
output_tensor = outputs[k]
|
||||
|
||||
element_size = input_tensor.element_size()
|
||||
|
||||
# Compute byte counts
|
||||
send_counts = (in_sizes_cpu * element_size).cuda()
|
||||
recv_counts = (out_sizes_cpu * element_size).cuda()
|
||||
|
||||
# Compute displacements
|
||||
send_displs = torch.zeros(comm.world_size, dtype=torch.int64, device="cuda")
|
||||
recv_displs = torch.zeros(comm.world_size, dtype=torch.int64, device="cuda")
|
||||
|
||||
if comm.world_size > 1:
|
||||
send_displs_cpu = torch.zeros(comm.world_size, dtype=torch.int64)
|
||||
recv_displs_cpu = torch.zeros(comm.world_size, dtype=torch.int64)
|
||||
send_displs_cpu[1:] = torch.cumsum(send_counts.cpu()[:-1], dim=0)
|
||||
recv_displs_cpu[1:] = torch.cumsum(recv_counts.cpu()[:-1], dim=0)
|
||||
send_displs = send_displs_cpu.cuda()
|
||||
recv_displs = recv_displs_cpu.cuda()
|
||||
|
||||
comm.alltoallv(
|
||||
input_tensor, output_tensor,
|
||||
send_counts, send_displs,
|
||||
recv_counts, recv_displs,
|
||||
stream
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
"""Test the alltoallv implementation."""
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", os.environ["RANK"]))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
master_addr = os.environ["MSCCLPP_MASTER_ADDR"]
|
||||
master_port = os.environ["MSCCLPP_MASTER_PORT"]
|
||||
|
||||
interface = interfaces_for_ip_netifaces(master_addr)
|
||||
if interface is None:
|
||||
raise ValueError(f"Cannot find network interface for IP address {master_addr}")
|
||||
|
||||
interfaceIpPortTrio = f"{interface}:{master_addr}:{master_port}"
|
||||
mscclpp_group = mscclpp.CommGroup(interfaceIpPortTrio=interfaceIpPortTrio, rank=rank, size=world_size)
|
||||
|
||||
# Create communicator
|
||||
comm = AllToAllVComm(mscclpp_group)
|
||||
|
||||
# Test with variable sizes per rank
|
||||
# Each rank sends different amounts to different peers
|
||||
# For simplicity: rank i sends (i+1)*100 elements to each peer
|
||||
elements_per_peer = (rank + 1) * 100
|
||||
|
||||
# Create send buffer with data
|
||||
total_send = elements_per_peer * world_size
|
||||
send_tensor = torch.arange(total_send, device="cuda", dtype=torch.float32) + rank * 10000
|
||||
|
||||
# Receive buffer needs to accommodate variable amounts from each sender
|
||||
# Each sender j sends (j+1)*100 elements to us
|
||||
recv_counts_cpu = torch.tensor([(j + 1) * 100 for j in range(world_size)], dtype=torch.int64)
|
||||
total_recv = recv_counts_cpu.sum().item()
|
||||
recv_tensor = torch.zeros(total_recv, device="cuda", dtype=torch.float32)
|
||||
|
||||
# Send counts: we send elements_per_peer to everyone
|
||||
send_counts_cpu = torch.tensor([elements_per_peer] * world_size, dtype=torch.int64)
|
||||
|
||||
# Move to GPU
|
||||
send_counts = send_counts_cpu.cuda()
|
||||
recv_counts = recv_counts_cpu.cuda()
|
||||
|
||||
comm.barrier_cpu()
|
||||
|
||||
# Perform alltoallv
|
||||
comm.alltoallv_by_elements(
|
||||
send_tensor, recv_tensor,
|
||||
send_counts, recv_counts,
|
||||
stream=torch.cuda.current_stream()
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
comm.barrier_cpu()
|
||||
|
||||
print(f"Rank {rank}: alltoallv completed successfully!")
|
||||
print(f" Sent {total_send} elements, received {total_recv} elements")
|
||||
|
||||
# Cleanup
|
||||
comm = None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
400
examples/torch-integration/alltoallv_kernel.cu
Normal file
400
examples/torch-integration/alltoallv_kernel.cu
Normal file
@@ -0,0 +1,400 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// AllToAllV implementation for MSCCLPP
|
||||
// This kernel handles variable element counts per rank for alltoallv operations.
|
||||
// Unlike NCCL's ncclGroupStart/ncclGroupEnd approach, mscclpp uses explicit
|
||||
// put/signal/wait operations on PortChannels.
|
||||
|
||||
#include <Python.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include <mscclpp/algorithm.hpp>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/port_channel.hpp>
|
||||
#include <mscclpp/port_channel_device.hpp>
|
||||
#include <mscclpp/concurrency_device.hpp>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
#define WARP_SIZE 64
|
||||
#else
|
||||
#define WARP_SIZE 32
|
||||
#endif
|
||||
|
||||
// Device syncer for synchronization across blocks
|
||||
__device__ mscclpp::DeviceSyncer alltoallvDeviceSyncer;
|
||||
|
||||
/**
|
||||
* AllToAllV kernel implementation
|
||||
*
|
||||
* This kernel performs an all-to-all exchange with variable-length data per rank.
|
||||
* Each rank sends sendCounts[i] elements to rank i at sendDispls[i] offset,
|
||||
* and receives recvCounts[i] elements from rank i at recvDispls[i] offset.
|
||||
*
|
||||
* Since mscclpp doesn't support ncclGroupStart/ncclGroupEnd, we implement
|
||||
* the exchange using explicit put/signal/wait operations on PortChannels.
|
||||
* The communication pattern uses a ring-based approach to avoid deadlocks.
|
||||
*
|
||||
* @param portChannels Array of PortChannel handles for each peer (worldSize-1 channels)
|
||||
* @param rank Current rank
|
||||
* @param worldSize Total number of ranks
|
||||
* @param sendBuff Source buffer containing data to send
|
||||
* @param recvBuff Destination buffer for received data
|
||||
* @param sendCounts Array of send counts for each rank (in bytes)
|
||||
* @param sendDispls Array of send displacements for each rank (in bytes)
|
||||
* @param recvCounts Array of receive counts for each rank (in bytes)
|
||||
* @param recvDispls Array of receive displacements for each rank (in bytes)
|
||||
*/
|
||||
__global__ void __launch_bounds__(1024)
|
||||
alltoallv_kernel(mscclpp::DeviceHandle<mscclpp::PortChannel>* portChannels,
|
||||
int rank,
|
||||
int worldSize,
|
||||
const void* sendBuff,
|
||||
void* recvBuff,
|
||||
const size_t* sendCounts,
|
||||
const size_t* sendDispls,
|
||||
const size_t* recvCounts,
|
||||
const size_t* recvDispls) {
|
||||
// First, copy local data (rank's own portion) from send to recv buffer
|
||||
// This doesn't require any communication
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
||||
if (sendCounts[rank] > 0) {
|
||||
// Local copy: sendBuff[sendDispls[rank]] -> recvBuff[recvDispls[rank]]
|
||||
const char* src = (const char*)sendBuff + sendDispls[rank];
|
||||
char* dst = (char*)recvBuff + recvDispls[rank];
|
||||
memcpy(dst, src, sendCounts[rank]);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Ring-based exchange pattern to avoid deadlocks
|
||||
// In each step i, rank sends to (rank + i) % worldSize and receives from (rank - i + worldSize) % worldSize
|
||||
for (int step = 1; step < worldSize; step++) {
|
||||
int sendPeer = (rank + step) % worldSize;
|
||||
int recvPeer = (rank - step + worldSize) % worldSize;
|
||||
|
||||
// Get channel indices (portChannels excludes self, so adjust index)
|
||||
int sendChanIdx = sendPeer < rank ? sendPeer : sendPeer - 1;
|
||||
int recvChanIdx = recvPeer < rank ? recvPeer : recvPeer - 1;
|
||||
|
||||
// Each warp handles one peer
|
||||
int wid = threadIdx.x / WARP_SIZE;
|
||||
int lid = threadIdx.x % WARP_SIZE;
|
||||
|
||||
// Send data to sendPeer if there's data to send
|
||||
if (wid == 0 && lid == 0) {
|
||||
if (sendCounts[sendPeer] > 0) {
|
||||
// putWithSignal: copy data and signal completion
|
||||
// src offset: sendDispls[sendPeer] in our sendBuff
|
||||
// dst offset: recvDispls[rank] in peer's recvBuff (where our data should go)
|
||||
portChannels[sendChanIdx].putWithSignal(
|
||||
recvDispls[rank], // dst offset in peer's recv buffer (where we write)
|
||||
sendDispls[sendPeer], // src offset in our send buffer
|
||||
sendCounts[sendPeer] // size in bytes
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Sync all threads before flushing
|
||||
alltoallvDeviceSyncer.sync(gridDim.x);
|
||||
|
||||
// Flush to ensure data is sent
|
||||
if (wid == 0 && lid == 0) {
|
||||
if (sendCounts[sendPeer] > 0) {
|
||||
portChannels[sendChanIdx].flush();
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for data from recvPeer if we're expecting data
|
||||
if (wid == 0 && lid == 0) {
|
||||
if (recvCounts[recvPeer] > 0) {
|
||||
portChannels[recvChanIdx].wait();
|
||||
}
|
||||
}
|
||||
|
||||
// Sync all threads before next step
|
||||
alltoallvDeviceSyncer.sync(gridDim.x);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Simplified AllToAllV kernel for single-block execution
|
||||
*
|
||||
* This version is optimized for cases where all communication can be
|
||||
* handled within a single thread block.
|
||||
*/
|
||||
__global__ void __launch_bounds__(1024)
|
||||
alltoallv_simple_kernel(mscclpp::DeviceHandle<mscclpp::PortChannel>* portChannels,
|
||||
int rank,
|
||||
int worldSize,
|
||||
const void* sendBuff,
|
||||
void* recvBuff,
|
||||
const size_t* sendCounts,
|
||||
const size_t* sendDispls,
|
||||
const size_t* recvCounts,
|
||||
const size_t* recvDispls) {
|
||||
int tid = threadIdx.x;
|
||||
int nPeers = worldSize - 1;
|
||||
|
||||
// Step 1: Copy local data
|
||||
if (tid == 0 && sendCounts[rank] > 0) {
|
||||
const char* src = (const char*)sendBuff + sendDispls[rank];
|
||||
char* dst = (char*)recvBuff + recvDispls[rank];
|
||||
memcpy(dst, src, sendCounts[rank]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 2: Each warp handles one peer for sending
|
||||
// We have worldSize-1 peers, assign one warp per peer
|
||||
int warpId = tid / WARP_SIZE;
|
||||
int laneId = tid % WARP_SIZE;
|
||||
|
||||
if (warpId < nPeers && laneId == 0) {
|
||||
// Determine which peer this warp handles
|
||||
int peer = warpId < rank ? warpId : warpId + 1;
|
||||
int chanIdx = warpId;
|
||||
|
||||
if (sendCounts[peer] > 0) {
|
||||
portChannels[chanIdx].putWithSignal(
|
||||
recvDispls[rank], // dst offset in peer's buffer
|
||||
sendDispls[peer], // src offset in our buffer
|
||||
sendCounts[peer] // size
|
||||
);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 3: Flush all pending operations
|
||||
if (warpId < nPeers && laneId == 0) {
|
||||
int peer = warpId < rank ? warpId : warpId + 1;
|
||||
if (sendCounts[peer] > 0) {
|
||||
portChannels[warpId].flush();
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 4: Wait for all incoming data
|
||||
if (warpId < nPeers && laneId == 0) {
|
||||
int peer = warpId < rank ? warpId : warpId + 1;
|
||||
if (recvCounts[peer] > 0) {
|
||||
portChannels[warpId].wait();
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Context to hold all necessary state for alltoallv execution
|
||||
struct AllToAllVContext {
|
||||
int rank;
|
||||
int worldSize;
|
||||
int nRanksPerNode;
|
||||
|
||||
std::vector<mscclpp::RegisteredMemory> registeredMemories;
|
||||
std::shared_ptr<mscclpp::DeviceHandle<mscclpp::PortChannel>> portChannelDeviceHandles;
|
||||
|
||||
// Device memory for counts and displacements
|
||||
size_t* d_sendCounts;
|
||||
size_t* d_sendDispls;
|
||||
size_t* d_recvCounts;
|
||||
size_t* d_recvDispls;
|
||||
};
|
||||
|
||||
class AllToAllVAlgoBuilder : public mscclpp::AlgorithmBuilder {
|
||||
public:
|
||||
AllToAllVAlgoBuilder() = default;
|
||||
~AllToAllVAlgoBuilder() {
|
||||
if (proxyService_) {
|
||||
proxyService_->stopProxy();
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<mscclpp::Algorithm> build() override {
|
||||
auto self = std::make_shared<AllToAllVAlgoBuilder>();
|
||||
std::shared_ptr<mscclpp::Algorithm> alltoallvAlgo = std::make_shared<mscclpp::NativeAlgorithm>(
|
||||
"alltoallv", "alltoallv",
|
||||
// Initialize function
|
||||
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
|
||||
// Kernel execution function
|
||||
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize, size_t outputSize,
|
||||
mscclpp::DataType dtype, [[maybe_unused]] mscclpp::ReduceOp op, cudaStream_t stream, int nBlocks,
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) {
|
||||
return self->alltoallvKernelFunc(ctx, input, output, inputSize, outputSize, dtype, stream, extras);
|
||||
},
|
||||
// Context initialization function
|
||||
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
size_t outputSize,
|
||||
mscclpp::DataType dtype) { return self->initAlltoallvContext(comm, input, output, inputSize, outputSize, dtype); },
|
||||
// Context key generation function
|
||||
[self](const void* input, void* output, size_t inputSize, size_t outputSize, mscclpp::DataType dtype) {
|
||||
return self->generateAlltoallvContextKey(input, output, inputSize, outputSize, dtype);
|
||||
});
|
||||
return alltoallvAlgo;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<mscclpp::Connection> conns_;
|
||||
std::shared_ptr<mscclpp::ProxyService> proxyService_;
|
||||
int worldSize_;
|
||||
|
||||
void initialize(std::shared_ptr<mscclpp::Communicator> comm) {
|
||||
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
|
||||
worldSize_ = comm->bootstrap()->getNranks();
|
||||
for (int i = 0; i < worldSize_; i++) {
|
||||
if (i == comm->bootstrap()->getRank()) continue;
|
||||
connectionFutures.push_back(comm->connect(mscclpp::Transport::CudaIpc, i));
|
||||
}
|
||||
std::vector<mscclpp::Connection> connections;
|
||||
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
|
||||
[](const auto& future) { return future.get(); });
|
||||
this->conns_ = std::move(connections);
|
||||
proxyService_ = std::make_shared<mscclpp::ProxyService>();
|
||||
proxyService_->startProxy(true);
|
||||
}
|
||||
|
||||
mscclpp::CommResult alltoallvKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output,
|
||||
size_t inputSize, size_t outputSize,
|
||||
[[maybe_unused]] mscclpp::DataType dtype,
|
||||
cudaStream_t stream,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras) {
|
||||
auto algoCtx = std::static_pointer_cast<AllToAllVContext>(ctx);
|
||||
int rank = algoCtx->rank;
|
||||
int worldSize = algoCtx->worldSize;
|
||||
|
||||
// Extract send/recv counts and displacements from extras
|
||||
// The caller should pass these as device pointers via extras map
|
||||
auto it_sendCounts = extras.find("sendCounts");
|
||||
auto it_sendDispls = extras.find("sendDispls");
|
||||
auto it_recvCounts = extras.find("recvCounts");
|
||||
auto it_recvDispls = extras.find("recvDispls");
|
||||
|
||||
if (it_sendCounts == extras.end() || it_sendDispls == extras.end() ||
|
||||
it_recvCounts == extras.end() || it_recvDispls == extras.end()) {
|
||||
return mscclpp::CommResult::CommInternalError;
|
||||
}
|
||||
|
||||
const size_t* d_sendCounts = reinterpret_cast<const size_t*>(it_sendCounts->second);
|
||||
const size_t* d_sendDispls = reinterpret_cast<const size_t*>(it_sendDispls->second);
|
||||
const size_t* d_recvCounts = reinterpret_cast<const size_t*>(it_recvCounts->second);
|
||||
const size_t* d_recvDispls = reinterpret_cast<const size_t*>(it_recvDispls->second);
|
||||
|
||||
// Reset device syncer
|
||||
mscclpp::DeviceSyncer syncer = {};
|
||||
cudaMemcpyToSymbolAsync(alltoallvDeviceSyncer, &syncer, sizeof(mscclpp::DeviceSyncer), 0,
|
||||
cudaMemcpyHostToDevice, stream);
|
||||
|
||||
// Use simple kernel for small world sizes, multi-block for larger
|
||||
if (worldSize <= 16) {
|
||||
int nThreads = (worldSize - 1) * WARP_SIZE;
|
||||
if (nThreads < 32) nThreads = 32;
|
||||
if (nThreads > 1024) nThreads = 1024;
|
||||
|
||||
alltoallv_simple_kernel<<<1, nThreads, 0, stream>>>(
|
||||
algoCtx->portChannelDeviceHandles.get(),
|
||||
rank, worldSize,
|
||||
input, output,
|
||||
d_sendCounts, d_sendDispls,
|
||||
d_recvCounts, d_recvDispls);
|
||||
} else {
|
||||
alltoallv_kernel<<<1, 1024, 0, stream>>>(
|
||||
algoCtx->portChannelDeviceHandles.get(),
|
||||
rank, worldSize,
|
||||
input, output,
|
||||
d_sendCounts, d_sendDispls,
|
||||
d_recvCounts, d_recvDispls);
|
||||
}
|
||||
|
||||
if (cudaGetLastError() == cudaSuccess) {
|
||||
return mscclpp::CommResult::CommSuccess;
|
||||
}
|
||||
return mscclpp::CommResult::CommInternalError;
|
||||
}
|
||||
|
||||
std::shared_ptr<void> initAlltoallvContext(std::shared_ptr<mscclpp::Communicator> comm, const void* input,
|
||||
void* output, size_t inputSize, size_t outputSize,
|
||||
mscclpp::DataType dtype) {
|
||||
auto ctx = std::make_shared<AllToAllVContext>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->worldSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
|
||||
|
||||
// Register memories for input and output buffers
|
||||
mscclpp::RegisteredMemory inputBufRegMem =
|
||||
comm->registerMemory((void*)input, inputSize, mscclpp::Transport::CudaIpc);
|
||||
mscclpp::RegisteredMemory outputBufRegMem =
|
||||
comm->registerMemory(output, outputSize, mscclpp::Transport::CudaIpc);
|
||||
|
||||
// Exchange output buffer registration with all peers
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemories;
|
||||
for (int i = 0; i < ctx->worldSize; i++) {
|
||||
if (i == ctx->rank) continue;
|
||||
comm->sendMemory(outputBufRegMem, i, 0);
|
||||
remoteRegMemories.push_back(comm->recvMemory(i, 0));
|
||||
}
|
||||
|
||||
// Setup port channels for each peer
|
||||
std::vector<mscclpp::DeviceHandle<mscclpp::PortChannel>> portChannels;
|
||||
mscclpp::MemoryId inputMemoryId = this->proxyService_->addMemory(inputBufRegMem);
|
||||
|
||||
for (size_t i = 0; i < this->conns_.size(); i++) {
|
||||
auto remoteMemory = remoteRegMemories[i].get();
|
||||
mscclpp::MemoryId remoteMemoryId = this->proxyService_->addMemory(remoteMemory);
|
||||
portChannels.push_back(mscclpp::deviceHandle(this->proxyService_->portChannel(
|
||||
this->proxyService_->buildAndAddSemaphore(*comm, this->conns_[i]), remoteMemoryId, inputMemoryId)));
|
||||
}
|
||||
|
||||
// Allocate and copy port channels to device
|
||||
ctx->portChannelDeviceHandles =
|
||||
mscclpp::detail::gpuCallocShared<mscclpp::DeviceHandle<mscclpp::PortChannel>>(portChannels.size());
|
||||
mscclpp::gpuMemcpy(ctx->portChannelDeviceHandles.get(), portChannels.data(), portChannels.size(),
|
||||
cudaMemcpyHostToDevice);
|
||||
|
||||
// Keep registered memory references to prevent deallocation
|
||||
std::transform(remoteRegMemories.begin(), remoteRegMemories.end(), std::back_inserter(ctx->registeredMemories),
|
||||
[](const auto& fut) { return fut.get(); });
|
||||
ctx->registeredMemories.push_back(inputBufRegMem);
|
||||
ctx->registeredMemories.push_back(outputBufRegMem);
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
mscclpp::AlgorithmCtxKey generateAlltoallvContextKey(const void* input, void* output, size_t inputSize,
|
||||
size_t outputSize, mscclpp::DataType dtype) {
|
||||
return {(void*)input, output, inputSize, outputSize, 0};
|
||||
}
|
||||
};
|
||||
|
||||
std::shared_ptr<mscclpp::Algorithm> createAlltoallvAlgorithm() {
|
||||
auto alltoallvAlgoBuilder = std::make_shared<AllToAllVAlgoBuilder>();
|
||||
return alltoallvAlgoBuilder->build();
|
||||
}
|
||||
|
||||
void deletePtr(PyObject* capsule) {
|
||||
const char* name = PyCapsule_GetName(capsule);
|
||||
void* p = PyCapsule_GetPointer(capsule, name);
|
||||
if (p == nullptr) {
|
||||
PyErr_WriteUnraisable(capsule);
|
||||
return;
|
||||
}
|
||||
auto* ptr = static_cast<std::shared_ptr<mscclpp::Algorithm>*>(p);
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
PyObject* getCapsule(std::shared_ptr<mscclpp::Algorithm> algo) {
|
||||
auto* ptrCopy = new std::shared_ptr<mscclpp::Algorithm>(algo);
|
||||
PyObject* capsule = PyCapsule_New(ptrCopy, mscclpp::ALGORITHM_NATIVE_CAPSULE_NAME, deletePtr);
|
||||
if (capsule == nullptr) {
|
||||
delete ptrCopy;
|
||||
throw pybind11::error_already_set();
|
||||
}
|
||||
return capsule;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(mscclpp_alltoallv, m) {
|
||||
m.doc() = "AllToAllV implementation for MSCCLPP - handles variable element counts per rank";
|
||||
m.def(
|
||||
"create_alltoallv_algorithm",
|
||||
[]() { return py::reinterpret_steal<py::capsule>(getCapsule(createAlltoallvAlgorithm())); },
|
||||
"Create an alltoallv algorithm and return it as a PyCapsule usable by MSCCL++ Python bindings");
|
||||
}
|
||||
@@ -18,3 +18,4 @@ add_mscclpp_test_executable(sendrecv_test_perf sendrecv_test.cu)
|
||||
add_mscclpp_test_executable(allgather_test_perf allgather_test.cu)
|
||||
add_mscclpp_test_executable(allreduce_test_perf allreduce_test.cu)
|
||||
add_mscclpp_test_executable(alltoall_test_perf alltoall_test.cu)
|
||||
add_mscclpp_test_executable(alltoallv_test_perf alltoallv_test.cu)
|
||||
|
||||
347
test/mscclpp-test/alltoallv_test.cu
Normal file
347
test/mscclpp-test/alltoallv_test.cu
Normal file
@@ -0,0 +1,347 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
// AllToAllV test - tests variable-length alltoall operations
|
||||
// This test validates the alltoallv kernel that handles variable element counts per rank.
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <numeric>
|
||||
#include <mscclpp/concurrency_device.hpp>
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
#define WARP_SIZE 64
|
||||
#else
|
||||
#define WARP_SIZE 32
|
||||
#endif
|
||||
|
||||
template <class T>
|
||||
using DeviceHandle = mscclpp::DeviceHandle<T>;
|
||||
__constant__ DeviceHandle<mscclpp::PortChannel> constPortChansV[16];
|
||||
__device__ mscclpp::DeviceSyncer deviceSyncerV;
|
||||
|
||||
static void* localRecvBuffV;
|
||||
static void* localSendBuffV;
|
||||
|
||||
// Device arrays for variable counts and displacements
|
||||
static size_t* d_sendCounts;
|
||||
static size_t* d_sendDispls;
|
||||
static size_t* d_recvCounts;
|
||||
static size_t* d_recvDispls;
|
||||
|
||||
/**
|
||||
* AllToAllV kernel implementation
|
||||
*
|
||||
* Each rank sends sendCounts[i] bytes to rank i at sendDispls[i] offset,
|
||||
* and receives recvCounts[i] bytes from rank i at recvDispls[i] offset.
|
||||
*
|
||||
* Uses ring-based exchange pattern to avoid deadlocks.
|
||||
*/
|
||||
__global__ void __launch_bounds__(1024)
|
||||
alltoallv0(int rank, int worldSize,
|
||||
const void* sendBuff, void* recvBuff,
|
||||
const size_t* sendCounts, const size_t* sendDispls,
|
||||
const size_t* recvCounts, const size_t* recvDispls) {
|
||||
int tid = threadIdx.x;
|
||||
int nPeers = worldSize - 1;
|
||||
|
||||
// Step 1: Copy local data (rank's own portion)
|
||||
if (tid == 0 && sendCounts[rank] > 0) {
|
||||
const char* src = (const char*)sendBuff + sendDispls[rank];
|
||||
char* dst = (char*)recvBuff + recvDispls[rank];
|
||||
memcpy(dst, src, sendCounts[rank]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 2: Each warp handles one peer for sending
|
||||
int warpId = tid / WARP_SIZE;
|
||||
int laneId = tid % WARP_SIZE;
|
||||
|
||||
if (warpId < nPeers && laneId == 0) {
|
||||
// Determine which peer this warp handles
|
||||
int peer = warpId < rank ? warpId : warpId + 1;
|
||||
int chanIdx = warpId;
|
||||
|
||||
if (sendCounts[peer] > 0) {
|
||||
constPortChansV[chanIdx].putWithSignal(
|
||||
recvDispls[rank], // dst offset in peer's buffer
|
||||
sendDispls[peer], // src offset in our buffer
|
||||
sendCounts[peer] // size
|
||||
);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 3: Flush all pending operations
|
||||
if (warpId < nPeers && laneId == 0) {
|
||||
int peer = warpId < rank ? warpId : warpId + 1;
|
||||
if (sendCounts[peer] > 0) {
|
||||
constPortChansV[warpId].flush();
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 4: Wait for all incoming data
|
||||
if (warpId < nPeers && laneId == 0) {
|
||||
int peer = warpId < rank ? warpId : warpId + 1;
|
||||
if (recvCounts[peer] > 0) {
|
||||
constPortChansV[warpId].wait();
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/**
|
||||
* Ring-based AllToAllV kernel for larger world sizes
|
||||
*
|
||||
* Uses step-by-step ring pattern to exchange data, sending to (rank+step) and
|
||||
* receiving from (rank-step) in each step. Single block to avoid concurrent
|
||||
* access to the same port channels.
|
||||
*/
|
||||
__global__ void __launch_bounds__(1024)
|
||||
alltoallv1(int rank, int worldSize,
|
||||
const void* sendBuff, void* recvBuff,
|
||||
const size_t* sendCounts, const size_t* sendDispls,
|
||||
const size_t* recvCounts, const size_t* recvDispls) {
|
||||
// Copy local data first
|
||||
if (threadIdx.x == 0) {
|
||||
if (sendCounts[rank] > 0) {
|
||||
const char* src = (const char*)sendBuff + sendDispls[rank];
|
||||
char* dst = (char*)recvBuff + recvDispls[rank];
|
||||
memcpy(dst, src, sendCounts[rank]);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Ring-based exchange - single thread handles the communication
|
||||
// to avoid race conditions on port channels
|
||||
if (threadIdx.x == 0) {
|
||||
for (int step = 1; step < worldSize; step++) {
|
||||
int sendPeer = (rank + step) % worldSize;
|
||||
int recvPeer = (rank - step + worldSize) % worldSize;
|
||||
|
||||
int sendChanIdx = sendPeer < rank ? sendPeer : sendPeer - 1;
|
||||
int recvChanIdx = recvPeer < rank ? recvPeer : recvPeer - 1;
|
||||
|
||||
// Send data to sendPeer (non-blocking put with signal)
|
||||
if (sendCounts[sendPeer] > 0) {
|
||||
constPortChansV[sendChanIdx].putWithSignal(
|
||||
recvDispls[rank], // dst offset in peer's buffer
|
||||
sendDispls[sendPeer], // src offset in our buffer
|
||||
sendCounts[sendPeer] // size
|
||||
);
|
||||
constPortChansV[sendChanIdx].flush();
|
||||
}
|
||||
|
||||
// Wait for data from recvPeer
|
||||
if (recvCounts[recvPeer] > 0) {
|
||||
constPortChansV[recvChanIdx].wait();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class AllToAllVTestColl : public BaseTestColl {
|
||||
public:
|
||||
AllToAllVTestColl() = default;
|
||||
~AllToAllVTestColl() override = default;
|
||||
|
||||
void runColl(const TestArgs& args, cudaStream_t stream) override;
|
||||
void initData(const TestArgs& args, std::vector<void*> sendBuff, void* expectedBuff) override;
|
||||
void getBw(const double deltaSec, double& algBw /*OUT*/, double& busBw /*OUT*/) override;
|
||||
void setupCollTest(size_t size) override;
|
||||
std::vector<KernelRestriction> getKernelRestrictions() override;
|
||||
|
||||
private:
|
||||
// Host-side counts and displacements
|
||||
std::vector<size_t> sendCounts_;
|
||||
std::vector<size_t> sendDispls_;
|
||||
std::vector<size_t> recvCounts_;
|
||||
std::vector<size_t> recvDispls_;
|
||||
size_t totalSendBytes_;
|
||||
size_t totalRecvBytes_;
|
||||
};
|
||||
|
||||
void AllToAllVTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
|
||||
const int worldSize = args.totalRanks;
|
||||
const int rank = args.rank;
|
||||
const int kernelNum = args.kernelNum;
|
||||
|
||||
// Reset device syncer
|
||||
mscclpp::DeviceSyncer syncer = {};
|
||||
CUDATHROW(cudaMemcpyToSymbol(deviceSyncerV, &syncer, sizeof(mscclpp::DeviceSyncer)));
|
||||
|
||||
if (kernelNum == 0) {
|
||||
int nThreads = (worldSize - 1) * WARP_SIZE;
|
||||
if (nThreads < 32) nThreads = 32;
|
||||
if (nThreads > 1024) nThreads = 1024;
|
||||
alltoallv0<<<1, nThreads, 0, stream>>>(
|
||||
rank, worldSize,
|
||||
localSendBuffV, localRecvBuffV,
|
||||
d_sendCounts, d_sendDispls,
|
||||
d_recvCounts, d_recvDispls);
|
||||
} else if (kernelNum == 1) {
|
||||
// Single block, single thread for ring-based serialized communication
|
||||
alltoallv1<<<1, 32, 0, stream>>>(
|
||||
rank, worldSize,
|
||||
localSendBuffV, localRecvBuffV,
|
||||
d_sendCounts, d_sendDispls,
|
||||
d_recvCounts, d_recvDispls);
|
||||
}
|
||||
}
|
||||
|
||||
void AllToAllVTestColl::initData(const TestArgs& args, std::vector<void*> sendBuff, void* expectedBuff) {
|
||||
if (sendBuff.size() != 1) throw std::runtime_error("unexpected error");
|
||||
const int rank = args.rank;
|
||||
const int worldSize = args.totalRanks;
|
||||
|
||||
// Create send data: each segment has values identifying source and destination
|
||||
std::vector<int> sendData(totalSendBytes_ / sizeof(int), 0);
|
||||
for (int peer = 0; peer < worldSize; peer++) {
|
||||
size_t offset = sendDispls_[peer] / sizeof(int);
|
||||
size_t count = sendCounts_[peer] / sizeof(int);
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
// Encode: rank * 10000 + peer * 100 + position
|
||||
sendData[offset + i] = rank * 10000 + peer * 100 + i;
|
||||
}
|
||||
}
|
||||
CUDATHROW(cudaMemcpy(sendBuff[0], sendData.data(), totalSendBytes_, cudaMemcpyHostToDevice));
|
||||
|
||||
// Create expected data: we receive from each peer
|
||||
std::vector<int> expectedData(totalRecvBytes_ / sizeof(int), 0);
|
||||
for (int peer = 0; peer < worldSize; peer++) {
|
||||
size_t offset = recvDispls_[peer] / sizeof(int);
|
||||
size_t count = recvCounts_[peer] / sizeof(int);
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
// We receive data sent by peer to us
|
||||
expectedData[offset + i] = peer * 10000 + rank * 100 + i;
|
||||
}
|
||||
}
|
||||
std::memcpy(expectedBuff, expectedData.data(), totalRecvBytes_);
|
||||
|
||||
// Copy counts and displacements to device
|
||||
CUDATHROW(cudaMemcpy(d_sendCounts, sendCounts_.data(), worldSize * sizeof(size_t), cudaMemcpyHostToDevice));
|
||||
CUDATHROW(cudaMemcpy(d_sendDispls, sendDispls_.data(), worldSize * sizeof(size_t), cudaMemcpyHostToDevice));
|
||||
CUDATHROW(cudaMemcpy(d_recvCounts, recvCounts_.data(), worldSize * sizeof(size_t), cudaMemcpyHostToDevice));
|
||||
CUDATHROW(cudaMemcpy(d_recvDispls, recvDispls_.data(), worldSize * sizeof(size_t), cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
void AllToAllVTestColl::getBw(const double deltaSec, double& algBw, double& busBw) {
|
||||
double baseBw = (double)(totalRecvBytes_) / 1.0E9 / deltaSec;
|
||||
algBw = baseBw;
|
||||
double factor = ((double)(worldSize_ - 1)) / ((double)worldSize_);
|
||||
busBw = baseBw * factor;
|
||||
}
|
||||
|
||||
void AllToAllVTestColl::setupCollTest(size_t size) {
|
||||
// For alltoallv, we use variable sizes per peer
|
||||
// For testing: rank i sends (rank + 1) * baseCount elements to each peer
|
||||
// Each peer j sends (j + 1) * baseCount elements to us
|
||||
|
||||
size_t baseBytes = size / (worldSize_ * worldSize_); // Base unit for variable sizing
|
||||
if (baseBytes < sizeof(int)) baseBytes = sizeof(int);
|
||||
baseBytes = (baseBytes / sizeof(int)) * sizeof(int); // Align to int size
|
||||
|
||||
sendCounts_.resize(worldSize_);
|
||||
sendDispls_.resize(worldSize_);
|
||||
recvCounts_.resize(worldSize_);
|
||||
recvDispls_.resize(worldSize_);
|
||||
|
||||
// Each rank sends the same amount to each peer (for simplicity in this test)
|
||||
// In a real MOE scenario, these would be variable
|
||||
totalSendBytes_ = 0;
|
||||
totalRecvBytes_ = 0;
|
||||
|
||||
for (int peer = 0; peer < worldSize_; peer++) {
|
||||
sendCounts_[peer] = baseBytes;
|
||||
sendDispls_[peer] = totalSendBytes_;
|
||||
totalSendBytes_ += sendCounts_[peer];
|
||||
|
||||
recvCounts_[peer] = baseBytes;
|
||||
recvDispls_[peer] = totalRecvBytes_;
|
||||
totalRecvBytes_ += recvCounts_[peer];
|
||||
}
|
||||
|
||||
sendCount_ = totalSendBytes_ / typeSize_;
|
||||
recvCount_ = totalRecvBytes_ / typeSize_;
|
||||
paramCount_ = sendCount_;
|
||||
expectedCount_ = recvCount_;
|
||||
|
||||
mscclpp::DeviceSyncer syncer = {};
|
||||
CUDATHROW(cudaMemcpyToSymbol(deviceSyncerV, &syncer, sizeof(mscclpp::DeviceSyncer)));
|
||||
}
|
||||
|
||||
std::vector<KernelRestriction> AllToAllVTestColl::getKernelRestrictions() {
|
||||
return {
|
||||
{0, "alltoallv0", true, 1, 4 * worldSize_},
|
||||
{1, "alltoallv1", true, 1, 4 * worldSize_}
|
||||
};
|
||||
}
|
||||
|
||||
class AllToAllVTestEngine : public BaseTestEngine {
|
||||
public:
|
||||
AllToAllVTestEngine(const TestArgs& args);
|
||||
~AllToAllVTestEngine() override = default;
|
||||
|
||||
void allocateBuffer() override;
|
||||
void setupConnections() override;
|
||||
|
||||
std::vector<void*> getSendBuff() override;
|
||||
void* getRecvBuff() override;
|
||||
void* getScratchBuff() override;
|
||||
|
||||
private:
|
||||
void* getExpectedBuff() override;
|
||||
bool isInPlace() const;
|
||||
|
||||
std::shared_ptr<int> sendBuff_;
|
||||
std::shared_ptr<int> recvBuff_;
|
||||
std::shared_ptr<int[]> expectedBuff_;
|
||||
};
|
||||
|
||||
bool AllToAllVTestEngine::isInPlace() const { return false; }
|
||||
|
||||
AllToAllVTestEngine::AllToAllVTestEngine(const TestArgs& args) : BaseTestEngine(args, "alltoallv") { inPlace_ = false; }
|
||||
|
||||
void AllToAllVTestEngine::allocateBuffer() {
|
||||
sendBuff_ = mscclpp::GpuBuffer<int>(args_.maxBytes / sizeof(int)).memory();
|
||||
recvBuff_ = mscclpp::GpuBuffer<int>(args_.maxBytes / sizeof(int)).memory();
|
||||
expectedBuff_ = std::shared_ptr<int[]>(new int[args_.maxBytes / sizeof(int)]);
|
||||
|
||||
localSendBuffV = sendBuff_.get();
|
||||
localRecvBuffV = recvBuff_.get();
|
||||
|
||||
// Allocate device arrays for counts and displacements
|
||||
CUDATHROW(cudaMalloc(&d_sendCounts, args_.totalRanks * sizeof(size_t)));
|
||||
CUDATHROW(cudaMalloc(&d_sendDispls, args_.totalRanks * sizeof(size_t)));
|
||||
CUDATHROW(cudaMalloc(&d_recvCounts, args_.totalRanks * sizeof(size_t)));
|
||||
CUDATHROW(cudaMalloc(&d_recvDispls, args_.totalRanks * sizeof(size_t)));
|
||||
}
|
||||
|
||||
void AllToAllVTestEngine::setupConnections() {
|
||||
std::vector<DeviceHandle<mscclpp::PortChannel>> portChannels;
|
||||
setupMeshConnections(portChannels, sendBuff_.get(), args_.maxBytes, recvBuff_.get(), args_.maxBytes);
|
||||
|
||||
if (portChannels.size() > sizeof(constPortChansV) / sizeof(DeviceHandle<mscclpp::PortChannel>)) {
|
||||
throw std::runtime_error("Too many port channels for alltoallv test");
|
||||
}
|
||||
CUDATHROW(cudaMemcpyToSymbol(constPortChansV, portChannels.data(),
|
||||
sizeof(DeviceHandle<mscclpp::PortChannel>) * portChannels.size()));
|
||||
}
|
||||
|
||||
std::vector<void*> AllToAllVTestEngine::getSendBuff() { return {sendBuff_.get()}; }
|
||||
void* AllToAllVTestEngine::getExpectedBuff() { return expectedBuff_.get(); }
|
||||
void* AllToAllVTestEngine::getRecvBuff() {
|
||||
if (this->isInPlace())
|
||||
return sendBuff_.get();
|
||||
else
|
||||
return recvBuff_.get();
|
||||
}
|
||||
void* AllToAllVTestEngine::getScratchBuff() { return nullptr; }
|
||||
|
||||
std::shared_ptr<BaseTestEngine> getTestEngine(const TestArgs& args) {
|
||||
return std::make_shared<AllToAllVTestEngine>(args);
|
||||
}
|
||||
std::shared_ptr<BaseTestColl> getTestColl() { return std::make_shared<AllToAllVTestColl>(); }
|
||||
Reference in New Issue
Block a user