mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
Tune threads per block for mscclpp executor (#345)
This commit is contained in:
@@ -488,29 +488,30 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
|
||||
std::shared_ptr<mscclpp::ExecutionPlan> plan;
|
||||
if (bytes <= comm->largeMessageSizeBoundary)
|
||||
plan = (sendbuff == recvbuff) ? comm->allReducePacketIPPlan : comm->allReducePacketOPPlan;
|
||||
else
|
||||
else {
|
||||
plan = (sendbuff == recvbuff) ? comm->allReduceIPPlan : comm->allReduceOPPlan;
|
||||
}
|
||||
|
||||
if (plan == nullptr)
|
||||
return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream);
|
||||
|
||||
switch (datatype) {
|
||||
case ncclFloat16:
|
||||
comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT16, 1024,
|
||||
*plan, stream, mscclpp::PacketType::LL8);
|
||||
comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT16, *plan,
|
||||
stream, mscclpp::PacketType::LL8);
|
||||
break;
|
||||
case ncclFloat32:
|
||||
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT32,
|
||||
1024, *plan, stream, mscclpp::PacketType::LL8);
|
||||
*plan, stream, mscclpp::PacketType::LL8);
|
||||
break;
|
||||
case ncclBfloat16:
|
||||
comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, bytes, bytes,
|
||||
mscclpp::DataType::BFLOAT16, 1024, *plan, stream, mscclpp::PacketType::LL8);
|
||||
mscclpp::DataType::BFLOAT16, *plan, stream, mscclpp::PacketType::LL8);
|
||||
break;
|
||||
case ncclInt32:
|
||||
case ncclUint32:
|
||||
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes, mscclpp::DataType::UINT32, 1024,
|
||||
*plan, stream, mscclpp::PacketType::LL8);
|
||||
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes, mscclpp::DataType::UINT32, *plan,
|
||||
stream, mscclpp::PacketType::LL8);
|
||||
break;
|
||||
default:
|
||||
return ncclInvalidArgument;
|
||||
|
||||
@@ -43,7 +43,7 @@ class Executor {
|
||||
~Executor();
|
||||
|
||||
void execute(int rank, void* sendbuff, void* recvBuff, size_t sendBuffSize, size_t recvBuffSize, DataType dataType,
|
||||
int nthreads, const ExecutionPlan& plan, cudaStream_t stream, PacketType packetType = PacketType::LL16);
|
||||
const ExecutionPlan& plan, cudaStream_t stream, PacketType packetType = PacketType::LL16);
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
|
||||
@@ -29,11 +29,10 @@ void register_executor(nb::module_& m) {
|
||||
.def(
|
||||
"execute",
|
||||
[](Executor* self, int rank, uintptr_t sendbuff, uintptr_t recvBuff, size_t sendBuffSize, size_t recvBuffSize,
|
||||
DataType dataType, int nthreads, const ExecutionPlan& plan, uintptr_t stream, PacketType packetType) {
|
||||
DataType dataType, const ExecutionPlan& plan, uintptr_t stream, PacketType packetType) {
|
||||
self->execute(rank, reinterpret_cast<void*>(sendbuff), reinterpret_cast<void*>(recvBuff), sendBuffSize,
|
||||
recvBuffSize, dataType, nthreads, plan, (cudaStream_t)stream, packetType);
|
||||
recvBuffSize, dataType, plan, (cudaStream_t)stream, packetType);
|
||||
},
|
||||
nb::arg("rank"), nb::arg("sendbuff"), nb::arg("recvBuff"), nb::arg("sendBuffSize"), nb::arg("recvBuffSize"),
|
||||
nb::arg("dataType"), nb::arg("nthreads"), nb::arg("plan"), nb::arg("stream"),
|
||||
nb::arg("packetType") = PacketType::LL16);
|
||||
nb::arg("dataType"), nb::arg("plan"), nb::arg("stream"), nb::arg("packetType") = PacketType::LL16);
|
||||
}
|
||||
|
||||
@@ -81,10 +81,9 @@ def main(
|
||||
execution_paln_name: str,
|
||||
execution_plan_path: str,
|
||||
size: int,
|
||||
nthreads_per_block: int,
|
||||
dtype: cp.dtype = cp.float16,
|
||||
packet_type: PacketType = PacketType.LL16,
|
||||
seed: int = 42,
|
||||
seed: int = 42 + MPI.COMM_WORLD.rank,
|
||||
):
|
||||
mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD)
|
||||
cp.cuda.Device(mscclpp_group.my_rank % mscclpp_group.nranks_per_node).use()
|
||||
@@ -96,12 +95,9 @@ def main(
|
||||
|
||||
cp.random.seed(seed)
|
||||
nelems = size // cp.dtype(dtype).itemsize
|
||||
buffer = cp.random.random(nelems * mscclpp_group.nranks).astype(dtype)
|
||||
sub_arrays = cp.split(buffer, MPI.COMM_WORLD.size)
|
||||
sendbuf = sub_arrays[MPI.COMM_WORLD.rank]
|
||||
expected = cp.zeros_like(sendbuf)
|
||||
for i in range(mscclpp_group.nranks):
|
||||
expected += sub_arrays[i]
|
||||
sendbuf = cp.random.random(nelems).astype(dtype)
|
||||
expected = cp.asnumpy(sendbuf)
|
||||
expected = MPI.COMM_WORLD.allreduce(expected, op=MPI.SUM)
|
||||
mscclpp_group.barrier()
|
||||
|
||||
executor_func = lambda stream: executor.execute(
|
||||
@@ -111,7 +107,6 @@ def main(
|
||||
sendbuf.nbytes,
|
||||
sendbuf.nbytes,
|
||||
dtype_to_mscclpp_dtype(dtype),
|
||||
nthreads_per_block,
|
||||
execution_plan,
|
||||
stream.ptr,
|
||||
packet_type,
|
||||
@@ -130,7 +125,7 @@ def main(
|
||||
print(
|
||||
f"Rank: {MPI.COMM_WORLD.rank} Execution time: {execution_time} us, "
|
||||
f"data size: {sendbuf.nbytes} bytes data type: {dtype().dtype.name} "
|
||||
f"packet type: {packet_type} nthreads_per_block: {nthreads_per_block}"
|
||||
f"packet type: {packet_type}"
|
||||
)
|
||||
executor = None
|
||||
mscclpp_group = None
|
||||
@@ -141,7 +136,6 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-n", "--execution_plan_name", type=str, required=True)
|
||||
parser.add_argument("-path", "--execution_plan_path", type=str, required=True)
|
||||
parser.add_argument("--size", type=str, required=True)
|
||||
parser.add_argument("--nthreads_per_block", type=int, required=True)
|
||||
parser.add_argument("--dtype", type=str, default="float16", help="Choose from float16, float32, int32")
|
||||
parser.add_argument("--packet_type", type=str, default="LL16", help="Choose from LL8, LL16")
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
@@ -157,7 +151,6 @@ if __name__ == "__main__":
|
||||
args.execution_plan_name,
|
||||
args.execution_plan_path,
|
||||
buffer_size,
|
||||
args.nthreads_per_block,
|
||||
dtype,
|
||||
packet_type,
|
||||
args.seed,
|
||||
|
||||
@@ -630,7 +630,6 @@ def test_executor(mpi_group: MpiGroup, filename: str):
|
||||
sendbuf.nbytes,
|
||||
sendbuf.nbytes,
|
||||
DataType.float16,
|
||||
512,
|
||||
execution_plan,
|
||||
stream.ptr,
|
||||
)
|
||||
|
||||
@@ -161,6 +161,8 @@ std::vector<Operation> ExecutionPlan::Impl::getOperations(int rank, int threadbl
|
||||
|
||||
int ExecutionPlan::Impl::getThreadblockCount(int rank) const { return this->operations.at(rank).size(); }
|
||||
|
||||
int ExecutionPlan::Impl::getNThreadsPerBlock() const { return this->nThreadsPerBlock; }
|
||||
|
||||
void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset) {
|
||||
std::ifstream file(this->planPath);
|
||||
json obj = json::parse(file);
|
||||
@@ -171,6 +173,7 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t contsSrcOff
|
||||
if (protocol == "LL") {
|
||||
this->isUsingPacket = true;
|
||||
}
|
||||
this->nThreadsPerBlock = obj.value("num_threads_per_block", 1024);
|
||||
const auto& gpus = obj["gpus"];
|
||||
|
||||
for (const auto& gpu : gpus) {
|
||||
|
||||
@@ -65,6 +65,7 @@ struct ExecutionContext {
|
||||
std::shared_ptr<char> scratchBuffer;
|
||||
size_t scratchBufferSize;
|
||||
std::shared_ptr<char> deviceExecutionPlansBuffer;
|
||||
int nthreadsPerBlock;
|
||||
};
|
||||
|
||||
struct Executor::Impl {
|
||||
@@ -104,6 +105,7 @@ struct Executor::Impl {
|
||||
context.scratchBuffer = scratchBuffer;
|
||||
context.scratchBufferSize = scratchBufferSize;
|
||||
context.proxyService = std::make_shared<ProxyService>();
|
||||
context.nthreadsPerBlock = plan.impl_->getNThreadsPerBlock();
|
||||
this->setupConnections(context, rank, plan);
|
||||
this->setupRegisteredMemories(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan);
|
||||
this->setupChannels(context, sendbuff, recvbuff, sendBufferSize, rank, plan);
|
||||
@@ -295,8 +297,8 @@ struct Executor::Impl {
|
||||
context.deviceExecutionPlans = std::move(deviceExecutionPlans);
|
||||
}
|
||||
|
||||
void launchKernel(ExecutionContext& context, int rank, int nthreadsPerBlock, void* sendbuff, void* recvbuff,
|
||||
DataType dataType, cudaStream_t stream, PacketType packetType) {
|
||||
void launchKernel(ExecutionContext& context, int rank, void* sendbuff, void* recvbuff, DataType dataType,
|
||||
cudaStream_t stream, PacketType packetType) {
|
||||
static uint32_t flag = 0;
|
||||
int nthreadblocks = context.deviceExecutionPlans.size();
|
||||
#if defined(ENABLE_NPKIT)
|
||||
@@ -315,13 +317,13 @@ struct Executor::Impl {
|
||||
switch (packetType) {
|
||||
case PacketType::LL16:
|
||||
ExecutionKernel::launchKernel<LL16Packet>(
|
||||
rank, nthreadblocks, nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(),
|
||||
rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(),
|
||||
context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(),
|
||||
sharedMemSize, stream, ++flag);
|
||||
break;
|
||||
case PacketType::LL8:
|
||||
ExecutionKernel::launchKernel<LL8Packet>(
|
||||
rank, nthreadblocks, nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(),
|
||||
rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(),
|
||||
context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(),
|
||||
sharedMemSize, stream, ++flag);
|
||||
break;
|
||||
@@ -334,7 +336,7 @@ struct Executor::Impl {
|
||||
Executor::Executor(std::shared_ptr<Communicator> comm) : impl_(std::make_unique<Impl>(comm)) {}
|
||||
|
||||
void Executor::execute(int rank, void* sendbuff, void* recvbuff, size_t sendBuffSize,
|
||||
[[maybe_unused]] size_t recvBuffSize, DataType dataType, int nthreads, const ExecutionPlan& plan,
|
||||
[[maybe_unused]] size_t recvBuffSize, DataType dataType, const ExecutionPlan& plan,
|
||||
cudaStream_t stream, PacketType packetType) {
|
||||
size_t sendBytes, recvBytes;
|
||||
CUdeviceptr sendBasePtr, recvBasePtr;
|
||||
@@ -345,8 +347,7 @@ void Executor::execute(int rank, void* sendbuff, void* recvbuff, size_t sendBuff
|
||||
|
||||
ExecutionContext context = this->impl_->setupExecutionContext(
|
||||
rank, (void*)sendBasePtr, (void*)recvBasePtr, sendBuffSize, offsetIn, offsetOut, sendBytes, recvBytes, plan);
|
||||
// TODO(binyli): need to flush proxy channel here
|
||||
this->impl_->launchKernel(context, rank, nthreads, sendbuff, recvbuff, dataType, stream, packetType);
|
||||
this->impl_->launchKernel(context, rank, sendbuff, recvbuff, dataType, stream, packetType);
|
||||
}
|
||||
|
||||
Executor::~Executor() = default;
|
||||
|
||||
@@ -68,6 +68,7 @@ struct ExecutionPlan::Impl {
|
||||
size_t getScratchBufferSize(int rank, size_t inputSize) const;
|
||||
std::vector<Operation> getOperations(int rank, int threadblock) const;
|
||||
int getThreadblockCount(int rank) const;
|
||||
int getNThreadsPerBlock() const;
|
||||
|
||||
void loadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset);
|
||||
void lightLoadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset);
|
||||
@@ -93,6 +94,7 @@ struct ExecutionPlan::Impl {
|
||||
std::unordered_map<int, uint32_t> scratchChunks;
|
||||
std::unordered_map<int, uint32_t> chunkGroups;
|
||||
size_t inputSize;
|
||||
int nThreadsPerBlock;
|
||||
|
||||
private:
|
||||
size_t getOffset(int rank, size_t inputSize, uint32_t chunkIndex, uint32_t alignment = 16) const;
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"colletive": "allreduce",
|
||||
"protocol": "Simple",
|
||||
"inplace": true,
|
||||
"num_threads_per_block": 512,
|
||||
"gpus": [
|
||||
{
|
||||
"id": 0,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"colletive": "allreduce",
|
||||
"protocol": "LL",
|
||||
"inplace": true,
|
||||
"num_threads_per_block": 768,
|
||||
"gpus": [
|
||||
{
|
||||
"id": 0,
|
||||
|
||||
@@ -56,16 +56,16 @@ mscclpp::PacketType parsePacketType(const char* value) {
|
||||
}
|
||||
|
||||
double benchTime(int rank, std::shared_ptr<mscclpp::Bootstrap> bootstrap, std::shared_ptr<mscclpp::Executor> executor,
|
||||
const mscclpp::ExecutionPlan& plan, std::shared_ptr<char> sendbuff, size_t bufferSize,
|
||||
int nthreadsPerBlock, int niters, int ngrapthIters, mscclpp::PacketType packetType) {
|
||||
const mscclpp::ExecutionPlan& plan, std::shared_ptr<char> sendbuff, size_t bufferSize, int niters,
|
||||
int ngrapthIters, mscclpp::PacketType packetType) {
|
||||
mscclpp::CudaStreamWithFlags stream(cudaStreamNonBlocking);
|
||||
cudaGraph_t graph;
|
||||
cudaGraphExec_t graphExec;
|
||||
mscclpp::Timer timer;
|
||||
MSCCLPP_CUDATHROW(cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal));
|
||||
for (int i = 0; i < niters; i++) {
|
||||
executor->execute(rank, sendbuff.get(), sendbuff.get(), bufferSize, bufferSize, mscclpp::DataType::FLOAT16,
|
||||
nthreadsPerBlock, plan, stream, packetType);
|
||||
executor->execute(rank, sendbuff.get(), sendbuff.get(), bufferSize, bufferSize, mscclpp::DataType::FLOAT16, plan,
|
||||
stream, packetType);
|
||||
}
|
||||
MSCCLPP_CUDATHROW(cudaStreamEndCapture(stream, &graph));
|
||||
MSCCLPP_CUDATHROW(cudaGraphInstantiate(&graphExec, graph, NULL, NULL, 0));
|
||||
@@ -86,11 +86,10 @@ double benchTime(int rank, std::shared_ptr<mscclpp::Bootstrap> bootstrap, std::s
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
if (argc != 7 && argc != 8) {
|
||||
if (argc != 6 && argc != 7) {
|
||||
std::cerr << "Usage: " << argv[0] << " <buffer size>"
|
||||
<< " <execution plan name>"
|
||||
<< " <execution plan path>"
|
||||
<< " <nthreads per block>"
|
||||
<< " <number of iterations>"
|
||||
<< " <number of graph iterations>"
|
||||
<< " (optional) <packet type>" << std::endl;
|
||||
@@ -107,13 +106,12 @@ int main(int argc, char* argv[]) {
|
||||
const size_t bufferSize = parseSize(argv[1]);
|
||||
const std::string executionPlanName = argv[2];
|
||||
const std::string executionPlanPath = argv[3];
|
||||
const int nthreadsPerBlock = std::stoi(argv[4]);
|
||||
const int niters = std::stoi(argv[5]);
|
||||
const int ngraphIters = std::stoi(argv[6]);
|
||||
const int niters = std::stoi(argv[4]);
|
||||
const int ngraphIters = std::stoi(argv[5]);
|
||||
const char* npkitDumpDir = getenv("NPKIT_DUMP_DIR");
|
||||
mscclpp::PacketType packetType = mscclpp::PacketType::LL16;
|
||||
if (argc == 8) {
|
||||
packetType = parsePacketType(argv[7]);
|
||||
if (argc == 7) {
|
||||
packetType = parsePacketType(argv[6]);
|
||||
}
|
||||
|
||||
std::shared_ptr<mscclpp::TcpBootstrap> bootstrap;
|
||||
@@ -133,8 +131,7 @@ int main(int argc, char* argv[]) {
|
||||
std::shared_ptr<char> sendbuff = mscclpp::allocExtSharedCuda<char>(bufferSize);
|
||||
std::vector<int> dataHost(bufferSize / sizeof(int), rank);
|
||||
MSCCLPP_CUDATHROW(cudaMemcpy(sendbuff.get(), dataHost.data(), bufferSize, cudaMemcpyHostToDevice));
|
||||
double deltaSec = benchTime(rank, bootstrap, executor, plan, sendbuff, bufferSize, nthreadsPerBlock, niters,
|
||||
ngraphIters, packetType);
|
||||
double deltaSec = benchTime(rank, bootstrap, executor, plan, sendbuff, bufferSize, niters, ngraphIters, packetType);
|
||||
|
||||
if (npkitDumpDir != nullptr) {
|
||||
NpKit::Dump(npkitDumpDir);
|
||||
|
||||
@@ -59,7 +59,7 @@ TEST_F(ExecutorTest, TwoNodesAllreduce) {
|
||||
const int bufferSize = 1024 * 1024;
|
||||
std::shared_ptr<char> sendbuff = mscclpp::allocExtSharedCuda<char>(bufferSize);
|
||||
mscclpp::CudaStreamWithFlags stream(cudaStreamNonBlocking);
|
||||
executor->execute(gEnv->rank, sendbuff.get(), sendbuff.get(), bufferSize, bufferSize, mscclpp::DataType::FLOAT16, 512,
|
||||
executor->execute(gEnv->rank, sendbuff.get(), sendbuff.get(), bufferSize, bufferSize, mscclpp::DataType::FLOAT16,
|
||||
plan, stream);
|
||||
MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user