Tune threads per block for mscclpp executor (#345)

This commit is contained in:
Binyang Li
2024-09-18 17:21:47 -07:00
committed by GitHub
parent 0c7311e83f
commit b30bb260e3
12 changed files with 43 additions and 46 deletions

View File

@@ -65,6 +65,7 @@ struct ExecutionContext {
std::shared_ptr<char> scratchBuffer;
size_t scratchBufferSize;
std::shared_ptr<char> deviceExecutionPlansBuffer;
int nthreadsPerBlock;
};
struct Executor::Impl {
@@ -104,6 +105,7 @@ struct Executor::Impl {
context.scratchBuffer = scratchBuffer;
context.scratchBufferSize = scratchBufferSize;
context.proxyService = std::make_shared<ProxyService>();
context.nthreadsPerBlock = plan.impl_->getNThreadsPerBlock();
this->setupConnections(context, rank, plan);
this->setupRegisteredMemories(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan);
this->setupChannels(context, sendbuff, recvbuff, sendBufferSize, rank, plan);
@@ -295,8 +297,8 @@ struct Executor::Impl {
context.deviceExecutionPlans = std::move(deviceExecutionPlans);
}
void launchKernel(ExecutionContext& context, int rank, int nthreadsPerBlock, void* sendbuff, void* recvbuff,
DataType dataType, cudaStream_t stream, PacketType packetType) {
void launchKernel(ExecutionContext& context, int rank, void* sendbuff, void* recvbuff, DataType dataType,
cudaStream_t stream, PacketType packetType) {
static uint32_t flag = 0;
int nthreadblocks = context.deviceExecutionPlans.size();
#if defined(ENABLE_NPKIT)
@@ -315,13 +317,13 @@ struct Executor::Impl {
switch (packetType) {
case PacketType::LL16:
ExecutionKernel::launchKernel<LL16Packet>(
rank, nthreadblocks, nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(),
rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(),
context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(),
sharedMemSize, stream, ++flag);
break;
case PacketType::LL8:
ExecutionKernel::launchKernel<LL8Packet>(
rank, nthreadblocks, nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(),
rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(),
context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(),
sharedMemSize, stream, ++flag);
break;
@@ -334,7 +336,7 @@ struct Executor::Impl {
Executor::Executor(std::shared_ptr<Communicator> comm) : impl_(std::make_unique<Impl>(comm)) {}
void Executor::execute(int rank, void* sendbuff, void* recvbuff, size_t sendBuffSize,
[[maybe_unused]] size_t recvBuffSize, DataType dataType, int nthreads, const ExecutionPlan& plan,
[[maybe_unused]] size_t recvBuffSize, DataType dataType, const ExecutionPlan& plan,
cudaStream_t stream, PacketType packetType) {
size_t sendBytes, recvBytes;
CUdeviceptr sendBasePtr, recvBasePtr;
@@ -345,8 +347,7 @@ void Executor::execute(int rank, void* sendbuff, void* recvbuff, size_t sendBuff
ExecutionContext context = this->impl_->setupExecutionContext(
rank, (void*)sendBasePtr, (void*)recvBasePtr, sendBuffSize, offsetIn, offsetOut, sendBytes, recvBytes, plan);
// TODO(binyli): need to flush proxy channel here
this->impl_->launchKernel(context, rank, nthreads, sendbuff, recvbuff, dataType, stream, packetType);
this->impl_->launchKernel(context, rank, sendbuff, recvbuff, dataType, stream, packetType);
}
Executor::~Executor() = default;