New DSL implementation (#579)

The PR contains following changes:
Python side:
- Channel based DSL implementation: decouple channel with chunk.
- Users create channel explicitly, only need local_rank, remote_rank and
channel_type
- Adjust executor json file, add remote_buffer fields, different op can
use different channel and remote buffers combination.
- Reimplement operation fusion, data dependency check mechanism
- Add new op such as semaphore, pipeline 
- Clean code and enhance document
C++ side: 
- Support new execution file json format
- Support semaphore and pipeline operation
- code clean, support non-zero copy scenario

---------

Co-authored-by: Caio Rocha <caiorocha@microsoft.com>
Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
This commit is contained in:
Binyang Li
2025-08-09 00:36:20 -07:00
committed by GitHub
parent 1cc1b827f4
commit be6a941fba
109 changed files with 10136 additions and 7182 deletions

View File

@@ -12,6 +12,7 @@
#include "execution_plan.hpp"
namespace mscclpp {
struct ExecutionContextKey {
void* sendBuff;
void* recvBuff;
@@ -25,14 +26,15 @@ struct ExecutionContextKey {
}
};
void* getBuffer(BufferType type, void* sendbuff, void* recvbuff, void* scratch) {
std::pair<void*, size_t> getBufferInfo(BufferType type, void* sendbuff, void* recvbuff, void* scratch,
size_t sendBuffSize, size_t recvBuffSize, size_t scratchBuffSize) {
switch (type) {
case BufferType::INPUT:
return sendbuff;
return std::make_pair(sendbuff, sendBuffSize);
case BufferType::OUTPUT:
return recvbuff;
return std::make_pair(recvbuff, recvBuffSize);
case BufferType::SCRATCH:
return scratch;
return std::make_pair(scratch, scratchBuffSize);
default:
throw Error("Invalid buffer type", ErrorCode::ExecutorError);
}
@@ -113,18 +115,30 @@ struct ExecutionContext {
std::shared_ptr<ProxyService> proxyService;
std::unordered_map<int, std::shared_ptr<Connection>> connections;
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections;
std::unordered_map<std::pair<BufferType, int>, mscclpp::RegisteredMemory> registeredMemories;
// For registered memories, registeredMemoryAddresses is used for memoryChannel and registeredMemoryIds is used for
// proxy channel
std::vector<mscclpp::RegisteredMemory> registeredMemories;
std::vector<void*> registeredMemoryAddresses;
std::vector<mscclpp::MemoryId> registeredMemoryIds;
// local registered memories to keep resources alive
std::vector<mscclpp::RegisteredMemory> localRegisteredMemories;
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memorySemaphores;
std::vector<mscclpp::SemaphoreId> proxySemaphores;
std::vector<mscclpp::MemoryChannel> memoryChannels;
std::vector<mscclpp::PortChannel> portChannels;
std::vector<mscclpp::BaseMemoryChannel> memoryChannels;
std::vector<mscclpp::BasePortChannel> portChannels;
std::vector<mscclpp::SwitchChannel> nvlsChannels;
std::unordered_map<DeviceExecutionPlanKey, std::vector<DeviceExecutionPlan>> deviceExecutionPlans;
std::unordered_map<DeviceExecutionPlanKey, std::shared_ptr<char>> deviceExecutionPlansBuffers;
std::shared_ptr<char> scratchBuffer;
std::shared_ptr<char> smemaphores;
size_t scratchBufferSize;
uint32_t scratchChunkSize;
int nthreadsPerBlock;
DeviceExecutionPlanKey currentDevicePlan;
bool reuseResources;
bool doubleScratchBuff;
};
struct Executor::Impl {
@@ -144,6 +158,11 @@ struct Executor::Impl {
size_t sendMemRange, size_t recvMemRange, const ExecutionPlan& plan) {
ExecutionContextKey key = {sendbuff, recvbuff, sendMemRange, recvMemRange, plan.impl_->name};
DeviceExecutionPlanKey devicePlanKey = {inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset};
// The plan is not related to any specific input/output message size or memory address
if (plan.impl_->reuseResources) {
key = {nullptr, nullptr, 0, 0, plan.impl_->name};
}
if (this->contexts.find(key) != this->contexts.end()) {
auto& devicePlans = this->contexts[key].deviceExecutionPlans;
if (this->contexts[key].currentDevicePlan == devicePlanKey) {
@@ -154,7 +173,7 @@ struct Executor::Impl {
}
plan.impl_->operationsReset();
plan.impl_->lightLoadExecutionPlan(inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset);
this->setupDeviceExecutionPlan(this->contexts[key], devicePlanKey, rank, plan);
this->setupDeviceExecutionPlan(this->contexts[key], devicePlanKey, plan);
this->contexts[key].deviceExecutionPlansBuffers[devicePlanKey] =
GpuBuffer(devicePlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan)).memory();
gpuMemcpy(this->contexts[key].deviceExecutionPlansBuffers[devicePlanKey].get(),
@@ -168,19 +187,21 @@ struct Executor::Impl {
plan.impl_->loadExecutionPlan(inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset);
ExecutionContext context;
size_t maxScratchBufferSize = plan.impl_->getMaxScratchBufferSize(rank);
size_t scratchBufferSize =
std::min(plan.impl_->getScratchBufferSize(rank, sendMemRange, recvMemRange), maxScratchBufferSize);
std::shared_ptr<char> scratchBuffer = GpuBuffer(scratchBufferSize).memory();
context.scratchBuffer = scratchBuffer;
context.reuseResources = plan.impl_->reuseResources;
context.doubleScratchBuff = plan.impl_->doubleScratchBuffer;
size_t scratchBufferSize = plan.impl_->calScratchBufferSize(std::min(sendMemRange, plan.impl_->maxMessageSize),
std::min(recvMemRange, plan.impl_->maxMessageSize));
context.scratchChunkSize = plan.impl_->calMaxScratchChunkSize(scratchBufferSize);
context.scratchBuffer = GpuBuffer(scratchBufferSize).memory();
context.scratchBufferSize = scratchBufferSize;
context.proxyService = std::make_shared<ProxyService>();
context.nthreadsPerBlock = plan.impl_->getNThreadsPerBlock();
this->setupConnections(context, rank, plan, sendMemRange, recvMemRange);
context.nthreadsPerBlock = plan.impl_->nThreadsPerBlock;
this->setupConnections(context, rank, sendMemRange, recvMemRange, scratchBufferSize, plan);
this->setupChannels(context, plan);
this->setupRegisteredMemories(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan);
this->setupChannels(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan);
this->setupNvlsChannels(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan);
this->setupDeviceExecutionPlan(context, devicePlanKey, rank, plan);
this->setupNvlsChannels(context, sendbuff, recvbuff, rank, sendMemRange, recvMemRange, scratchBufferSize, plan);
this->setupSemaphores(context, plan);
this->setupDeviceExecutionPlan(context, devicePlanKey, plan);
context.deviceExecutionPlansBuffers[devicePlanKey] =
GpuBuffer(context.deviceExecutionPlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan)).memory();
gpuMemcpy(context.deviceExecutionPlansBuffers[devicePlanKey].get(),
@@ -192,26 +213,37 @@ struct Executor::Impl {
return context;
}
TransportFlags getTransportFlags(std::vector<ChannelInfo>& infos, int rank) {
TransportFlags getTransportFlags(const BufferInfo& info, int rank) {
TransportFlags flags;
for (ChannelInfo& info : infos) {
if (info.channelType == ChannelType::MEMORY) {
for (const ChannelType& type : info.accessChannelTypes) {
if (type == ChannelType::MEMORY) {
flags |= Transport::CudaIpc;
} else if (info.channelType == ChannelType::PORT) {
for (int peer : info.connectedPeers) {
if (!inSameNode(rank, peer, this->nranksPerNode)) {
flags |= IBs[rank % this->nranksPerNode];
} else
flags |= Transport::CudaIpc;
}
} else if (type == ChannelType::PORT) {
if (!inSameNode(rank, info.accessRank, this->nranksPerNode)) {
flags |= IBs[rank % this->nranksPerNode];
} else
flags |= Transport::CudaIpc;
}
}
return flags;
};
void setupConnections(ExecutionContext& context, int rank, const ExecutionPlan& plan, size_t sendBufferSize,
size_t recvBufferSize) {
std::vector<int> connectedPeers = plan.impl_->getConnectedPeers(rank);
void setupConnections(ExecutionContext& context, int rank, size_t sendBuffSize, size_t recvBuffSize,
size_t scratchBuffSize, const ExecutionPlan& plan) {
auto getBufferSize = [&](BufferType bufferType) {
switch (bufferType) {
case BufferType::INPUT:
return sendBuffSize;
case BufferType::OUTPUT:
return recvBuffSize;
case BufferType::SCRATCH:
return scratchBuffSize;
default:
throw Error("Invalid buffer type", ErrorCode::ExecutorError);
}
};
std::vector<int> connectedPeers = plan.impl_->getConnectedPeers();
std::vector<std::shared_future<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
for (int peer : connectedPeers) {
Transport transport =
@@ -222,63 +254,53 @@ struct Executor::Impl {
context.connections[connectedPeers[i]] = connectionFutures[i].get();
}
std::vector<NvlsInfo> nvlsInfos = plan.impl_->getNvlsInfos(rank, sendBufferSize, recvBufferSize);
std::vector<NvlsInfo> nvlsInfos = plan.impl_->nvlsInfos.at(rank);
for (const NvlsInfo& info : nvlsInfos) {
std::shared_ptr<NvlsConnection> nvlsConnection =
mscclpp::connectNvlsCollective(this->comm, info.ranks, info.bufferSize);
mscclpp::connectNvlsCollective(this->comm, info.ranks, getBufferSize(info.bufferType));
context.nvlsConnections.push_back(nvlsConnection);
}
}
void setupRegisteredMemories(ExecutionContext& context, void* sendbuff, void* recvbuff, size_t sendBufferSize,
size_t recvBufferSize, int rank, const ExecutionPlan& plan) {
auto getBufferInfo = [&](BufferType type) {
switch (type) {
case BufferType::INPUT:
return std::make_pair(sendbuff, sendBufferSize);
case BufferType::OUTPUT:
return std::make_pair(recvbuff, recvBufferSize);
case BufferType::SCRATCH:
return std::make_pair((void*)context.scratchBuffer.get(), context.scratchBufferSize);
default:
throw Error("Invalid buffer type", ErrorCode::ExecutorError);
// Add local src,dst and scratch to registeredMemoryIds
for (auto& bufferType : {BufferType::INPUT, BufferType::OUTPUT, BufferType::SCRATCH}) {
TransportFlags flags = Transport::CudaIpc;
#if defined(USE_IBVERBS)
flags |= IBs[rank % this->nranksPerNode];
#endif
RegisteredMemory localMemory;
auto bufferInfo = getBufferInfo(bufferType, sendbuff, recvbuff, context.scratchBuffer.get(), sendBufferSize,
recvBufferSize, context.scratchBufferSize);
if (bufferInfo.second > 0) {
localMemory = this->comm->registerMemory(bufferInfo.first, bufferInfo.second, flags);
}
};
auto getConnectedPeers = [&](std::vector<ChannelInfo>& infos) {
std::set<int> peers;
for (ChannelInfo& info : infos) {
for (int peer : info.connectedPeers) {
peers.insert(peer);
context.proxyService->addMemory(localMemory);
}
for (const auto& buffer : plan.impl_->getLocalBufferToSend()) {
auto bufferInfo = getBufferInfo(buffer.bufferType, sendbuff, recvbuff, context.scratchBuffer.get(),
sendBufferSize, recvBufferSize, context.scratchBufferSize);
RegisteredMemory memory =
this->comm->registerMemory(bufferInfo.first, bufferInfo.second, getTransportFlags(buffer, rank));
comm->sendMemory(memory, buffer.accessRank);
context.localRegisteredMemories.emplace_back(std::move(memory));
}
for (const auto& bufferInfo : plan.impl_->getRemoteBufferInfos()) {
std::shared_future<RegisteredMemory> remoteRegMemoryFuture = comm->recvMemory(bufferInfo.rank);
context.registeredMemories.emplace_back(std::move(remoteRegMemoryFuture.get()));
for (ChannelType chanType : bufferInfo.accessChannelTypes) {
if (chanType == ChannelType::MEMORY) {
context.registeredMemoryAddresses.push_back(context.registeredMemories.back().data());
} else if (chanType == ChannelType::PORT) {
context.registeredMemoryIds.push_back(context.proxyService->addMemory(context.registeredMemories.back()));
}
}
return std::vector<int>(peers.begin(), peers.end());
};
std::vector<BufferType> bufferTypes = plan.impl_->getConnectedBufferTypes(rank);
for (BufferType bufferType : bufferTypes) {
std::vector<ChannelInfo> channelInfos = plan.impl_->getChannelInfosByDstRank(rank, bufferType);
TransportFlags transportFlags = getTransportFlags(channelInfos, rank);
RegisteredMemory memory =
this->comm->registerMemory(getBufferInfo(bufferType).first, getBufferInfo(bufferType).second, transportFlags);
std::vector<int> connectedPeers = getConnectedPeers(channelInfos);
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemoryFutures;
for (int peer : connectedPeers) {
comm->sendMemory(memory, peer);
}
channelInfos = plan.impl_->getChannelInfos(rank, bufferType);
connectedPeers = getConnectedPeers(channelInfos);
for (int peer : connectedPeers) {
remoteRegMemoryFutures.push_back(comm->recvMemory(peer));
}
for (size_t i = 0; i < remoteRegMemoryFutures.size(); i++) {
context.registeredMemories[{bufferType, connectedPeers[i]}] = std::move(remoteRegMemoryFutures[i].get());
}
context.registeredMemories[{bufferType, rank}] = std::move(memory);
}
}
void setupChannels(ExecutionContext& context, void* sendbuff, void* recvbuff, size_t sendBufferSize,
size_t recvBufferSize, int rank, const ExecutionPlan& plan) {
void setupChannels(ExecutionContext& context, const ExecutionPlan& plan) {
const auto channelTypes = {ChannelType::MEMORY, ChannelType::PORT};
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> memorySemaphores;
std::vector<mscclpp::SemaphoreId> proxySemaphores;
@@ -296,89 +318,91 @@ struct Executor::Impl {
}
};
for (ChannelType channelType : channelTypes) {
std::vector<ChannelInfo> channelInfos = plan.impl_->getChannelInfos(rank, channelType);
std::vector<ChannelInfo> channelInfos = plan.impl_->getChannelInfos(channelType);
processChannelInfos(channelInfos);
// Current semaphore construction requires two-way communication, e.g., to construct a semaphore signaling from
// rank 0 to rank 1, both rank 0 and rank 1 need to send a message to each other. This PR fixes an executor bug
// that fails to conduct two-way communication for constructing such one-way semaphores, and instead hangs
// during the semaphore construction. In the future, we may need to change the implementation to construct
// semaphore via one-way communication.
channelInfos = plan.impl_->getUnpairedChannelInfos(rank, nranks, channelType);
// during the semaphore construction.
channelInfos = plan.impl_->getUnpairedChannelInfos(nranks, channelType);
processChannelInfos(channelInfos);
}
context.memorySemaphores = std::move(memorySemaphores);
context.proxySemaphores = std::move(proxySemaphores);
auto getBufferSize = [&](BufferType type) {
switch (type) {
case BufferType::INPUT:
return sendBufferSize;
case BufferType::OUTPUT:
return recvBufferSize;
case BufferType::SCRATCH:
return context.scratchBufferSize;
default:
throw Error("Invalid buffer type", ErrorCode::ExecutorError);
}
};
for (ChannelType channelType : channelTypes) {
std::vector<ChannelInfo> channelInfos = plan.impl_->getChannelInfos(rank, channelType);
std::vector<ChannelInfo> channelInfos = plan.impl_->getChannelInfos(channelType);
int index = 0;
for (ChannelInfo& info : channelInfos) {
void* src = getBuffer(info.srcBufferType, sendbuff, recvbuff, context.scratchBuffer.get());
size_t bufferSize = getBufferSize(info.srcBufferType);
TransportFlags transport = getTransportFlags(channelInfos, rank);
RegisteredMemory localMemory = this->comm->registerMemory(src, bufferSize, transport);
for (int peer : info.connectedPeers) {
for (size_t i = 0; i < info.connectedPeers.size(); i++) {
if (channelType == ChannelType::MEMORY) {
context.memoryChannels.emplace_back(context.memorySemaphores[index++],
context.registeredMemories[{info.dstBufferType, peer}], localMemory,
nullptr);
context.memoryChannels.emplace_back(context.memorySemaphores[index++]);
} else if (channelType == ChannelType::PORT) {
context.portChannels.emplace_back(context.proxyService->portChannel(
context.proxySemaphores[index++],
context.proxyService->addMemory(context.registeredMemories[{info.dstBufferType, peer}]),
context.proxyService->addMemory(localMemory)));
context.portChannels.emplace_back(context.proxyService->basePortChannel(context.proxySemaphores[index++]));
}
}
}
}
}
void setupNvlsChannels(ExecutionContext& context, void* sendbuff, void* recvbuff, size_t sendBufferSize,
size_t recvBufferSize, int rank, const ExecutionPlan& plan) {
std::vector<NvlsInfo> nvlsInfos = plan.impl_->getNvlsInfos(rank, sendBufferSize, recvBufferSize);
void setupNvlsChannels(ExecutionContext& context, void* sendbuff, void* recvbuff, int rank, size_t sendBuffSize,
size_t recvBuffSize, size_t scratchBuffSize, const ExecutionPlan& plan) {
std::vector<NvlsInfo> nvlsInfos = plan.impl_->nvlsInfos.at(rank);
for (size_t i = 0; i < nvlsInfos.size(); i++) {
std::shared_ptr<NvlsConnection> nvlsConnection = context.nvlsConnections[i];
NvlsInfo info = nvlsInfos[i];
void* buffer = getBuffer(info.bufferType, sendbuff, recvbuff, context.scratchBuffer.get());
SwitchChannel switchChannel = nvlsConnection->bindAllocatedMemory((CUdeviceptr)buffer, info.bufferSize);
auto bufferInfo = getBufferInfo(info.bufferType, sendbuff, recvbuff, context.scratchBuffer.get(), sendBuffSize,
recvBuffSize, scratchBuffSize);
SwitchChannel switchChannel =
nvlsConnection->bindAllocatedMemory((CUdeviceptr)bufferInfo.first, bufferInfo.second);
context.nvlsChannels.push_back(switchChannel);
}
}
void setupDeviceExecutionPlan(ExecutionContext& context, const DeviceExecutionPlanKey& key, int rank,
void setupSemaphores(ExecutionContext& context, const ExecutionPlan& plan) {
std::vector<DeviceSemaphore> semaphores;
for (const SemaphoreInfo& info : plan.impl_->semaphoreInfos) {
DeviceSemaphore semaphore(info.initValue);
semaphores.push_back(semaphore);
}
context.smemaphores = GpuBuffer(semaphores.size() * sizeof(DeviceSemaphore)).memory();
gpuMemcpy(context.smemaphores.get(), (char*)semaphores.data(), semaphores.size() * sizeof(DeviceSemaphore),
cudaMemcpyHostToDevice);
}
void setupDeviceExecutionPlan(ExecutionContext& context, const DeviceExecutionPlanKey& key,
const ExecutionPlan& plan) {
std::vector<DeviceExecutionPlan> deviceExecutionPlans;
for (int threadblock = 0; threadblock < plan.impl_->getThreadblockCount(rank); threadblock++) {
for (int threadblock = 0; threadblock < plan.impl_->getThreadblockCount(); threadblock++) {
DeviceExecutionPlan deviceExecutionPlan = {};
std::vector<Operation> ops = plan.impl_->getOperations(rank, threadblock);
std::vector<Operation> ops = plan.impl_->getOperations(threadblock);
deviceExecutionPlan.nOperations = ops.size();
deviceExecutionPlan.nMemoryChannels = plan.impl_->threadblockMemoryChannelMap.at(rank).at(threadblock).size();
deviceExecutionPlan.nPortChannels = plan.impl_->threadblockPortChannelMap.at(rank).at(threadblock).size();
deviceExecutionPlan.nMemoryChannels = plan.impl_->threadblockMemoryChannels.at(threadblock).size();
deviceExecutionPlan.nPortChannels = plan.impl_->threadblockPortChannels.at(threadblock).size();
int chanIndex = 0;
for (const auto& [index, _] : plan.impl_->threadblockMemoryChannelMap.at(rank).at(threadblock)) {
for (const int index : plan.impl_->threadblockMemoryChannels.at(threadblock)) {
deviceExecutionPlan.channels.memoryChannels[chanIndex++] = mscclpp::deviceHandle(context.memoryChannels[index]);
}
chanIndex = 0;
for (const auto& [index, _] : plan.impl_->threadblockPortChannelMap.at(rank).at(threadblock)) {
for (const int index : plan.impl_->threadblockPortChannels.at(threadblock)) {
deviceExecutionPlan.channels.portChannels[chanIndex++] = mscclpp::deviceHandle(context.portChannels[index]);
}
chanIndex = 0;
for (const auto& [index, _] : plan.impl_->threadblockNvlsChannelMap.at(rank).at(threadblock)) {
for (const int index : plan.impl_->threadblockNvlsChannels.at(threadblock)) {
deviceExecutionPlan.channels.nvlsChannels[chanIndex++] = mscclpp::deviceHandle(context.nvlsChannels[index]);
}
int memIndex = 0;
for (const auto& pair : plan.impl_->threadblockMemoryChannelBuffers.at(threadblock)) {
deviceExecutionPlan.remoteBuffers.memoryChannelBufferPtrs[memIndex] =
context.registeredMemoryAddresses[pair.first];
deviceExecutionPlan.remoteBuffers.memoryChannelBufferTypes[memIndex++] = pair.second;
}
memIndex = 0;
for (const auto& pair : plan.impl_->threadblockPortChannelBuffers.at(threadblock)) {
deviceExecutionPlan.remoteBuffers.portChannelBufferIds[memIndex] = context.registeredMemoryIds[pair.first];
deviceExecutionPlan.remoteBuffers.portChannelBufferTypes[memIndex++] = pair.second;
}
if (ops.size() > MAX_OPERATION) {
throw Error("Executor plan launching " + std::to_string(ops.size()) +
" operations, exceeding device execution plan support (" + std::to_string(MAX_OPERATION) + ")",
@@ -392,13 +416,36 @@ struct Executor::Impl {
context.deviceExecutionPlans[key] = std::move(deviceExecutionPlans);
}
template <typename PacketType>
void launchKernelHelper(ExecutionContext& context, int rank, void* sendbuff, void* recvbuff, DataType dataType,
cudaStream_t stream, uint32_t sharedMemSize, const uint32_t& flag) {
DeviceExecutionPlanKey key = context.currentDevicePlan;
int nthreadblocks = context.deviceExecutionPlans[key].size();
void* scratchBuffer = context.scratchBuffer.get();
size_t scratchOffset = 0;
if (context.doubleScratchBuff && (flag & 0x1) == 0) {
scratchOffset = (context.scratchBufferSize) >> 1;
}
if (context.reuseResources) {
ExecutionKernel::launchKernel<PacketType, true>(
rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, scratchBuffer, scratchOffset,
context.scratchChunkSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffers[key].get(),
(DeviceSemaphore*)context.smemaphores.get(), sharedMemSize, stream, flag);
} else {
ExecutionKernel::launchKernel<PacketType, false>(
rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, scratchBuffer, scratchOffset,
context.scratchChunkSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffers[key].get(),
(DeviceSemaphore*)context.smemaphores.get(), sharedMemSize, stream, flag);
}
}
void launchKernel(ExecutionContext& context, int rank, void* sendbuff, void* recvbuff, DataType dataType,
cudaStream_t stream, PacketType packetType) {
static uint32_t flag = 0;
DeviceExecutionPlanKey key = context.currentDevicePlan;
int nthreadblocks = context.deviceExecutionPlans[key].size();
#if defined(ENABLE_NPKIT)
#if defined(__HIP_PLATFORM_AMD__)
DeviceExecutionPlanKey key = context.currentDevicePlan;
int nthreadblocks = context.deviceExecutionPlans[key].size();
if (nthreadblocks > NPKIT_MAX_NUM_GPU_THREADBLOCKS) {
throw Error("Executor plan launching " + std::to_string(nthreadblocks) +
" thread blocks, exceeding NPKit support (" + std::to_string(NPKIT_MAX_NUM_GPU_THREADBLOCKS) +
@@ -408,20 +455,14 @@ struct Executor::Impl {
#endif
size_t sharedMemSize = sizeof(DeviceExecutionPlan) + NPKIT_SHM_NUM_EVENTS * sizeof(NpKitEvent);
#else
size_t sharedMemSize = sizeof(DeviceExecutionPlan);
uint32_t sharedMemSize = sizeof(DeviceExecutionPlan);
#endif
switch (packetType) {
case PacketType::LL16:
ExecutionKernel::launchKernel<LL16Packet>(
rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(),
context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffers[key].get(),
sharedMemSize, stream, ++flag);
launchKernelHelper<LL16Packet>(context, rank, sendbuff, recvbuff, dataType, stream, sharedMemSize, ++flag);
break;
case PacketType::LL8:
ExecutionKernel::launchKernel<LL8Packet>(
rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(),
context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffers[key].get(),
sharedMemSize, stream, ++flag);
launchKernelHelper<LL8Packet>(context, rank, sendbuff, recvbuff, dataType, stream, sharedMemSize, ++flag);
break;
default:
throw Error("Invalid packet type", ErrorCode::ExecutorError);