mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-04 21:51:32 +00:00
Tune threads per block for mscclpp executor (#345)
This commit is contained in:
@@ -65,6 +65,7 @@ struct ExecutionContext {
|
||||
std::shared_ptr<char> scratchBuffer;
|
||||
size_t scratchBufferSize;
|
||||
std::shared_ptr<char> deviceExecutionPlansBuffer;
|
||||
int nthreadsPerBlock;
|
||||
};
|
||||
|
||||
struct Executor::Impl {
|
||||
@@ -104,6 +105,7 @@ struct Executor::Impl {
|
||||
context.scratchBuffer = scratchBuffer;
|
||||
context.scratchBufferSize = scratchBufferSize;
|
||||
context.proxyService = std::make_shared<ProxyService>();
|
||||
context.nthreadsPerBlock = plan.impl_->getNThreadsPerBlock();
|
||||
this->setupConnections(context, rank, plan);
|
||||
this->setupRegisteredMemories(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan);
|
||||
this->setupChannels(context, sendbuff, recvbuff, sendBufferSize, rank, plan);
|
||||
@@ -295,8 +297,8 @@ struct Executor::Impl {
|
||||
context.deviceExecutionPlans = std::move(deviceExecutionPlans);
|
||||
}
|
||||
|
||||
void launchKernel(ExecutionContext& context, int rank, int nthreadsPerBlock, void* sendbuff, void* recvbuff,
|
||||
DataType dataType, cudaStream_t stream, PacketType packetType) {
|
||||
void launchKernel(ExecutionContext& context, int rank, void* sendbuff, void* recvbuff, DataType dataType,
|
||||
cudaStream_t stream, PacketType packetType) {
|
||||
static uint32_t flag = 0;
|
||||
int nthreadblocks = context.deviceExecutionPlans.size();
|
||||
#if defined(ENABLE_NPKIT)
|
||||
@@ -315,13 +317,13 @@ struct Executor::Impl {
|
||||
switch (packetType) {
|
||||
case PacketType::LL16:
|
||||
ExecutionKernel::launchKernel<LL16Packet>(
|
||||
rank, nthreadblocks, nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(),
|
||||
rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(),
|
||||
context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(),
|
||||
sharedMemSize, stream, ++flag);
|
||||
break;
|
||||
case PacketType::LL8:
|
||||
ExecutionKernel::launchKernel<LL8Packet>(
|
||||
rank, nthreadblocks, nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(),
|
||||
rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(),
|
||||
context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(),
|
||||
sharedMemSize, stream, ++flag);
|
||||
break;
|
||||
@@ -334,7 +336,7 @@ struct Executor::Impl {
|
||||
Executor::Executor(std::shared_ptr<Communicator> comm) : impl_(std::make_unique<Impl>(comm)) {}
|
||||
|
||||
void Executor::execute(int rank, void* sendbuff, void* recvbuff, size_t sendBuffSize,
|
||||
[[maybe_unused]] size_t recvBuffSize, DataType dataType, int nthreads, const ExecutionPlan& plan,
|
||||
[[maybe_unused]] size_t recvBuffSize, DataType dataType, const ExecutionPlan& plan,
|
||||
cudaStream_t stream, PacketType packetType) {
|
||||
size_t sendBytes, recvBytes;
|
||||
CUdeviceptr sendBasePtr, recvBasePtr;
|
||||
@@ -345,8 +347,7 @@ void Executor::execute(int rank, void* sendbuff, void* recvbuff, size_t sendBuff
|
||||
|
||||
ExecutionContext context = this->impl_->setupExecutionContext(
|
||||
rank, (void*)sendBasePtr, (void*)recvBasePtr, sendBuffSize, offsetIn, offsetOut, sendBytes, recvBytes, plan);
|
||||
// TODO(binyli): need to flush proxy channel here
|
||||
this->impl_->launchKernel(context, rank, nthreads, sendbuff, recvbuff, dataType, stream, packetType);
|
||||
this->impl_->launchKernel(context, rank, sendbuff, recvbuff, dataType, stream, packetType);
|
||||
}
|
||||
|
||||
Executor::~Executor() = default;
|
||||
|
||||
Reference in New Issue
Block a user