mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
Move the alltoallv kernel to the src directory; Utilize the kernel in mscclpp-test
This commit is contained in:
@@ -1,372 +1,22 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// AllToAllV implementation for MSCCLPP
|
||||
// This kernel handles variable element counts per rank for alltoallv operations.
|
||||
// Unlike NCCL's ncclGroupStart/ncclGroupEnd approach, mscclpp uses explicit
|
||||
// put/signal/wait operations on PortChannels.
|
||||
// AllToAllV Python bindings for MSCCLPP
|
||||
// This file provides Python bindings for the alltoallv algorithm.
|
||||
// The actual implementation is in src/ext/collectives/alltoallv/
|
||||
|
||||
#include <Python.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include <mscclpp/algorithm.hpp>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/port_channel.hpp>
|
||||
#include <mscclpp/port_channel_device.hpp>
|
||||
#include <mscclpp/concurrency_device.hpp>
|
||||
|
||||
// Include the implementation header
|
||||
#include "alltoallv/alltoallv_fullmesh.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
#define WARP_SIZE 64
|
||||
#else
|
||||
#define WARP_SIZE 32
|
||||
#endif
|
||||
|
||||
// Device syncer for synchronization across blocks
|
||||
__device__ mscclpp::DeviceSyncer alltoallvDeviceSyncer;
|
||||
|
||||
/**
|
||||
* AllToAllV kernel implementation
|
||||
*
|
||||
* This kernel performs an all-to-all exchange with variable-length data per rank.
|
||||
* Each rank sends sendCounts[i] elements to rank i at sendDispls[i] offset,
|
||||
* and receives recvCounts[i] elements from rank i at recvDispls[i] offset.
|
||||
*
|
||||
* Since mscclpp doesn't support ncclGroupStart/ncclGroupEnd, we implement
|
||||
* the exchange using explicit put/signal/wait operations on PortChannels.
|
||||
* The communication pattern uses a ring-based approach to avoid deadlocks.
|
||||
*
|
||||
* @param portChannels Array of PortChannel handles for each peer (worldSize-1 channels)
|
||||
* @param rank Current rank
|
||||
* @param worldSize Total number of ranks
|
||||
* @param sendBuff Source buffer containing data to send
|
||||
* @param recvBuff Destination buffer for received data
|
||||
* @param sendCounts Array of send counts for each rank (in bytes)
|
||||
* @param sendDispls Array of send displacements for each rank (in bytes)
|
||||
* @param recvCounts Array of receive counts for each rank (in bytes)
|
||||
* @param recvDispls Array of receive displacements for each rank (in bytes)
|
||||
*/
|
||||
__global__ void __launch_bounds__(1024)
|
||||
alltoallv_kernel(mscclpp::DeviceHandle<mscclpp::PortChannel>* portChannels,
|
||||
int rank,
|
||||
int worldSize,
|
||||
const void* sendBuff,
|
||||
void* recvBuff,
|
||||
const size_t* sendCounts,
|
||||
const size_t* sendDispls,
|
||||
const size_t* recvCounts,
|
||||
const size_t* recvDispls) {
|
||||
// First, copy local data (rank's own portion) from send to recv buffer
|
||||
// This doesn't require any communication
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
||||
if (sendCounts[rank] > 0) {
|
||||
// Local copy: sendBuff[sendDispls[rank]] -> recvBuff[recvDispls[rank]]
|
||||
const char* src = (const char*)sendBuff + sendDispls[rank];
|
||||
char* dst = (char*)recvBuff + recvDispls[rank];
|
||||
memcpy(dst, src, sendCounts[rank]);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Ring-based exchange pattern to avoid deadlocks
|
||||
// In each step i, rank sends to (rank + i) % worldSize and receives from (rank - i + worldSize) % worldSize
|
||||
for (int step = 1; step < worldSize; step++) {
|
||||
int sendPeer = (rank + step) % worldSize;
|
||||
int recvPeer = (rank - step + worldSize) % worldSize;
|
||||
|
||||
// Get channel indices (portChannels excludes self, so adjust index)
|
||||
int sendChanIdx = sendPeer < rank ? sendPeer : sendPeer - 1;
|
||||
int recvChanIdx = recvPeer < rank ? recvPeer : recvPeer - 1;
|
||||
|
||||
// Each warp handles one peer
|
||||
int wid = threadIdx.x / WARP_SIZE;
|
||||
int lid = threadIdx.x % WARP_SIZE;
|
||||
|
||||
// Send data to sendPeer if there's data to send
|
||||
if (wid == 0 && lid == 0) {
|
||||
if (sendCounts[sendPeer] > 0) {
|
||||
// putWithSignal: copy data and signal completion
|
||||
// src offset: sendDispls[sendPeer] in our sendBuff
|
||||
// dst offset: recvDispls[rank] in peer's recvBuff (where our data should go)
|
||||
portChannels[sendChanIdx].putWithSignal(
|
||||
recvDispls[rank], // dst offset in peer's recv buffer (where we write)
|
||||
sendDispls[sendPeer], // src offset in our send buffer
|
||||
sendCounts[sendPeer] // size in bytes
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Sync all threads before flushing
|
||||
alltoallvDeviceSyncer.sync(gridDim.x);
|
||||
|
||||
// Flush to ensure data is sent
|
||||
if (wid == 0 && lid == 0) {
|
||||
if (sendCounts[sendPeer] > 0) {
|
||||
portChannels[sendChanIdx].flush();
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for data from recvPeer if we're expecting data
|
||||
if (wid == 0 && lid == 0) {
|
||||
if (recvCounts[recvPeer] > 0) {
|
||||
portChannels[recvChanIdx].wait();
|
||||
}
|
||||
}
|
||||
|
||||
// Sync all threads before next step
|
||||
alltoallvDeviceSyncer.sync(gridDim.x);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Simplified AllToAllV kernel for single-block execution
|
||||
*
|
||||
* This version is optimized for cases where all communication can be
|
||||
* handled within a single thread block.
|
||||
*/
|
||||
__global__ void __launch_bounds__(1024)
|
||||
alltoallv_simple_kernel(mscclpp::DeviceHandle<mscclpp::PortChannel>* portChannels,
|
||||
int rank,
|
||||
int worldSize,
|
||||
const void* sendBuff,
|
||||
void* recvBuff,
|
||||
const size_t* sendCounts,
|
||||
const size_t* sendDispls,
|
||||
const size_t* recvCounts,
|
||||
const size_t* recvDispls) {
|
||||
int tid = threadIdx.x;
|
||||
int nPeers = worldSize - 1;
|
||||
|
||||
// Step 1: Copy local data
|
||||
if (tid == 0 && sendCounts[rank] > 0) {
|
||||
const char* src = (const char*)sendBuff + sendDispls[rank];
|
||||
char* dst = (char*)recvBuff + recvDispls[rank];
|
||||
memcpy(dst, src, sendCounts[rank]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 2: Each warp handles one peer for sending
|
||||
// We have worldSize-1 peers, assign one warp per peer
|
||||
int warpId = tid / WARP_SIZE;
|
||||
int laneId = tid % WARP_SIZE;
|
||||
|
||||
if (warpId < nPeers && laneId == 0) {
|
||||
// Determine which peer this warp handles
|
||||
int peer = warpId < rank ? warpId : warpId + 1;
|
||||
int chanIdx = warpId;
|
||||
|
||||
if (sendCounts[peer] > 0) {
|
||||
portChannels[chanIdx].putWithSignal(
|
||||
recvDispls[rank], // dst offset in peer's buffer
|
||||
sendDispls[peer], // src offset in our buffer
|
||||
sendCounts[peer] // size
|
||||
);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 3: Flush all pending operations
|
||||
if (warpId < nPeers && laneId == 0) {
|
||||
int peer = warpId < rank ? warpId : warpId + 1;
|
||||
if (sendCounts[peer] > 0) {
|
||||
portChannels[warpId].flush();
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 4: Wait for all incoming data
|
||||
if (warpId < nPeers && laneId == 0) {
|
||||
int peer = warpId < rank ? warpId : warpId + 1;
|
||||
if (recvCounts[peer] > 0) {
|
||||
portChannels[warpId].wait();
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Context to hold all necessary state for alltoallv execution
|
||||
struct AllToAllVContext {
|
||||
int rank;
|
||||
int worldSize;
|
||||
int nRanksPerNode;
|
||||
|
||||
std::vector<mscclpp::RegisteredMemory> registeredMemories;
|
||||
std::shared_ptr<mscclpp::DeviceHandle<mscclpp::PortChannel>> portChannelDeviceHandles;
|
||||
|
||||
// Device memory for counts and displacements
|
||||
size_t* d_sendCounts;
|
||||
size_t* d_sendDispls;
|
||||
size_t* d_recvCounts;
|
||||
size_t* d_recvDispls;
|
||||
};
|
||||
|
||||
class AllToAllVAlgoBuilder : public mscclpp::AlgorithmBuilder {
|
||||
public:
|
||||
AllToAllVAlgoBuilder() = default;
|
||||
~AllToAllVAlgoBuilder() {
|
||||
if (proxyService_) {
|
||||
proxyService_->stopProxy();
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<mscclpp::Algorithm> build() override {
|
||||
auto self = std::make_shared<AllToAllVAlgoBuilder>();
|
||||
std::shared_ptr<mscclpp::Algorithm> alltoallvAlgo = std::make_shared<mscclpp::NativeAlgorithm>(
|
||||
"alltoallv", "alltoallv",
|
||||
// Initialize function
|
||||
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
|
||||
// Kernel execution function
|
||||
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize, size_t outputSize,
|
||||
mscclpp::DataType dtype, [[maybe_unused]] mscclpp::ReduceOp op, cudaStream_t stream, int nBlocks,
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) {
|
||||
return self->alltoallvKernelFunc(ctx, input, output, inputSize, outputSize, dtype, stream, extras);
|
||||
},
|
||||
// Context initialization function
|
||||
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
size_t outputSize,
|
||||
mscclpp::DataType dtype) { return self->initAlltoallvContext(comm, input, output, inputSize, outputSize, dtype); },
|
||||
// Context key generation function
|
||||
[self](const void* input, void* output, size_t inputSize, size_t outputSize, mscclpp::DataType dtype) {
|
||||
return self->generateAlltoallvContextKey(input, output, inputSize, outputSize, dtype);
|
||||
});
|
||||
return alltoallvAlgo;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<mscclpp::Connection> conns_;
|
||||
std::shared_ptr<mscclpp::ProxyService> proxyService_;
|
||||
int worldSize_;
|
||||
|
||||
void initialize(std::shared_ptr<mscclpp::Communicator> comm) {
|
||||
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
|
||||
worldSize_ = comm->bootstrap()->getNranks();
|
||||
for (int i = 0; i < worldSize_; i++) {
|
||||
if (i == comm->bootstrap()->getRank()) continue;
|
||||
connectionFutures.push_back(comm->connect(mscclpp::Transport::CudaIpc, i));
|
||||
}
|
||||
std::vector<mscclpp::Connection> connections;
|
||||
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
|
||||
[](const auto& future) { return future.get(); });
|
||||
this->conns_ = std::move(connections);
|
||||
proxyService_ = std::make_shared<mscclpp::ProxyService>();
|
||||
proxyService_->startProxy(true);
|
||||
}
|
||||
|
||||
mscclpp::CommResult alltoallvKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output,
|
||||
size_t inputSize, size_t outputSize,
|
||||
[[maybe_unused]] mscclpp::DataType dtype,
|
||||
cudaStream_t stream,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras) {
|
||||
auto algoCtx = std::static_pointer_cast<AllToAllVContext>(ctx);
|
||||
int rank = algoCtx->rank;
|
||||
int worldSize = algoCtx->worldSize;
|
||||
|
||||
// Extract send/recv counts and displacements from extras
|
||||
// The caller should pass these as device pointers via extras map
|
||||
auto it_sendCounts = extras.find("sendCounts");
|
||||
auto it_sendDispls = extras.find("sendDispls");
|
||||
auto it_recvCounts = extras.find("recvCounts");
|
||||
auto it_recvDispls = extras.find("recvDispls");
|
||||
|
||||
if (it_sendCounts == extras.end() || it_sendDispls == extras.end() ||
|
||||
it_recvCounts == extras.end() || it_recvDispls == extras.end()) {
|
||||
return mscclpp::CommResult::CommInternalError;
|
||||
}
|
||||
|
||||
const size_t* d_sendCounts = reinterpret_cast<const size_t*>(it_sendCounts->second);
|
||||
const size_t* d_sendDispls = reinterpret_cast<const size_t*>(it_sendDispls->second);
|
||||
const size_t* d_recvCounts = reinterpret_cast<const size_t*>(it_recvCounts->second);
|
||||
const size_t* d_recvDispls = reinterpret_cast<const size_t*>(it_recvDispls->second);
|
||||
|
||||
// Reset device syncer
|
||||
mscclpp::DeviceSyncer syncer = {};
|
||||
cudaMemcpyToSymbolAsync(alltoallvDeviceSyncer, &syncer, sizeof(mscclpp::DeviceSyncer), 0,
|
||||
cudaMemcpyHostToDevice, stream);
|
||||
|
||||
// Use simple kernel for small world sizes, multi-block for larger
|
||||
if (worldSize <= 16) {
|
||||
int nThreads = (worldSize - 1) * WARP_SIZE;
|
||||
if (nThreads < 32) nThreads = 32;
|
||||
if (nThreads > 1024) nThreads = 1024;
|
||||
|
||||
alltoallv_simple_kernel<<<1, nThreads, 0, stream>>>(
|
||||
algoCtx->portChannelDeviceHandles.get(),
|
||||
rank, worldSize,
|
||||
input, output,
|
||||
d_sendCounts, d_sendDispls,
|
||||
d_recvCounts, d_recvDispls);
|
||||
} else {
|
||||
alltoallv_kernel<<<1, 1024, 0, stream>>>(
|
||||
algoCtx->portChannelDeviceHandles.get(),
|
||||
rank, worldSize,
|
||||
input, output,
|
||||
d_sendCounts, d_sendDispls,
|
||||
d_recvCounts, d_recvDispls);
|
||||
}
|
||||
|
||||
if (cudaGetLastError() == cudaSuccess) {
|
||||
return mscclpp::CommResult::CommSuccess;
|
||||
}
|
||||
return mscclpp::CommResult::CommInternalError;
|
||||
}
|
||||
|
||||
std::shared_ptr<void> initAlltoallvContext(std::shared_ptr<mscclpp::Communicator> comm, const void* input,
|
||||
void* output, size_t inputSize, size_t outputSize,
|
||||
mscclpp::DataType dtype) {
|
||||
auto ctx = std::make_shared<AllToAllVContext>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->worldSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
|
||||
|
||||
// Register memories for input and output buffers
|
||||
mscclpp::RegisteredMemory inputBufRegMem =
|
||||
comm->registerMemory((void*)input, inputSize, mscclpp::Transport::CudaIpc);
|
||||
mscclpp::RegisteredMemory outputBufRegMem =
|
||||
comm->registerMemory(output, outputSize, mscclpp::Transport::CudaIpc);
|
||||
|
||||
// Exchange output buffer registration with all peers
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemories;
|
||||
for (int i = 0; i < ctx->worldSize; i++) {
|
||||
if (i == ctx->rank) continue;
|
||||
comm->sendMemory(outputBufRegMem, i, 0);
|
||||
remoteRegMemories.push_back(comm->recvMemory(i, 0));
|
||||
}
|
||||
|
||||
// Setup port channels for each peer
|
||||
std::vector<mscclpp::DeviceHandle<mscclpp::PortChannel>> portChannels;
|
||||
mscclpp::MemoryId inputMemoryId = this->proxyService_->addMemory(inputBufRegMem);
|
||||
|
||||
for (size_t i = 0; i < this->conns_.size(); i++) {
|
||||
auto remoteMemory = remoteRegMemories[i].get();
|
||||
mscclpp::MemoryId remoteMemoryId = this->proxyService_->addMemory(remoteMemory);
|
||||
portChannels.push_back(mscclpp::deviceHandle(this->proxyService_->portChannel(
|
||||
this->proxyService_->buildAndAddSemaphore(*comm, this->conns_[i]), remoteMemoryId, inputMemoryId)));
|
||||
}
|
||||
|
||||
// Allocate and copy port channels to device
|
||||
ctx->portChannelDeviceHandles =
|
||||
mscclpp::detail::gpuCallocShared<mscclpp::DeviceHandle<mscclpp::PortChannel>>(portChannels.size());
|
||||
mscclpp::gpuMemcpy(ctx->portChannelDeviceHandles.get(), portChannels.data(), portChannels.size(),
|
||||
cudaMemcpyHostToDevice);
|
||||
|
||||
// Keep registered memory references to prevent deallocation
|
||||
std::transform(remoteRegMemories.begin(), remoteRegMemories.end(), std::back_inserter(ctx->registeredMemories),
|
||||
[](const auto& fut) { return fut.get(); });
|
||||
ctx->registeredMemories.push_back(inputBufRegMem);
|
||||
ctx->registeredMemories.push_back(outputBufRegMem);
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
mscclpp::AlgorithmCtxKey generateAlltoallvContextKey(const void* input, void* output, size_t inputSize,
|
||||
size_t outputSize, mscclpp::DataType dtype) {
|
||||
return {(void*)input, output, inputSize, outputSize, 0};
|
||||
}
|
||||
};
|
||||
|
||||
std::shared_ptr<mscclpp::Algorithm> createAlltoallvAlgorithm() {
|
||||
auto alltoallvAlgoBuilder = std::make_shared<AllToAllVAlgoBuilder>();
|
||||
auto alltoallvAlgoBuilder = std::make_shared<mscclpp::collective::AlltoallvFullmesh>();
|
||||
return alltoallvAlgoBuilder->build();
|
||||
}
|
||||
|
||||
|
||||
197
src/ext/collectives/alltoallv/alltoallv_fullmesh.cu
Normal file
197
src/ext/collectives/alltoallv/alltoallv_fullmesh.cu
Normal file
@@ -0,0 +1,197 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "alltoallv/alltoallv_fullmesh.hpp"
|
||||
#include "alltoallv/alltoallv_kernel.hpp"
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/port_channel.hpp>
|
||||
#include <mscclpp/port_channel_device.hpp>
|
||||
#include <mscclpp/gpu_utils.hpp>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
namespace mscclpp {
|
||||
namespace collective {
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
#define ALLTOALLV_WARP_SIZE 64
|
||||
#else
|
||||
#define ALLTOALLV_WARP_SIZE 32
|
||||
#endif
|
||||
|
||||
// Context to hold all necessary state for alltoallv execution
|
||||
struct AllToAllVContext {
|
||||
int rank;
|
||||
int worldSize;
|
||||
int nRanksPerNode;
|
||||
|
||||
std::vector<RegisteredMemory> registeredMemories;
|
||||
std::shared_ptr<DeviceHandle<PortChannel>> portChannelDeviceHandles;
|
||||
};
|
||||
|
||||
AlltoallvFullmesh::~AlltoallvFullmesh() {
|
||||
if (proxyService_) {
|
||||
proxyService_->stopProxy();
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Algorithm> AlltoallvFullmesh::build() {
|
||||
auto self = std::shared_ptr<AlltoallvFullmesh>(this, [](AlltoallvFullmesh*) {});
|
||||
|
||||
std::shared_ptr<Algorithm> alltoallvAlgo = std::make_shared<NativeAlgorithm>(
|
||||
"alltoallv", "alltoallv_fullmesh",
|
||||
// Initialize function
|
||||
[self](std::shared_ptr<Communicator> comm) { self->initialize(comm); },
|
||||
// Kernel execution function
|
||||
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
size_t outputSize, DataType dtype, [[maybe_unused]] ReduceOp op, cudaStream_t stream,
|
||||
int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras) {
|
||||
return self->alltoallvKernelFunc(ctx, input, output, inputSize, outputSize, dtype, stream,
|
||||
nBlocks, nThreadsPerBlock, extras);
|
||||
},
|
||||
// Context initialization function
|
||||
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
size_t outputSize, DataType dtype) {
|
||||
return self->initAlltoallvContext(comm, input, output, inputSize, outputSize, dtype);
|
||||
},
|
||||
// Context key generation function
|
||||
[self](const void* input, void* output, size_t inputSize, size_t outputSize, DataType dtype) {
|
||||
return self->generateAlltoallvContextKey(input, output, inputSize, outputSize, dtype);
|
||||
});
|
||||
|
||||
return alltoallvAlgo;
|
||||
}
|
||||
|
||||
void AlltoallvFullmesh::initialize(std::shared_ptr<Communicator> comm) {
|
||||
std::vector<std::shared_future<Connection>> connectionFutures;
|
||||
worldSize_ = comm->bootstrap()->getNranks();
|
||||
int rank = comm->bootstrap()->getRank();
|
||||
|
||||
for (int i = 0; i < worldSize_; i++) {
|
||||
if (i == rank) continue;
|
||||
connectionFutures.push_back(comm->connect(Transport::CudaIpc, i));
|
||||
}
|
||||
|
||||
std::vector<Connection> connections;
|
||||
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
|
||||
[](const auto& future) { return future.get(); });
|
||||
this->conns_ = std::move(connections);
|
||||
|
||||
proxyService_ = std::make_shared<ProxyService>();
|
||||
proxyService_->startProxy(true);
|
||||
}
|
||||
|
||||
CommResult AlltoallvFullmesh::alltoallvKernelFunc(
|
||||
const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
size_t outputSize, [[maybe_unused]] DataType dtype, cudaStream_t stream,
|
||||
[[maybe_unused]] int nBlocks, [[maybe_unused]] int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras) {
|
||||
|
||||
auto algoCtx = std::static_pointer_cast<AllToAllVContext>(ctx);
|
||||
int rank = algoCtx->rank;
|
||||
int worldSize = algoCtx->worldSize;
|
||||
|
||||
// Extract send/recv counts and displacements from extras
|
||||
auto it_sendCounts = extras.find("sendCounts");
|
||||
auto it_sendDispls = extras.find("sendDispls");
|
||||
auto it_recvCounts = extras.find("recvCounts");
|
||||
auto it_recvDispls = extras.find("recvDispls");
|
||||
|
||||
if (it_sendCounts == extras.end() || it_sendDispls == extras.end() ||
|
||||
it_recvCounts == extras.end() || it_recvDispls == extras.end()) {
|
||||
return CommResult::CommInternalError;
|
||||
}
|
||||
|
||||
const size_t* d_sendCounts = reinterpret_cast<const size_t*>(it_sendCounts->second);
|
||||
const size_t* d_sendDispls = reinterpret_cast<const size_t*>(it_sendDispls->second);
|
||||
const size_t* d_recvCounts = reinterpret_cast<const size_t*>(it_recvCounts->second);
|
||||
const size_t* d_recvDispls = reinterpret_cast<const size_t*>(it_recvDispls->second);
|
||||
|
||||
// Choose kernel based on world size
|
||||
if (worldSize <= 16) {
|
||||
// Use parallel warp-based kernel for small world sizes
|
||||
int nThreads = (worldSize - 1) * ALLTOALLV_WARP_SIZE;
|
||||
if (nThreads < 32) nThreads = 32;
|
||||
if (nThreads > 1024) nThreads = 1024;
|
||||
|
||||
alltoallvKernel<<<1, nThreads, 0, stream>>>(
|
||||
algoCtx->portChannelDeviceHandles.get(),
|
||||
rank, worldSize,
|
||||
input, output,
|
||||
d_sendCounts, d_sendDispls,
|
||||
d_recvCounts, d_recvDispls);
|
||||
} else {
|
||||
// Use ring-based kernel for larger world sizes
|
||||
alltoallvRingKernel<<<1, 32, 0, stream>>>(
|
||||
algoCtx->portChannelDeviceHandles.get(),
|
||||
rank, worldSize,
|
||||
input, output,
|
||||
d_sendCounts, d_sendDispls,
|
||||
d_recvCounts, d_recvDispls);
|
||||
}
|
||||
|
||||
if (cudaGetLastError() == cudaSuccess) {
|
||||
return CommResult::CommSuccess;
|
||||
}
|
||||
return CommResult::CommInternalError;
|
||||
}
|
||||
|
||||
std::shared_ptr<void> AlltoallvFullmesh::initAlltoallvContext(
|
||||
std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
size_t outputSize, [[maybe_unused]] DataType dtype) {
|
||||
|
||||
auto ctx = std::make_shared<AllToAllVContext>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->worldSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
|
||||
|
||||
// Register memories for input and output buffers
|
||||
RegisteredMemory inputBufRegMem = comm->registerMemory((void*)input, inputSize, Transport::CudaIpc);
|
||||
RegisteredMemory outputBufRegMem = comm->registerMemory(output, outputSize, Transport::CudaIpc);
|
||||
|
||||
// Exchange output buffer registration with all peers
|
||||
std::vector<std::shared_future<RegisteredMemory>> remoteRegMemories;
|
||||
for (int i = 0; i < ctx->worldSize; i++) {
|
||||
if (i == ctx->rank) continue;
|
||||
comm->sendMemory(outputBufRegMem, i, 0);
|
||||
remoteRegMemories.push_back(comm->recvMemory(i, 0));
|
||||
}
|
||||
|
||||
// Setup port channels for each peer
|
||||
std::vector<DeviceHandle<PortChannel>> portChannels;
|
||||
MemoryId inputMemoryId = this->proxyService_->addMemory(inputBufRegMem);
|
||||
|
||||
for (size_t i = 0; i < this->conns_.size(); i++) {
|
||||
auto remoteMemory = remoteRegMemories[i].get();
|
||||
MemoryId remoteMemoryId = this->proxyService_->addMemory(remoteMemory);
|
||||
portChannels.push_back(deviceHandle(this->proxyService_->portChannel(
|
||||
this->proxyService_->buildAndAddSemaphore(*comm, this->conns_[i]), remoteMemoryId, inputMemoryId)));
|
||||
}
|
||||
|
||||
// Allocate and copy port channels to device
|
||||
ctx->portChannelDeviceHandles = detail::gpuCallocShared<DeviceHandle<PortChannel>>(portChannels.size());
|
||||
gpuMemcpy(ctx->portChannelDeviceHandles.get(), portChannels.data(), portChannels.size(),
|
||||
cudaMemcpyHostToDevice);
|
||||
|
||||
// Keep registered memory references to prevent deallocation
|
||||
std::transform(remoteRegMemories.begin(), remoteRegMemories.end(),
|
||||
std::back_inserter(ctx->registeredMemories),
|
||||
[](const auto& fut) { return fut.get(); });
|
||||
ctx->registeredMemories.push_back(inputBufRegMem);
|
||||
ctx->registeredMemories.push_back(outputBufRegMem);
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
AlgorithmCtxKey AlltoallvFullmesh::generateAlltoallvContextKey(
|
||||
const void* input, void* output, size_t inputSize, size_t outputSize,
|
||||
[[maybe_unused]] DataType dtype) {
|
||||
return {(void*)input, output, inputSize, outputSize, 0};
|
||||
}
|
||||
|
||||
#undef ALLTOALLV_WARP_SIZE
|
||||
|
||||
} // namespace collective
|
||||
} // namespace mscclpp
|
||||
55
src/ext/collectives/include/alltoallv/alltoallv_fullmesh.hpp
Normal file
55
src/ext/collectives/include/alltoallv/alltoallv_fullmesh.hpp
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <mscclpp/algorithm.hpp>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/port_channel.hpp>
|
||||
|
||||
namespace mscclpp {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* AllToAllV collective operation builder.
|
||||
*
|
||||
* This class builds an AllToAllV algorithm that handles variable element counts
|
||||
* per rank, similar to MPI_Alltoallv. Unlike NCCL's ncclGroupStart/ncclGroupEnd
|
||||
* approach, mscclpp uses explicit put/signal/wait operations on PortChannels.
|
||||
*
|
||||
* The implementation uses a ring-based exchange pattern to avoid deadlocks.
|
||||
*
|
||||
* Usage:
|
||||
* auto builder = std::make_shared<AlltoallvFullmesh>();
|
||||
* auto algorithm = builder->build();
|
||||
* // Then execute with extras containing sendCounts, sendDispls, recvCounts, recvDispls
|
||||
*/
|
||||
class AlltoallvFullmesh : public AlgorithmBuilder {
|
||||
public:
|
||||
AlltoallvFullmesh() = default;
|
||||
~AlltoallvFullmesh();
|
||||
|
||||
std::shared_ptr<Algorithm> build() override;
|
||||
|
||||
private:
|
||||
void initialize(std::shared_ptr<Communicator> comm);
|
||||
|
||||
CommResult alltoallvKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output,
|
||||
size_t inputSize, size_t outputSize, DataType dtype, cudaStream_t stream,
|
||||
int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras);
|
||||
|
||||
std::shared_ptr<void> initAlltoallvContext(std::shared_ptr<Communicator> comm, const void* input,
|
||||
void* output, size_t inputSize, size_t outputSize,
|
||||
DataType dtype);
|
||||
|
||||
AlgorithmCtxKey generateAlltoallvContextKey(const void* input, void* output, size_t inputSize,
|
||||
size_t outputSize, DataType dtype);
|
||||
|
||||
std::vector<Connection> conns_;
|
||||
std::shared_ptr<ProxyService> proxyService_;
|
||||
int worldSize_;
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace mscclpp
|
||||
153
src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp
Normal file
153
src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp
Normal file
@@ -0,0 +1,153 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <mscclpp/port_channel_device.hpp>
|
||||
#include <mscclpp/concurrency_device.hpp>
|
||||
#include <mscclpp/copy_device.hpp>
|
||||
|
||||
namespace mscclpp {
|
||||
namespace collective {
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
#define ALLTOALLV_WARP_SIZE 64
|
||||
#else
|
||||
#define ALLTOALLV_WARP_SIZE 32
|
||||
#endif
|
||||
|
||||
/**
|
||||
* AllToAllV kernel implementation using parallel warp-based communication.
|
||||
*
|
||||
* Each warp handles communication with one peer. All sends happen in parallel,
|
||||
* followed by flushes and waits.
|
||||
*
|
||||
* @param portChannels Array of PortChannel handles for each peer (worldSize-1 channels)
|
||||
* @param rank Current rank
|
||||
* @param worldSize Total number of ranks
|
||||
* @param sendBuff Source buffer containing data to send
|
||||
* @param recvBuff Destination buffer for received data
|
||||
* @param sendCounts Array of send counts for each rank (in bytes)
|
||||
* @param sendDispls Array of send displacements for each rank (in bytes)
|
||||
* @param recvCounts Array of receive counts for each rank (in bytes)
|
||||
* @param recvDispls Array of receive displacements for each rank (in bytes)
|
||||
*/
|
||||
__global__ void __launch_bounds__(1024)
|
||||
alltoallvKernel(DeviceHandle<PortChannel>* portChannels,
|
||||
int rank,
|
||||
int worldSize,
|
||||
const void* sendBuff,
|
||||
void* recvBuff,
|
||||
const size_t* sendCounts,
|
||||
const size_t* sendDispls,
|
||||
const size_t* recvCounts,
|
||||
const size_t* recvDispls) {
|
||||
int tid = threadIdx.x;
|
||||
int nPeers = worldSize - 1;
|
||||
|
||||
// Step 1: Copy local data (rank's own portion)
|
||||
if (tid == 0 && sendCounts[rank] > 0) {
|
||||
const char* src = (const char*)sendBuff + sendDispls[rank];
|
||||
char* dst = (char*)recvBuff + recvDispls[rank];
|
||||
memcpy(dst, src, sendCounts[rank]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 2: Each warp handles one peer for sending
|
||||
int warpId = tid / ALLTOALLV_WARP_SIZE;
|
||||
int laneId = tid % ALLTOALLV_WARP_SIZE;
|
||||
|
||||
if (warpId < nPeers && laneId == 0) {
|
||||
// Determine which peer this warp handles
|
||||
int peer = warpId < rank ? warpId : warpId + 1;
|
||||
int chanIdx = warpId;
|
||||
|
||||
if (sendCounts[peer] > 0) {
|
||||
portChannels[chanIdx].putWithSignal(
|
||||
recvDispls[rank], // dst offset in peer's buffer
|
||||
sendDispls[peer], // src offset in our buffer
|
||||
sendCounts[peer] // size
|
||||
);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 3: Flush all pending operations
|
||||
if (warpId < nPeers && laneId == 0) {
|
||||
int peer = warpId < rank ? warpId : warpId + 1;
|
||||
if (sendCounts[peer] > 0) {
|
||||
portChannels[warpId].flush();
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 4: Wait for all incoming data
|
||||
if (warpId < nPeers && laneId == 0) {
|
||||
int peer = warpId < rank ? warpId : warpId + 1;
|
||||
if (recvCounts[peer] > 0) {
|
||||
portChannels[warpId].wait();
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/**
|
||||
* Ring-based AllToAllV kernel for serialized communication.
|
||||
*
|
||||
* Uses step-by-step ring pattern to exchange data, sending to (rank+step) and
|
||||
* receiving from (rank-step) in each step. Single thread handles all communication
|
||||
* to avoid race conditions.
|
||||
*
|
||||
* This kernel is more robust but slower than the parallel version.
|
||||
*/
|
||||
__global__ void __launch_bounds__(1024)
|
||||
alltoallvRingKernel(DeviceHandle<PortChannel>* portChannels,
|
||||
int rank,
|
||||
int worldSize,
|
||||
const void* sendBuff,
|
||||
void* recvBuff,
|
||||
const size_t* sendCounts,
|
||||
const size_t* sendDispls,
|
||||
const size_t* recvCounts,
|
||||
const size_t* recvDispls) {
|
||||
// Copy local data first
|
||||
if (threadIdx.x == 0) {
|
||||
if (sendCounts[rank] > 0) {
|
||||
const char* src = (const char*)sendBuff + sendDispls[rank];
|
||||
char* dst = (char*)recvBuff + recvDispls[rank];
|
||||
memcpy(dst, src, sendCounts[rank]);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Ring-based exchange - single thread handles communication
|
||||
if (threadIdx.x == 0) {
|
||||
for (int step = 1; step < worldSize; step++) {
|
||||
int sendPeer = (rank + step) % worldSize;
|
||||
int recvPeer = (rank - step + worldSize) % worldSize;
|
||||
|
||||
int sendChanIdx = sendPeer < rank ? sendPeer : sendPeer - 1;
|
||||
int recvChanIdx = recvPeer < rank ? recvPeer : recvPeer - 1;
|
||||
|
||||
// Send data to sendPeer
|
||||
if (sendCounts[sendPeer] > 0) {
|
||||
portChannels[sendChanIdx].putWithSignal(
|
||||
recvDispls[rank],
|
||||
sendDispls[sendPeer],
|
||||
sendCounts[sendPeer]
|
||||
);
|
||||
portChannels[sendChanIdx].flush();
|
||||
}
|
||||
|
||||
// Wait for data from recvPeer
|
||||
if (recvCounts[recvPeer] > 0) {
|
||||
portChannels[recvChanIdx].wait();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#undef ALLTOALLV_WARP_SIZE
|
||||
|
||||
} // namespace collective
|
||||
} // namespace mscclpp
|
||||
@@ -4,13 +4,16 @@
|
||||
FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.2/json.tar.xz)
|
||||
FetchContent_MakeAvailable(json)
|
||||
|
||||
# Include path for collective algorithm headers (alltoallv, etc.)
|
||||
set(COLLECTIVES_INC_DIR ${PROJECT_SOURCE_DIR}/src/ext/collectives/include)
|
||||
|
||||
function(add_mscclpp_test_executable name sources)
|
||||
if(MSCCLPP_USE_ROCM)
|
||||
set_source_files_properties(${sources} PROPERTIES LANGUAGE CXX)
|
||||
endif()
|
||||
add_executable(${name} ${sources} common.cc)
|
||||
target_link_libraries(${name} ${TEST_LIBS_COMMON} MPI::MPI_CXX nlohmann_json::nlohmann_json)
|
||||
target_include_directories(${name} ${TEST_INC_COMMON} ${TEST_INC_INTERNAL})
|
||||
target_include_directories(${name} ${TEST_INC_COMMON} ${TEST_INC_INTERNAL} PRIVATE ${COLLECTIVES_INC_DIR})
|
||||
set_target_properties(${name} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin/mscclpp-test")
|
||||
endfunction()
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
// AllToAllV test - tests variable-length alltoall operations
|
||||
// This test validates the alltoallv kernel that handles variable element counts per rank.
|
||||
// Uses the kernel implementations from src/ext/collectives/include/alltoallv/
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
@@ -11,15 +12,11 @@
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
#define WARP_SIZE 64
|
||||
#else
|
||||
#define WARP_SIZE 32
|
||||
#endif
|
||||
// Include the alltoallv kernel implementations from src/ext/collectives
|
||||
#include "alltoallv/alltoallv_kernel.hpp"
|
||||
|
||||
template <class T>
|
||||
using DeviceHandle = mscclpp::DeviceHandle<T>;
|
||||
__constant__ DeviceHandle<mscclpp::PortChannel> constPortChansV[16];
|
||||
__device__ mscclpp::DeviceSyncer deviceSyncerV;
|
||||
|
||||
static void* localRecvBuffV;
|
||||
@@ -31,117 +28,8 @@ static size_t* d_sendDispls;
|
||||
static size_t* d_recvCounts;
|
||||
static size_t* d_recvDispls;
|
||||
|
||||
/**
|
||||
* AllToAllV kernel implementation
|
||||
*
|
||||
* Each rank sends sendCounts[i] bytes to rank i at sendDispls[i] offset,
|
||||
* and receives recvCounts[i] bytes from rank i at recvDispls[i] offset.
|
||||
*
|
||||
* Uses ring-based exchange pattern to avoid deadlocks.
|
||||
*/
|
||||
__global__ void __launch_bounds__(1024)
|
||||
alltoallv0(int rank, int worldSize,
|
||||
const void* sendBuff, void* recvBuff,
|
||||
const size_t* sendCounts, const size_t* sendDispls,
|
||||
const size_t* recvCounts, const size_t* recvDispls) {
|
||||
int tid = threadIdx.x;
|
||||
int nPeers = worldSize - 1;
|
||||
|
||||
// Step 1: Copy local data (rank's own portion)
|
||||
if (tid == 0 && sendCounts[rank] > 0) {
|
||||
const char* src = (const char*)sendBuff + sendDispls[rank];
|
||||
char* dst = (char*)recvBuff + recvDispls[rank];
|
||||
memcpy(dst, src, sendCounts[rank]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 2: Each warp handles one peer for sending
|
||||
int warpId = tid / WARP_SIZE;
|
||||
int laneId = tid % WARP_SIZE;
|
||||
|
||||
if (warpId < nPeers && laneId == 0) {
|
||||
// Determine which peer this warp handles
|
||||
int peer = warpId < rank ? warpId : warpId + 1;
|
||||
int chanIdx = warpId;
|
||||
|
||||
if (sendCounts[peer] > 0) {
|
||||
constPortChansV[chanIdx].putWithSignal(
|
||||
recvDispls[rank], // dst offset in peer's buffer
|
||||
sendDispls[peer], // src offset in our buffer
|
||||
sendCounts[peer] // size
|
||||
);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 3: Flush all pending operations
|
||||
if (warpId < nPeers && laneId == 0) {
|
||||
int peer = warpId < rank ? warpId : warpId + 1;
|
||||
if (sendCounts[peer] > 0) {
|
||||
constPortChansV[warpId].flush();
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 4: Wait for all incoming data
|
||||
if (warpId < nPeers && laneId == 0) {
|
||||
int peer = warpId < rank ? warpId : warpId + 1;
|
||||
if (recvCounts[peer] > 0) {
|
||||
constPortChansV[warpId].wait();
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/**
|
||||
* Ring-based AllToAllV kernel for larger world sizes
|
||||
*
|
||||
* Uses step-by-step ring pattern to exchange data, sending to (rank+step) and
|
||||
* receiving from (rank-step) in each step. Single block to avoid concurrent
|
||||
* access to the same port channels.
|
||||
*/
|
||||
__global__ void __launch_bounds__(1024)
|
||||
alltoallv1(int rank, int worldSize,
|
||||
const void* sendBuff, void* recvBuff,
|
||||
const size_t* sendCounts, const size_t* sendDispls,
|
||||
const size_t* recvCounts, const size_t* recvDispls) {
|
||||
// Copy local data first
|
||||
if (threadIdx.x == 0) {
|
||||
if (sendCounts[rank] > 0) {
|
||||
const char* src = (const char*)sendBuff + sendDispls[rank];
|
||||
char* dst = (char*)recvBuff + recvDispls[rank];
|
||||
memcpy(dst, src, sendCounts[rank]);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Ring-based exchange - single thread handles the communication
|
||||
// to avoid race conditions on port channels
|
||||
if (threadIdx.x == 0) {
|
||||
for (int step = 1; step < worldSize; step++) {
|
||||
int sendPeer = (rank + step) % worldSize;
|
||||
int recvPeer = (rank - step + worldSize) % worldSize;
|
||||
|
||||
int sendChanIdx = sendPeer < rank ? sendPeer : sendPeer - 1;
|
||||
int recvChanIdx = recvPeer < rank ? recvPeer : recvPeer - 1;
|
||||
|
||||
// Send data to sendPeer (non-blocking put with signal)
|
||||
if (sendCounts[sendPeer] > 0) {
|
||||
constPortChansV[sendChanIdx].putWithSignal(
|
||||
recvDispls[rank], // dst offset in peer's buffer
|
||||
sendDispls[sendPeer], // src offset in our buffer
|
||||
sendCounts[sendPeer] // size
|
||||
);
|
||||
constPortChansV[sendChanIdx].flush();
|
||||
}
|
||||
|
||||
// Wait for data from recvPeer
|
||||
if (recvCounts[recvPeer] > 0) {
|
||||
constPortChansV[recvChanIdx].wait();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Device array for port channels (used by library kernels)
|
||||
static DeviceHandle<mscclpp::PortChannel>* d_portChannels;
|
||||
|
||||
class AllToAllVTestColl : public BaseTestColl {
|
||||
public:
|
||||
@@ -174,17 +62,23 @@ void AllToAllVTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
|
||||
CUDATHROW(cudaMemcpyToSymbol(deviceSyncerV, &syncer, sizeof(mscclpp::DeviceSyncer)));
|
||||
|
||||
if (kernelNum == 0) {
|
||||
int nThreads = (worldSize - 1) * WARP_SIZE;
|
||||
// Use parallel warp-based kernel from library
|
||||
int nThreads = (worldSize - 1) * 32; // One warp per peer
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
nThreads = (worldSize - 1) * 64;
|
||||
#endif
|
||||
if (nThreads < 32) nThreads = 32;
|
||||
if (nThreads > 1024) nThreads = 1024;
|
||||
alltoallv0<<<1, nThreads, 0, stream>>>(
|
||||
mscclpp::collective::alltoallvKernel<<<1, nThreads, 0, stream>>>(
|
||||
d_portChannels,
|
||||
rank, worldSize,
|
||||
localSendBuffV, localRecvBuffV,
|
||||
d_sendCounts, d_sendDispls,
|
||||
d_recvCounts, d_recvDispls);
|
||||
} else if (kernelNum == 1) {
|
||||
// Single block, single thread for ring-based serialized communication
|
||||
alltoallv1<<<1, 32, 0, stream>>>(
|
||||
// Use ring-based kernel from library
|
||||
mscclpp::collective::alltoallvRingKernel<<<1, 32, 0, stream>>>(
|
||||
d_portChannels,
|
||||
rank, worldSize,
|
||||
localSendBuffV, localRecvBuffV,
|
||||
d_sendCounts, d_sendDispls,
|
||||
@@ -275,8 +169,8 @@ void AllToAllVTestColl::setupCollTest(size_t size) {
|
||||
|
||||
std::vector<KernelRestriction> AllToAllVTestColl::getKernelRestrictions() {
|
||||
return {
|
||||
{0, "alltoallv0", true, 1, 4 * worldSize_},
|
||||
{1, "alltoallv1", true, 1, 4 * worldSize_}
|
||||
{0, "alltoallvKernel", true, 1, 4 * worldSize_},
|
||||
{1, "alltoallvRingKernel", true, 1, 4 * worldSize_}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -318,17 +212,19 @@ void AllToAllVTestEngine::allocateBuffer() {
|
||||
CUDATHROW(cudaMalloc(&d_sendDispls, args_.totalRanks * sizeof(size_t)));
|
||||
CUDATHROW(cudaMalloc(&d_recvCounts, args_.totalRanks * sizeof(size_t)));
|
||||
CUDATHROW(cudaMalloc(&d_recvDispls, args_.totalRanks * sizeof(size_t)));
|
||||
|
||||
// Allocate device array for port channels
|
||||
CUDATHROW(cudaMalloc(&d_portChannels, args_.totalRanks * sizeof(DeviceHandle<mscclpp::PortChannel>)));
|
||||
}
|
||||
|
||||
void AllToAllVTestEngine::setupConnections() {
|
||||
std::vector<DeviceHandle<mscclpp::PortChannel>> portChannels;
|
||||
setupMeshConnections(portChannels, sendBuff_.get(), args_.maxBytes, recvBuff_.get(), args_.maxBytes);
|
||||
|
||||
if (portChannels.size() > sizeof(constPortChansV) / sizeof(DeviceHandle<mscclpp::PortChannel>)) {
|
||||
throw std::runtime_error("Too many port channels for alltoallv test");
|
||||
}
|
||||
CUDATHROW(cudaMemcpyToSymbol(constPortChansV, portChannels.data(),
|
||||
sizeof(DeviceHandle<mscclpp::PortChannel>) * portChannels.size()));
|
||||
// Copy port channels to device memory for use by library kernels
|
||||
CUDATHROW(cudaMemcpy(d_portChannels, portChannels.data(),
|
||||
sizeof(DeviceHandle<mscclpp::PortChannel>) * portChannels.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
std::vector<void*> AllToAllVTestEngine::getSendBuff() { return {sendBuff_.get()}; }
|
||||
|
||||
Reference in New Issue
Block a user