mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-20 06:49:29 +00:00
Integrate MSCCL++ DSL to torch workload (#620)
Provides two integration ways for MSCCL++ DSL. 1. Integrate with customized communication group 2. Integrate with NCCL API Introduce new Python APIs to make it work: ```python mscclpp.compile # compile dsl to json based execution plan mscclpp.ExecutionPlanRegistry.register_plan(plan) # register the compiled plan to executionPlanRegistery mscclpp.ExecutionPlanRegistry.set_selector(selector) # set the selector, the selector will return the best execution plan based on collection, message size, world size.... ``` Fix #556 --------- Co-authored-by: Caio Rocha <caiorocha@microsoft.com> Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
This commit is contained in:
@@ -1237,7 +1237,6 @@ class AllreduceNvlsPacket : public mscclpp::AlgorithmBuilder {
|
||||
|
||||
size_t scratchBufferSize_;
|
||||
std::shared_ptr<char> scratchBuffer_;
|
||||
const int nSegmentsForScratchBuffer_ = 2;
|
||||
const size_t nvlsBufferSize_ = (1 << 30);
|
||||
|
||||
std::shared_ptr<uint32_t> deviceFlag_;
|
||||
|
||||
@@ -120,7 +120,8 @@ static inline int mscclppNcclDlopenInit() {
|
||||
return dlopenSuccess;
|
||||
}
|
||||
|
||||
static inline void mscclppNcclDlopenFinalize() {
|
||||
// No need to call this function, handle will be closed at program exit
|
||||
[[maybe_unused]] static inline void mscclppNcclDlopenFinalize() {
|
||||
if (mscclppNcclDlHandle) {
|
||||
dlclose(mscclppNcclDlHandle);
|
||||
}
|
||||
@@ -159,17 +160,6 @@ static bool tryLoadNcclSharedLib() {
|
||||
// Declare the global map to store associations between raw pointer and shared pointer
|
||||
static std::unordered_map<void*, std::shared_ptr<char>> ptrMap;
|
||||
|
||||
struct planKey {
|
||||
size_t minMessageSize;
|
||||
size_t maxMessageSize;
|
||||
bool isInPlace;
|
||||
};
|
||||
|
||||
struct executionPlanInstance {
|
||||
planKey key;
|
||||
std::shared_ptr<mscclpp::ExecutionPlan> plan;
|
||||
};
|
||||
|
||||
struct splitCommInfo {
|
||||
int color;
|
||||
int key;
|
||||
@@ -179,23 +169,16 @@ struct splitCommInfo {
|
||||
struct ncclComm {
|
||||
std::shared_ptr<mscclpp::Communicator> comm;
|
||||
std::shared_ptr<mscclpp::Executor> executor;
|
||||
std::unordered_map<std::string, std::vector<executionPlanInstance>> executionPlans;
|
||||
std::shared_ptr<mscclpp::AlgorithmCollection> algorithmCollection;
|
||||
std::shared_ptr<char> scratchBuffer_;
|
||||
const size_t scratchBufferSize_ = (1 << 27); // 128MB
|
||||
std::shared_ptr<mscclpp::ExecutionPlanRegistry> planRegistry_;
|
||||
int nRanksPerNode;
|
||||
int worldSize;
|
||||
|
||||
void* mscclppNcclComm;
|
||||
};
|
||||
|
||||
static std::pair<std::string, executionPlanInstance> loadExecutionPlan(const std::string& filename, int rank) {
|
||||
std::shared_ptr<mscclpp::ExecutionPlan> plan = std::make_shared<mscclpp::ExecutionPlan>(filename, rank);
|
||||
std::string collective = plan->collective();
|
||||
planKey key{plan->minMessageSize(), plan->maxMessageSize(), plan->isInPlace()};
|
||||
return std::make_pair(collective, executionPlanInstance{key, plan});
|
||||
}
|
||||
|
||||
static ncclResult_t executeWithPlan(std::shared_ptr<mscclpp::Executor> executor, int rank, ncclDataType_t datatype,
|
||||
const void* sendbuff, void* recvbuff, size_t sendBytes, size_t recvBytes,
|
||||
std::shared_ptr<mscclpp::ExecutionPlan> plan, cudaStream_t stream) {
|
||||
@@ -352,6 +335,20 @@ static mscclpp::Algorithm algoSelector(
|
||||
return mscclpp::Algorithm();
|
||||
}
|
||||
|
||||
std::shared_ptr<mscclpp::ExecutionPlanHandle> executionPlanDefaultSelector(
|
||||
const std::vector<std::shared_ptr<mscclpp::ExecutionPlanHandle>> plans, const mscclpp::ExecutionRequest&) {
|
||||
if (plans.empty()) {
|
||||
INFO(MSCCLPP_NCCL, "No execution plans available for selection");
|
||||
return nullptr;
|
||||
}
|
||||
for (auto plan : plans) {
|
||||
if (plan->tags.find("default") == plan->tags.end()) {
|
||||
return plan;
|
||||
}
|
||||
}
|
||||
return plans[0];
|
||||
}
|
||||
|
||||
NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank) {
|
||||
INFO(MSCCLPP_NCCL, "Initializing NCCL communicator for rank %d, world_size=%d", rank, nranks);
|
||||
if (comm == nullptr) {
|
||||
@@ -371,29 +368,13 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
|
||||
|
||||
commPtr->comm = mscclppComm;
|
||||
commPtr->scratchBuffer_ = mscclpp::GpuBuffer<char>(commPtr->scratchBufferSize_).memory();
|
||||
commPtr->executor = std::make_shared<mscclpp::Executor>(mscclppComm);
|
||||
commPtr->executor = std::make_shared<mscclpp::Executor>(mscclppComm, commPtr->scratchBuffer_);
|
||||
commPtr->planRegistry_ = mscclpp::ExecutionPlanRegistry::getInstance();
|
||||
|
||||
commPtr->nRanksPerNode = mscclppComm->bootstrap()->getNranksPerNode();
|
||||
commPtr->worldSize = mscclppComm->bootstrap()->getNranks();
|
||||
|
||||
if (commPtr->worldSize == 1) {
|
||||
*comm = commPtr;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
const std::string& collectiveDir = mscclpp::env()->executionPlanDir;
|
||||
if (collectiveDir != "") {
|
||||
if (!std::filesystem::is_directory(collectiveDir)) {
|
||||
WARN("The value of the environment variable %s is not a directory", collectiveDir.c_str());
|
||||
return ncclInvalidArgument;
|
||||
}
|
||||
for (const auto& entry : std::filesystem::directory_iterator(collectiveDir)) {
|
||||
if (entry.is_regular_file()) {
|
||||
auto plan = loadExecutionPlan(entry.path(), rank);
|
||||
commPtr->executionPlans[plan.first].push_back(plan.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
commPtr->planRegistry_->loadDefaultPlans(rank);
|
||||
commPtr->planRegistry_->setDefaultSelector(executionPlanDefaultSelector);
|
||||
mscclpp::AlgorithmCollectionBuilder::getInstance()->setFallbackAlgorithmSelector(algoSelector);
|
||||
registerCustomizedAlgo();
|
||||
commPtr->algorithmCollection = mscclpp::AlgorithmCollectionBuilder::getInstance()->build();
|
||||
@@ -462,12 +443,12 @@ NCCL_API ncclResult_t ncclCommDestroy(ncclComm_t comm) {
|
||||
}
|
||||
#endif
|
||||
|
||||
if (mscclppNcclDlopenSharedLib == true) {
|
||||
mscclppNcclOps.CommDestroy(*reinterpret_cast<ncclComm_t*>(comm->mscclppNcclComm));
|
||||
mscclppNcclDlopenFinalize();
|
||||
delete static_cast<ncclComm_t*>(comm->mscclppNcclComm);
|
||||
}
|
||||
ncclComm_t* mscclppNcclCommPtr = reinterpret_cast<ncclComm_t*>(comm->mscclppNcclComm);
|
||||
delete comm;
|
||||
if (mscclppNcclCommPtr != nullptr) {
|
||||
mscclppNcclOps.CommDestroy(*reinterpret_cast<ncclComm_t*>(mscclppNcclCommPtr));
|
||||
delete static_cast<ncclComm_t*>(mscclppNcclCommPtr);
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -646,18 +627,13 @@ NCCL_API ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t
|
||||
*reinterpret_cast<ncclComm_t*>(comm->mscclppNcclComm), stream);
|
||||
}
|
||||
|
||||
std::vector<executionPlanInstance>& plans = comm->executionPlans["broadcast"];
|
||||
std::shared_ptr<mscclpp::ExecutionPlan> plan;
|
||||
bool inPlace = sendbuff == recvbuff;
|
||||
for (const auto& p : plans) {
|
||||
if (bytes >= p.key.minMessageSize && bytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) {
|
||||
plan = p.plan;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (plan != nullptr) {
|
||||
return executeWithPlan(comm->executor, rank, datatype, sendbuff, recvbuff, bytes, bytes, plan, stream);
|
||||
static std::unordered_map<std::string, std::vector<uint64_t>> hints{{"root", {static_cast<uint64_t>(root)}}};
|
||||
hints["root"][0] = static_cast<uint64_t>(root);
|
||||
auto planHandle = comm->planRegistry_->select("broadcast", comm->comm->bootstrap()->getNranks(),
|
||||
comm->comm->bootstrap()->getNranksPerNode(),
|
||||
comm->comm->bootstrap()->getRank(), sendbuff, recvbuff, bytes, hints);
|
||||
if (planHandle != nullptr) {
|
||||
return executeWithPlan(comm->executor, rank, datatype, sendbuff, recvbuff, bytes, bytes, planHandle->plan, stream);
|
||||
}
|
||||
auto algo = comm->algorithmCollection->selectAlgorithm(
|
||||
"broadcast", sendbuff, recvbuff, count * ncclTypeSize(datatype), datatype,
|
||||
@@ -706,18 +682,11 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
|
||||
*reinterpret_cast<ncclComm_t*>(comm->mscclppNcclComm), stream);
|
||||
}
|
||||
|
||||
std::vector<executionPlanInstance>& plans = comm->executionPlans["allreduce"];
|
||||
std::shared_ptr<mscclpp::ExecutionPlan> plan;
|
||||
bool inPlace = sendbuff == recvbuff;
|
||||
for (const auto& p : plans) {
|
||||
if (bytes >= p.key.minMessageSize && bytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) {
|
||||
plan = p.plan;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (plan != nullptr) {
|
||||
return executeWithPlan(comm->executor, rank, datatype, sendbuff, recvbuff, bytes, bytes, plan, stream);
|
||||
auto planHandler = comm->planRegistry_->select("allreduce", comm->comm->bootstrap()->getNranks(),
|
||||
comm->comm->bootstrap()->getNranksPerNode(),
|
||||
comm->comm->bootstrap()->getRank(), sendbuff, recvbuff, bytes, {});
|
||||
if (planHandler != nullptr) {
|
||||
return executeWithPlan(comm->executor, rank, datatype, sendbuff, recvbuff, bytes, bytes, planHandler->plan, stream);
|
||||
}
|
||||
|
||||
auto algo = comm->algorithmCollection->selectAlgorithm(
|
||||
@@ -769,20 +738,12 @@ NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, si
|
||||
int rank = comm->comm->bootstrap()->getRank();
|
||||
int nRank = comm->comm->bootstrap()->getNranks();
|
||||
|
||||
std::vector<executionPlanInstance>& plans = comm->executionPlans["reducescatter"];
|
||||
std::shared_ptr<mscclpp::ExecutionPlan> plan;
|
||||
void* basePtr = (char*)sendbuff + rank * bytes;
|
||||
bool inPlace = basePtr == recvbuff;
|
||||
const size_t totalBytes = bytes * nRank;
|
||||
for (const auto& p : plans) {
|
||||
if (totalBytes >= p.key.minMessageSize && totalBytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) {
|
||||
plan = p.plan;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (plan != nullptr) {
|
||||
return executeWithPlan(comm->executor, rank, datatype, sendbuff, recvbuff, totalBytes, bytes, plan, stream);
|
||||
auto planHandle = comm->planRegistry_->select("reducescatter", comm->comm->bootstrap()->getNranks(),
|
||||
comm->comm->bootstrap()->getNranksPerNode(),
|
||||
comm->comm->bootstrap()->getRank(), sendbuff, recvbuff, bytes, {});
|
||||
if (planHandle != nullptr) {
|
||||
return executeWithPlan(comm->executor, rank, datatype, sendbuff, recvbuff, bytes * nRank, bytes, planHandle->plan,
|
||||
stream);
|
||||
}
|
||||
|
||||
if (mscclppNcclDlopenSharedLib == true) {
|
||||
@@ -821,20 +782,12 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
|
||||
*reinterpret_cast<ncclComm_t*>(comm->mscclppNcclComm), stream);
|
||||
}
|
||||
|
||||
std::vector<executionPlanInstance>& plans = comm->executionPlans["allgather"];
|
||||
std::shared_ptr<mscclpp::ExecutionPlan> plan;
|
||||
void* basePtr = (char*)sendbuff - rank * bytes;
|
||||
bool inPlace = basePtr == recvbuff;
|
||||
const size_t totalBytes = bytes * nRank;
|
||||
for (const auto& p : plans) {
|
||||
if (totalBytes >= p.key.minMessageSize && totalBytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) {
|
||||
plan = p.plan;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (plan != nullptr) {
|
||||
return executeWithPlan(comm->executor, rank, datatype, sendbuff, recvbuff, bytes, totalBytes, plan, stream);
|
||||
auto planHandle = comm->planRegistry_->select("allgather", comm->comm->bootstrap()->getNranks(),
|
||||
comm->comm->bootstrap()->getNranksPerNode(),
|
||||
comm->comm->bootstrap()->getRank(), sendbuff, recvbuff, bytes, {});
|
||||
if (planHandle != nullptr) {
|
||||
return executeWithPlan(comm->executor, rank, datatype, sendbuff, recvbuff, bytes, bytes * nRank, planHandle->plan,
|
||||
stream);
|
||||
}
|
||||
|
||||
auto algo = comm->algorithmCollection->selectAlgorithm(
|
||||
|
||||
@@ -49,7 +49,7 @@ autodoc_default_options = {
|
||||
"show-inheritance": True,
|
||||
}
|
||||
# only mock the C-extension when using the source tree
|
||||
autodoc_mock_imports = ["mscclpp._version", "mscclpp._mscclpp", "cupy", "mpi4py", "numpy", "sortedcontainers"]
|
||||
autodoc_mock_imports = ["mscclpp._version", "mscclpp._mscclpp", "blake3", "cupy", "mpi4py", "numpy", "sortedcontainers"]
|
||||
autodoc_typehints = "description"
|
||||
napoleon_google_docstring = True
|
||||
napoleon_numpy_docstring = True
|
||||
|
||||
126
docs/guide/mscclpp-dsl-integration.md
Normal file
126
docs/guide/mscclpp-dsl-integration.md
Normal file
@@ -0,0 +1,126 @@
|
||||
# MSCCL++ DSL Integration Guide
|
||||
|
||||
MSCCL++ DSL (domain-specific language) enables concise expression of collective algorithms as Python functions.
|
||||
MSCCL++ offers pythonic utilities to author, JIT-compile, register, and select execution plans. This guide walks through two integration paths: a customized MSCCL++ communicator and NCCL interposition that accelerates existing PyTorch `backend="nccl"` workloads.
|
||||
|
||||
## Initial Setup
|
||||
|
||||
Run the following from the repository root after completing the basic project setup:
|
||||
|
||||
1. Install Python dependencies.
|
||||
```bash
|
||||
pip install -r ./python/<requirements_file>
|
||||
```
|
||||
Replace `<requirements_file>` with the file that matches your environment (e.g., `requirements_cuda11.txt`, `requirements_cuda12.txt`, or `requirements_rocm6.txt`).
|
||||
|
||||
2. Install the module and generate default algorithm plans.
|
||||
```bash
|
||||
pip install . && python3 -m mscclpp --install
|
||||
```
|
||||
|
||||
## Integration Options
|
||||
|
||||
MSCCL++ DSL integrates into your training or inference workload in two ways:
|
||||
1. **Custom MSCCL++ Communicator** — directly manage an MSCCL++ communicator and launch collectives with the MSCCL++ executor.
|
||||
2. **NCCL Interposition** — keep using `backend="nccl"`; MSCCL++ intercepts NCCL calls at runtime for drop-in acceleration.
|
||||
|
||||
Both paths follow the same high-level flow:
|
||||
1. Author (or reuse) a collective algorithm with the MSCCL++ DSL.
|
||||
2. Compile it into an execution plan.
|
||||
3. Register the plan with the MSCCL++ runtime.
|
||||
4. Configure a selector to choose the plan for each collective call.
|
||||
|
||||
Below we show an AllReduce example and then detail each integration option.
|
||||
|
||||
### Example: AllReduce in the MSCCL++ DSL
|
||||
The snippet defines an AllReduce that uses NVLS for intra-node reduce-scatter followed by broadcast.
|
||||
```python
|
||||
def allreduce_nvls(spec: mscclpp.AlgoSpec) -> CollectiveProgram:
|
||||
gpu_size = spec.world_size
|
||||
with CollectiveProgram(
|
||||
spec.name,
|
||||
spec.collective,
|
||||
gpu_size,
|
||||
instances=8,
|
||||
protocol=spec.protocol,
|
||||
num_threads_per_block=spec.num_threads_per_block,
|
||||
min_message_size=spec.min_message_size,
|
||||
max_message_size=spec.max_message_size,
|
||||
) as program:
|
||||
# Creating Channels
|
||||
nvls_chan = SwitchChannel(rank_list=[gpu for gpu in range(gpu_size)], buffer_type=BufferType.input)
|
||||
channels = {}
|
||||
for gpu in range(gpu_size):
|
||||
for peer in range(gpu_size):
|
||||
if peer != gpu:
|
||||
channels[(peer, gpu)] = MemoryChannel(peer, gpu)
|
||||
|
||||
# Synchronization to Ensure all the Gpus are Ready
|
||||
for gpu in range(gpu_size):
|
||||
src_rank = gpu
|
||||
for peer in range(gpu_size):
|
||||
if peer != src_rank:
|
||||
dst_rank = peer
|
||||
channels[(dst_rank, src_rank)].signal(tb=0, relaxed=True)
|
||||
for peer in range(gpu_size):
|
||||
if peer != src_rank:
|
||||
dst_rank = peer
|
||||
channels[(dst_rank, src_rank)].wait(tb=0, relaxed=True, data_sync=SyncType.after)
|
||||
# Reducing and Storing the data
|
||||
for gpu in range(gpu_size):
|
||||
buffer_offset = gpu
|
||||
rank = Rank(gpu)
|
||||
input_buffer = rank.get_input_buffer()
|
||||
nvls_chan.at_rank(gpu).reduce(
|
||||
buffer_offset=buffer_offset, size=1, dst_chunk=input_buffer[gpu : gpu + 1], tb=0
|
||||
)
|
||||
nvls_chan.at_rank(gpu).broadcast(
|
||||
src_chunk=input_buffer[gpu : gpu + 1], buffer_offset=buffer_offset, size=1, tb=0
|
||||
)
|
||||
# Synchronization to Ensure the Gpus finished
|
||||
for gpu in range(gpu_size):
|
||||
src_rank = gpu
|
||||
for peer in range(gpu_size):
|
||||
if peer != src_rank:
|
||||
dst_rank = peer
|
||||
channels[(dst_rank, src_rank)].signal(tb=0, relaxed=True, data_sync=SyncType.before)
|
||||
for peer in range(gpu_size):
|
||||
if peer != src_rank:
|
||||
dst_rank = peer
|
||||
channels[(dst_rank, src_rank)].wait(tb=0, relaxed=True)
|
||||
|
||||
return program
|
||||
```
|
||||
|
||||
### Integrate with MSCCL++ customized communicator
|
||||
Use when you want a PyTorch‑compatible interface with fine‑grained control. You manage the communicator, compile/register DSL plans, and invoke collectives via a thin wrapper. The example below shows an AllReduce built on the MSCCL++ communicator and executor.
|
||||
Example source directory:
|
||||
```
|
||||
examples/torch-integration
|
||||
```
|
||||
Key file: `customized_comm.py`.
|
||||
|
||||
|
||||
#### Launch (single node)
|
||||
```bash
|
||||
MSCCLPP_MASTER_ADDR=<master_ip> MSCCLPP_MASTER_PORT=<port> torchrun --nnodes=1 --nproc_per_node=8 customized_comm.py
|
||||
```
|
||||
|
||||
### Integrate via NCCL Interposition
|
||||
Keep your script as‑is: init PyTorch with backend="nccl"; MSCCL++ intercepts NCCL calls for drop‑in acceleration.
|
||||
Example source directory:
|
||||
```
|
||||
examples/torch-integration
|
||||
```
|
||||
Key file: `dsl_with_nccl_api.py`.
|
||||
|
||||
#### Launch with interposition
|
||||
To run with NCCL interposition, you preload the MSCCL++ shim so it transparently intercepts NCCL calls made by PyTorch’s nccl backend.
|
||||
```bash
|
||||
LD_PRELOAD=<MSCCLPP_REPO>/build/apps/nccl/libmscclpp_nccl.so torchrun --nnodes=1 --nproc_per_node=8 dsl_with_nccl_api.py
|
||||
```
|
||||
## Notices:
|
||||
- When using NCCL interposition, the algorithm selection order is:
|
||||
1. Check for registered DSL plans matching the collective call.
|
||||
2. Check for a customized kernel implementation if no DSL plan fits.
|
||||
3. Fall back to the default NCCL implementation (set `MSCCLPP_NCCL_LIB_PATH` to the original NCCL library).
|
||||
@@ -13,3 +13,4 @@ This section provides advanced topics and best practices for using MSCCL++. It i
|
||||
guide/cpp-examples
|
||||
guide/mscclpp-dsl
|
||||
guide/customized-algorithm-with-nccl-api
|
||||
guide/mscclpp-dsl-integration
|
||||
|
||||
201
examples/torch-integration/customized_comm.py
Normal file
201
examples/torch-integration/customized_comm.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# MSCCLPP_MASTER_ADDR=<master_ip> MSCCLPP_MASTER_PORT=<port> torchrun --nnodes=1 --nproc_per_node=8 customized_comm.py
|
||||
|
||||
import os
|
||||
import torch
|
||||
import mscclpp.comm as mscclpp_comm
|
||||
import mscclpp
|
||||
from mscclpp.language.collectives import AllReduce
|
||||
from mscclpp.language.channel import SwitchChannel, MemoryChannel, BufferType, SyncType
|
||||
from mscclpp.language.program import CollectiveProgram
|
||||
from mscclpp.language.rank import Rank
|
||||
import netifaces as ni
|
||||
import ipaddress
|
||||
|
||||
|
||||
def allreduce_nvls(spec: mscclpp.AlgoSpec) -> CollectiveProgram:
|
||||
gpu_size = spec.world_size
|
||||
with CollectiveProgram(
|
||||
spec.name,
|
||||
spec.collective,
|
||||
gpu_size,
|
||||
instances=8,
|
||||
protocol=spec.protocol,
|
||||
num_threads_per_block=spec.num_threads_per_block,
|
||||
min_message_size=spec.min_message_size,
|
||||
max_message_size=spec.max_message_size,
|
||||
) as program:
|
||||
# Creating Channels
|
||||
nvls_chan = SwitchChannel(rank_list=[gpu for gpu in range(gpu_size)], buffer_type=BufferType.input)
|
||||
channels = {}
|
||||
for gpu in range(gpu_size):
|
||||
for peer in range(gpu_size):
|
||||
if peer != gpu:
|
||||
channels[(peer, gpu)] = MemoryChannel(peer, gpu)
|
||||
|
||||
# Synchronization to Ensure all the Gpus are Ready
|
||||
for gpu in range(gpu_size):
|
||||
src_rank = gpu
|
||||
for peer in range(gpu_size):
|
||||
if peer != src_rank:
|
||||
dst_rank = peer
|
||||
channels[(dst_rank, src_rank)].signal(tb=0, relaxed=True)
|
||||
for peer in range(gpu_size):
|
||||
if peer != src_rank:
|
||||
dst_rank = peer
|
||||
channels[(dst_rank, src_rank)].wait(tb=0, relaxed=True, data_sync=SyncType.after)
|
||||
|
||||
# Reducing and Storing the data
|
||||
for gpu in range(gpu_size):
|
||||
buffer_offset = gpu
|
||||
rank = Rank(gpu)
|
||||
input_buffer = rank.get_input_buffer()
|
||||
nvls_chan.at_rank(gpu).reduce(
|
||||
buffer_offset=buffer_offset, size=1, dst_chunk=input_buffer[gpu : gpu + 1], tb=0
|
||||
)
|
||||
nvls_chan.at_rank(gpu).broadcast(
|
||||
src_chunk=input_buffer[gpu : gpu + 1], buffer_offset=buffer_offset, size=1, tb=0
|
||||
)
|
||||
|
||||
# Synchronization to Ensure the Gpus finished
|
||||
for gpu in range(gpu_size):
|
||||
src_rank = gpu
|
||||
for peer in range(gpu_size):
|
||||
if peer != src_rank:
|
||||
dst_rank = peer
|
||||
channels[(dst_rank, src_rank)].signal(tb=0, relaxed=True, data_sync=SyncType.before)
|
||||
for peer in range(gpu_size):
|
||||
if peer != src_rank:
|
||||
dst_rank = peer
|
||||
channels[(dst_rank, src_rank)].wait(tb=0, relaxed=True)
|
||||
|
||||
return program
|
||||
|
||||
|
||||
def setup_plan(registry: mscclpp.ExecutionPlanRegistry, rank: int, world_size: int):
|
||||
spec = mscclpp.AlgoSpec(
|
||||
name="allreduce_nvls",
|
||||
collective=AllReduce(8, 1, True),
|
||||
nranks_per_node=8,
|
||||
world_size=world_size,
|
||||
in_place=True,
|
||||
instances=2,
|
||||
protocol="Simple",
|
||||
num_threads_per_block=1024,
|
||||
min_message_size=1 << 20,
|
||||
max_message_size=48 << 30,
|
||||
tags={"nvls": 1},
|
||||
)
|
||||
|
||||
plan_handle = mscclpp.compile(algo=allreduce_nvls, algo_spec=spec, rank=rank)
|
||||
registry.register_plan(plan_handle)
|
||||
|
||||
|
||||
def selector(plans, req):
|
||||
if req.collective != "allreduce":
|
||||
return None
|
||||
if req.message_size < 1 << 20:
|
||||
return None
|
||||
nvls = [p for p in plans if "nvls" in p.tags]
|
||||
return nvls[0] if nvls else plans[0]
|
||||
|
||||
|
||||
def interfaces_for_ip_netifaces(ip: str):
|
||||
target = ipaddress.ip_address(ip)
|
||||
for interface in ni.interfaces():
|
||||
addresses = ni.ifaddresses(interface)
|
||||
if ni.AF_INET in addresses:
|
||||
for link in addresses[ni.AF_INET]:
|
||||
if "addr" in link:
|
||||
addr = ipaddress.ip_address(link["addr"])
|
||||
if addr == target:
|
||||
return interface
|
||||
return None
|
||||
|
||||
|
||||
def dtype_to_mscclpp_dtype(dtype: torch.dtype) -> mscclpp.DataType:
|
||||
if dtype == torch.float16:
|
||||
return mscclpp.DataType.float16
|
||||
elif dtype == torch.float32:
|
||||
return mscclpp.DataType.float32
|
||||
elif dtype == torch.int32:
|
||||
return mscclpp.DataType.int32
|
||||
elif dtype == torch.bfloat16:
|
||||
return mscclpp.DataType.bfloat16
|
||||
else:
|
||||
raise ValueError(f"Unknown data type: {dtype}")
|
||||
|
||||
|
||||
class CustomizedComm:
|
||||
def __init__(self, comm: mscclpp_comm.CommGroup):
|
||||
self.comm = comm
|
||||
self.rank = comm.my_rank
|
||||
self.world_size = comm.nranks
|
||||
self.local_rank = comm.my_rank % comm.nranks_per_node
|
||||
self.n_ranks_per_node = comm.nranks_per_node
|
||||
self.registry = mscclpp.ExecutionPlanRegistry()
|
||||
self.executor = mscclpp.Executor(comm.communicator)
|
||||
|
||||
def all_reduce(self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM, stream: torch.cuda.Stream = None):
|
||||
assert op == torch.distributed.ReduceOp.SUM
|
||||
plan = self.registry.select(
|
||||
collective="allreduce",
|
||||
world_size=self.world_size,
|
||||
n_ranks_per_node=self.n_ranks_per_node,
|
||||
send_buffer=tensor.data_ptr(),
|
||||
recv_buffer=tensor.data_ptr(),
|
||||
message_size=tensor.numel() * tensor.element_size(),
|
||||
)
|
||||
if plan is None:
|
||||
raise ValueError(
|
||||
f"No suitable plan found for collective allreduce with message size {tensor.numel() * tensor.element_size()}"
|
||||
)
|
||||
self.executor.execute(
|
||||
self.rank,
|
||||
tensor.data_ptr(),
|
||||
tensor.data_ptr(),
|
||||
tensor.numel() * tensor.element_size(),
|
||||
tensor.numel() * tensor.element_size(),
|
||||
dtype_to_mscclpp_dtype(tensor.dtype),
|
||||
plan.plan,
|
||||
stream.cuda_stream if stream is not None else 0,
|
||||
)
|
||||
|
||||
def barrier_cpu(self):
|
||||
self.comm.barrier()
|
||||
|
||||
|
||||
def init_dist() -> CustomizedComm:
|
||||
rank = int(os.environ["RANK"])
|
||||
world = int(os.environ["WORLD_SIZE"])
|
||||
master_addr = os.environ["MSCCLPP_MASTER_ADDR"]
|
||||
master_port = os.environ["MSCCLPP_MASTER_PORT"]
|
||||
interface = interfaces_for_ip_netifaces(master_addr)
|
||||
if interface is None:
|
||||
raise ValueError(f"Cannot find network interface for IP address {master_addr}")
|
||||
registry = mscclpp.ExecutionPlanRegistry()
|
||||
setup_plan(registry, rank, world)
|
||||
registry.set_selector(selector)
|
||||
interfaceIpPortTrio = f"{interface}:{master_addr}:{master_port}"
|
||||
mscclpp_group = mscclpp_comm.CommGroup(interfaceIpPortTrio=interfaceIpPortTrio, rank=rank, size=world)
|
||||
return CustomizedComm(mscclpp_group)
|
||||
|
||||
|
||||
def main():
|
||||
local = int(os.environ["LOCAL_RANK"])
|
||||
torch.cuda.set_device(local)
|
||||
comm = init_dist()
|
||||
comm.barrier_cpu()
|
||||
buffer = mscclpp.RawGpuBuffer(24 << 20)
|
||||
dlpack = buffer.to_dlpack(data_type=str(torch.bfloat16))
|
||||
x = torch.utils.dlpack.from_dlpack(dlpack)
|
||||
x.normal_()
|
||||
comm.all_reduce(x, op=torch.distributed.ReduceOp.SUM)
|
||||
comm.barrier_cpu()
|
||||
comm = None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
117
examples/torch-integration/dsl_with_nccl_api.py
Normal file
117
examples/torch-integration/dsl_with_nccl_api.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# LD_PRELOAD=<MSCCLPP_REPO>/build/apps/nccl/libmscclpp_nccl.so torchrun --nnodes=1 --nproc_per_node=8 dsl-torch-integration/dsl_with_nccl_api.py
|
||||
|
||||
import os
|
||||
import torch, torch.distributed as dist
|
||||
import mscclpp
|
||||
from mscclpp.language.collectives import AllReduce
|
||||
from mscclpp.language.channel import SwitchChannel, MemoryChannel, BufferType, SyncType
|
||||
from mscclpp.language.program import CollectiveProgram
|
||||
from mscclpp.language.rank import Rank
|
||||
|
||||
|
||||
def allreduce_nvls(spec: mscclpp.AlgoSpec) -> CollectiveProgram:
|
||||
gpu_size = spec.world_size
|
||||
with CollectiveProgram.from_spec(spec) as program:
|
||||
# Creating Channels
|
||||
nvls_chan = SwitchChannel(rank_list=[gpu for gpu in range(gpu_size)], buffer_type=BufferType.input)
|
||||
channels = {}
|
||||
for gpu in range(gpu_size):
|
||||
for peer in range(gpu_size):
|
||||
if peer != gpu:
|
||||
channels[(peer, gpu)] = MemoryChannel(peer, gpu)
|
||||
|
||||
# Synchronization to Ensure all the Gpus are Ready
|
||||
for gpu in range(gpu_size):
|
||||
src_rank = gpu
|
||||
for peer in range(gpu_size):
|
||||
if peer != src_rank:
|
||||
dst_rank = peer
|
||||
channels[(dst_rank, src_rank)].signal(tb=0, relaxed=True)
|
||||
for peer in range(gpu_size):
|
||||
if peer != src_rank:
|
||||
dst_rank = peer
|
||||
channels[(dst_rank, src_rank)].wait(tb=0, relaxed=True, data_sync=SyncType.after)
|
||||
|
||||
# Reducing and Storing the data
|
||||
for gpu in range(gpu_size):
|
||||
buffer_offset = gpu
|
||||
rank = Rank(gpu)
|
||||
input_buffer = rank.get_input_buffer()
|
||||
nvls_chan.at_rank(gpu).reduce(
|
||||
buffer_offset=buffer_offset, size=1, dst_chunk=input_buffer[gpu : gpu + 1], tb=0
|
||||
)
|
||||
nvls_chan.at_rank(gpu).broadcast(
|
||||
src_chunk=input_buffer[gpu : gpu + 1], buffer_offset=buffer_offset, size=1, tb=0
|
||||
)
|
||||
|
||||
# Synchronization to Ensure the Gpus finished
|
||||
for gpu in range(gpu_size):
|
||||
src_rank = gpu
|
||||
for peer in range(gpu_size):
|
||||
if peer != src_rank:
|
||||
dst_rank = peer
|
||||
channels[(dst_rank, src_rank)].signal(tb=0, relaxed=True, data_sync=SyncType.before)
|
||||
for peer in range(gpu_size):
|
||||
if peer != src_rank:
|
||||
dst_rank = peer
|
||||
channels[(dst_rank, src_rank)].wait(tb=0, relaxed=True)
|
||||
|
||||
return program
|
||||
|
||||
|
||||
def setup_plan(registry: mscclpp.ExecutionPlanRegistry, rank: int, world_size: int):
|
||||
spec = mscclpp.AlgoSpec(
|
||||
name="allreduce_nvls",
|
||||
collective=AllReduce(8, 1, True),
|
||||
nranks_per_node=8,
|
||||
world_size=world_size,
|
||||
in_place=True,
|
||||
instances=2,
|
||||
protocol="Simple",
|
||||
num_threads_per_block=1024,
|
||||
min_message_size=1 << 20,
|
||||
max_message_size=48 << 30,
|
||||
tags={"nvls": 1},
|
||||
)
|
||||
|
||||
plan_handle = mscclpp.compile(algo=allreduce_nvls, algo_spec=spec, rank=rank)
|
||||
registry.register_plan(plan_handle)
|
||||
|
||||
|
||||
def selector(plans, req):
|
||||
if req.collective != "allreduce":
|
||||
return None
|
||||
if req.message_size < 1 << 20:
|
||||
return None
|
||||
nvls = [p for p in plans if "nvls" in p.tags]
|
||||
return nvls[0] if nvls else plans[0]
|
||||
|
||||
|
||||
def init_dist():
|
||||
rank = int(os.environ["RANK"])
|
||||
world = int(os.environ["WORLD_SIZE"])
|
||||
local = int(os.environ["LOCAL_RANK"])
|
||||
registry = mscclpp.ExecutionPlanRegistry()
|
||||
setup_plan(registry, rank, world)
|
||||
registry.set_selector(selector)
|
||||
dist.init_process_group(backend="nccl")
|
||||
return rank, world, local
|
||||
|
||||
|
||||
def main():
|
||||
_, _, local = init_dist()
|
||||
torch.cuda.set_device(local)
|
||||
buffer = mscclpp.RawGpuBuffer(24 << 20)
|
||||
dlpack = buffer.to_dlpack(data_type=str(torch.bfloat16))
|
||||
x = torch.utils.dlpack.from_dlpack(dlpack)
|
||||
x.normal_()
|
||||
dist.all_reduce(x, op=dist.ReduceOp.SUM)
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -11,30 +11,51 @@
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
/// Data types supported by the executor.
|
||||
enum class DataType {
|
||||
INT32,
|
||||
UINT32,
|
||||
FLOAT16,
|
||||
FLOAT32,
|
||||
BFLOAT16,
|
||||
FP8_E4M3, // Add FP8 E4M3 type
|
||||
FP8_E5M2, // Add FP8 E5M2 type
|
||||
INT32, // 32-bit signed integer.
|
||||
UINT32, // 32-bit unsigned integer.
|
||||
FLOAT16, // IEEE 754 half precision.
|
||||
FLOAT32, // IEEE 754 single precision.
|
||||
BFLOAT16, // bfloat16 precision.
|
||||
FP8_E4M3, // FP8 with E4M3 layout.
|
||||
FP8_E5M2, // FP8 with E5M2 layout.
|
||||
};
|
||||
|
||||
/// Packet formats used by low-latency transport.
|
||||
enum class PacketType {
|
||||
LL8,
|
||||
LL16,
|
||||
LL8, // 8-byte low-latency packet.
|
||||
LL16, // 16-byte low-latency packet.
|
||||
};
|
||||
|
||||
/// Represents a compiled execution plan loaded from disk.
|
||||
///
|
||||
/// An ExecutionPlan encapsulates metadata about a collective algorithm such as its name, the
|
||||
/// collective it implements, and the supported message-size range. The concrete implementation
|
||||
/// is hidden behind the PIMPL pointer.
|
||||
class ExecutionPlan {
|
||||
public:
|
||||
/// Construct an ExecutionPlan by loading the plan file at `planPath`.
|
||||
/// @param planPath Filesystem path to the serialized plan.
|
||||
/// @param rank The rank of the current process.
|
||||
ExecutionPlan(const std::string& planPath, int rank);
|
||||
|
||||
/// Destructor.
|
||||
~ExecutionPlan() = default;
|
||||
|
||||
/// Return the human-readable name of the plan.
|
||||
std::string name() const;
|
||||
|
||||
/// Return the collective implemented by this plan (e.g., "allreduce", "allgather").
|
||||
std::string collective() const;
|
||||
|
||||
/// Minimum message size (in bytes) for which this plan is valid.
|
||||
size_t minMessageSize() const;
|
||||
|
||||
/// Maximum message size (in bytes) for which this plan is valid.
|
||||
size_t maxMessageSize() const;
|
||||
|
||||
/// Whether this plan performs the operation in-place.
|
||||
bool isInPlace() const;
|
||||
|
||||
private:
|
||||
@@ -44,13 +65,125 @@ class ExecutionPlan {
|
||||
friend class Executor;
|
||||
};
|
||||
|
||||
/// Request parameters provided when executing a plan.
|
||||
struct ExecutionRequest {
|
||||
int worldSize;
|
||||
int nRanksPerNode;
|
||||
int rank;
|
||||
const void* inputBuffer;
|
||||
void* outputBuffer;
|
||||
size_t messageSize;
|
||||
const std::string& collective;
|
||||
const std::unordered_map<std::string, std::vector<uint64_t>>& hints;
|
||||
|
||||
/// Whether the request indicates an in-place operation.
|
||||
bool isInPlace() const;
|
||||
};
|
||||
|
||||
/// A handle representing a specific execution plan along with its constraints and metadata.
|
||||
struct ExecutionPlanHandle {
|
||||
/// Constraints that must be satisfied for the plan to be valid.
|
||||
struct Constraint {
|
||||
int worldSize;
|
||||
int nRanksPerNode;
|
||||
};
|
||||
|
||||
std::string id; /// Unique identifier for the handle.
|
||||
Constraint constraint; /// Constraints for plan applicability.
|
||||
std::shared_ptr<ExecutionPlan> plan; /// Backing ExecutionPlan instance.
|
||||
std::unordered_map<std::string, uint64_t> tags; /// Optional tags/metadata used by selector.
|
||||
|
||||
/// Create a new ExecutionPlanHandle.
|
||||
/// @param id Unique id for the handle.
|
||||
/// @param worldSize Required world size for the plan.
|
||||
/// @param nRanksPerNode Required ranks-per-node for the plan.
|
||||
/// @param plan The associated ExecutionPlan.
|
||||
/// @param tags Optional tags used for selection.
|
||||
static std::shared_ptr<ExecutionPlanHandle> create(const std::string& id, int worldSize, int nRanksPerNode,
|
||||
std::shared_ptr<ExecutionPlan> plan,
|
||||
const std::unordered_map<std::string, uint64_t>& tags = {});
|
||||
|
||||
/// Check whether the given ExecutionRequest satisfies this handle's parameters.
|
||||
/// @param request The execution request to evaluate.
|
||||
/// @return True if the request matches the handle parameters, false otherwise.
|
||||
bool match(const ExecutionRequest& request);
|
||||
};
|
||||
|
||||
/// Selector function type used to pick an ExecutionPlanHandle from a list of candidates.
|
||||
using ExecutionPlanSelector = std::function<std::shared_ptr<ExecutionPlanHandle>(
|
||||
const std::vector<std::shared_ptr<ExecutionPlanHandle>> plans, const ExecutionRequest& request)>;
|
||||
|
||||
/// Registry that holds available execution plans and performs selection logic.
|
||||
class ExecutionPlanRegistry {
|
||||
public:
|
||||
/// Retrieve the singleton instance of the registry.
|
||||
static std::shared_ptr<ExecutionPlanRegistry> getInstance();
|
||||
|
||||
/// Destructor.
|
||||
~ExecutionPlanRegistry();
|
||||
|
||||
/// Register a plan handle with the registry.
|
||||
void registerPlan(const std::shared_ptr<ExecutionPlanHandle> planHandle);
|
||||
|
||||
/// Get all plan handles for a given collective name.
|
||||
std::vector<std::shared_ptr<ExecutionPlanHandle>> getPlans(const std::string& collective);
|
||||
|
||||
/// Lookup a plan handle by id.
|
||||
std::shared_ptr<ExecutionPlanHandle> get(const std::string& id);
|
||||
|
||||
/// Select a suitable plan handle for the given parameters.
|
||||
std::shared_ptr<ExecutionPlanHandle> select(const std::string& collective, int worldSize, int nRanksPerNode, int rank,
|
||||
const void* sendBuffer, void* recvBuffer, size_t messageSize,
|
||||
const std::unordered_map<std::string, std::vector<uint64_t>>& hints);
|
||||
|
||||
/// Provide a custom selector function.
|
||||
void setSelector(ExecutionPlanSelector selector);
|
||||
|
||||
/// Set the default selector used when no custom selector is provided.
|
||||
void setDefaultSelector(ExecutionPlanSelector selector);
|
||||
|
||||
/// Load built-in/default plans for the given rank.
|
||||
void loadDefaultPlans(int rank);
|
||||
|
||||
/// Clear all registered plans from the registry.
|
||||
void clear();
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
ExecutionPlanRegistry();
|
||||
};
|
||||
|
||||
/// High-level executor responsible for invoking execution plans on a communicator.
|
||||
class Executor {
|
||||
public:
|
||||
Executor(std::shared_ptr<Communicator> comm);
|
||||
/// Construct an Executor using the provided communicator.
|
||||
/// @param comm Communicator instance used for underlying communication.
|
||||
/// @param defaultScratchBuffer Optional scratch buffer used by some plans (may be nullptr).
|
||||
Executor(std::shared_ptr<Communicator> comm, std::shared_ptr<char> defaultScratchBuffer = nullptr);
|
||||
|
||||
/// Copy construction is disabled for Executor.
|
||||
Executor(const Executor&) = delete;
|
||||
|
||||
/// Copy assignment is disabled for Executor.
|
||||
Executor& operator=(const Executor&) = delete;
|
||||
|
||||
/// Destructor. Cleans up internal resources held by the Executor.
|
||||
~Executor();
|
||||
|
||||
/// Execute a plan.
|
||||
///
|
||||
/// This method dispatches the given plan on the provided CUDA stream.
|
||||
///
|
||||
/// @param rank Rank of the calling process.
|
||||
/// @param sendbuff Pointer to the send buffer.
|
||||
/// @param recvBuff Pointer to the receive buffer.
|
||||
/// @param sendBuffSize Size of the send buffer in bytes.
|
||||
/// @param recvBuffSize Size of the receive buffer in bytes.
|
||||
/// @param dataType Data type of elements in the buffers.
|
||||
/// @param plan The execution plan to run.
|
||||
/// @param stream CUDA stream to execute kernels/operations on.
|
||||
/// @param packetType Packet type used for low-latency transports (default: LL16).
|
||||
void execute(int rank, void* sendbuff, void* recvBuff, size_t sendBuffSize, size_t recvBuffSize, DataType dataType,
|
||||
const ExecutionPlan& plan, cudaStream_t stream, PacketType packetType = PacketType::LL16);
|
||||
|
||||
|
||||
@@ -50,6 +50,11 @@ class ProxyService : public BaseProxyService {
|
||||
/// @return The ID of the memory region.
|
||||
MemoryId addMemory(RegisteredMemory memory);
|
||||
|
||||
/// Get the next available memory ID.
|
||||
/// @param count The number of consecutive IDs required (default: 1).
|
||||
/// @return The first ID of an available range [first, first + count).
|
||||
MemoryId nextMemoryId(uint32_t count = 1) const;
|
||||
|
||||
/// Get a semaphore by ID.
|
||||
/// @param id The ID of the semaphore.
|
||||
/// @return The semaphore.
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
add_subdirectory(mscclpp)
|
||||
add_subdirectory(csrc)
|
||||
add_subdirectory(test)
|
||||
|
||||
add_custom_target(pytest_lib_copy ALL
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_if_different
|
||||
${CMAKE_CURRENT_BINARY_DIR}/mscclpp/_mscclpp.*.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/csrc/_mscclpp.*.so
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mscclpp
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_if_different
|
||||
${CMAKE_CURRENT_BINARY_DIR}/test/_ext.*.so
|
||||
|
||||
81
python/csrc/executor_py.cpp
Normal file
81
python/csrc/executor_py.cpp
Normal file
@@ -0,0 +1,81 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/function.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <mscclpp/executor.hpp>
|
||||
#include <mscclpp/gpu.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_executor(nb::module_& m) {
|
||||
nb::enum_<DataType>(m, "DataType")
|
||||
.value("int32", DataType::INT32)
|
||||
.value("uint32", DataType::UINT32)
|
||||
.value("float16", DataType::FLOAT16)
|
||||
.value("float32", DataType::FLOAT32)
|
||||
.value("bfloat16", DataType::BFLOAT16);
|
||||
|
||||
nb::enum_<PacketType>(m, "PacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);
|
||||
|
||||
nb::class_<ExecutionRequest>(m, "ExecutionRequest")
|
||||
.def_ro("world_size", &ExecutionRequest::worldSize)
|
||||
.def_ro("n_ranks_per_node", &ExecutionRequest::nRanksPerNode)
|
||||
.def_prop_ro(
|
||||
"input_buffer",
|
||||
[](const ExecutionRequest& self) -> uintptr_t { return reinterpret_cast<uintptr_t>(self.inputBuffer); })
|
||||
.def_prop_ro(
|
||||
"output_buffer",
|
||||
[](const ExecutionRequest& self) -> uintptr_t { return reinterpret_cast<uintptr_t>(self.outputBuffer); })
|
||||
.def_ro("message_size", &ExecutionRequest::messageSize)
|
||||
.def_prop_ro("collective", [](ExecutionRequest& self) -> const std::string& { return self.collective; })
|
||||
.def_prop_ro("hints", [](ExecutionRequest& self) { return self.hints; });
|
||||
|
||||
nb::class_<ExecutionPlanHandle>(m, "ExecutionPlanHandle")
|
||||
.def_ro("id", &ExecutionPlanHandle::id)
|
||||
.def_ro("constraint", &ExecutionPlanHandle::constraint)
|
||||
.def_ro("plan", &ExecutionPlanHandle::plan)
|
||||
.def_ro("tags", &ExecutionPlanHandle::tags)
|
||||
.def_static("create", &ExecutionPlanHandle::create, nb::arg("id"), nb::arg("world_size"),
|
||||
nb::arg("nranks_per_node"), nb::arg("plan"),
|
||||
nb::arg("tags") = std::unordered_map<std::string, uint64_t>{});
|
||||
|
||||
nb::class_<ExecutionPlanHandle::Constraint>(m, "ExecutionPlanConstraint")
|
||||
.def_ro("world_size", &ExecutionPlanHandle::Constraint::worldSize)
|
||||
.def_ro("n_ranks_per_node", &ExecutionPlanHandle::Constraint::nRanksPerNode);
|
||||
|
||||
nb::class_<ExecutionPlanRegistry>(m, "ExecutionPlanRegistry")
|
||||
.def_static("get_instance", &ExecutionPlanRegistry::getInstance)
|
||||
.def("register_plan", &ExecutionPlanRegistry::registerPlan, nb::arg("planHandle"))
|
||||
.def("get_plans", &ExecutionPlanRegistry::getPlans, nb::arg("collective"))
|
||||
.def("get", &ExecutionPlanRegistry::get, nb::arg("id"))
|
||||
.def("set_selector", &ExecutionPlanRegistry::setSelector, nb::arg("selector"))
|
||||
.def("set_default_selector", &ExecutionPlanRegistry::setDefaultSelector, nb::arg("selector"))
|
||||
.def("clear", &ExecutionPlanRegistry::clear);
|
||||
|
||||
nb::class_<ExecutionPlan>(m, "ExecutionPlan")
|
||||
.def(nb::init<const std::string&, int>(), nb::arg("planPath"), nb::arg("rank"))
|
||||
.def_prop_ro("name", [](const ExecutionPlan& self) -> std::string { return self.name(); })
|
||||
.def_prop_ro("collective", [](const ExecutionPlan& self) -> std::string { return self.collective(); })
|
||||
.def_prop_ro("min_message_size", [](const ExecutionPlan& self) -> size_t { return self.minMessageSize(); })
|
||||
.def_prop_ro("max_message_size", [](const ExecutionPlan& self) -> size_t { return self.maxMessageSize(); });
|
||||
|
||||
nb::class_<Executor>(m, "Executor")
|
||||
.def(nb::init<std::shared_ptr<Communicator>>(), nb::arg("comm"))
|
||||
.def(
|
||||
"execute",
|
||||
[](Executor* self, int rank, uintptr_t sendbuff, uintptr_t recvBuff, size_t sendBuffSize, size_t recvBuffSize,
|
||||
DataType dataType, const ExecutionPlan& plan, uintptr_t stream, PacketType packetType) {
|
||||
self->execute(rank, reinterpret_cast<void*>(sendbuff), reinterpret_cast<void*>(recvBuff), sendBuffSize,
|
||||
recvBuffSize, dataType, plan, (cudaStream_t)stream, packetType);
|
||||
},
|
||||
nb::arg("rank"), nb::arg("send_buff"), nb::arg("recv_buff"), nb::arg("send_buff_size"),
|
||||
nb::arg("recv_buff_size"), nb::arg("data_type"), nb::arg("plan"), nb::arg("stream"),
|
||||
nb::arg("packet_type") = PacketType::LL16);
|
||||
}
|
||||
@@ -3,8 +3,20 @@
|
||||
|
||||
"""MSCCL++ Python API."""
|
||||
|
||||
import atexit
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property, wraps
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
import warnings
|
||||
|
||||
from blake3 import blake3
|
||||
|
||||
from mscclpp.language.program import CollectiveProgram
|
||||
from mscclpp.language.utils import AlgoSpec
|
||||
from functools import wraps
|
||||
from mscclpp._version import __version__, __commit_id__
|
||||
|
||||
@@ -49,11 +61,14 @@ from ._mscclpp import (
|
||||
DataType,
|
||||
Executor,
|
||||
ExecutionPlan,
|
||||
ExecutionPlanConstraint,
|
||||
PacketType,
|
||||
RawGpuBuffer,
|
||||
env,
|
||||
is_nvls_supported,
|
||||
npkit,
|
||||
ExecutionPlanHandle as _ExecutionPlanHandle,
|
||||
ExecutionPlanRegistry as _ExecutionPlanRegistry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -79,6 +94,9 @@ __all__ = [
|
||||
"Executor",
|
||||
"ExecutionPlan",
|
||||
"PacketType",
|
||||
"RawGpuBuffer",
|
||||
"env",
|
||||
"version",
|
||||
"is_nvls_supported",
|
||||
"alloc_shared_physical_cuda",
|
||||
"npkit",
|
||||
@@ -87,10 +105,6 @@ __all__ = [
|
||||
"version",
|
||||
"get_include",
|
||||
"get_lib",
|
||||
### Deprecated ###
|
||||
"ProxyChannel",
|
||||
"SmChannel",
|
||||
"SmDevice2DeviceSemaphore",
|
||||
]
|
||||
|
||||
|
||||
@@ -119,16 +133,193 @@ def deprecated(new_cls):
|
||||
return decorator
|
||||
|
||||
|
||||
@deprecated(PortChannel)
|
||||
class ProxyChannel(PortChannel):
|
||||
pass
|
||||
class ExecutionPlanHandle:
|
||||
|
||||
def __init__(self, handle: _ExecutionPlanHandle):
|
||||
self._handle = handle
|
||||
|
||||
@cached_property
|
||||
def id(self) -> int:
|
||||
return self._handle.id
|
||||
|
||||
@cached_property
|
||||
def tags(self) -> set:
|
||||
return frozenset(self._handle.tags)
|
||||
|
||||
@cached_property
|
||||
def plan(self) -> ExecutionPlan:
|
||||
return self._handle.plan
|
||||
|
||||
@cached_property
|
||||
def constraints(self) -> ExecutionPlanConstraint:
|
||||
return self._handle.constraints
|
||||
|
||||
|
||||
@deprecated(MemoryChannel)
|
||||
class SmChannel(MemoryChannel):
|
||||
pass
|
||||
@dataclass(frozen=True)
|
||||
class ExecutionRequest:
|
||||
collective: str
|
||||
world_size: int
|
||||
n_ranks_per_node: int
|
||||
send_buffer: int
|
||||
recv_buffer: int
|
||||
message_size: int
|
||||
hints: dict
|
||||
|
||||
|
||||
@deprecated(MemoryDevice2DeviceSemaphore)
|
||||
class SmDevice2DeviceSemaphore(MemoryDevice2DeviceSemaphore):
|
||||
pass
|
||||
class ExecutionPlanRegistry:
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(ExecutionPlanRegistry, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, "_initialized"):
|
||||
self._registry = _ExecutionPlanRegistry.get_instance()
|
||||
self._id_map = {}
|
||||
self._collective_map = {}
|
||||
self._selector = None
|
||||
self._initialized = True
|
||||
|
||||
def register_plan(self, plan: ExecutionPlanHandle):
|
||||
self._id_map[plan.id] = plan
|
||||
if plan.plan.collective not in self._collective_map:
|
||||
self._collective_map[plan.plan.collective] = []
|
||||
self._collective_map[plan.plan.collective].append(plan)
|
||||
return self._instance._registry.register_plan(plan._handle)
|
||||
|
||||
def set_selector(self, selector):
|
||||
self._selector = selector
|
||||
self._instance._registry.set_selector(selector)
|
||||
|
||||
def set_default_selector(self, selector):
|
||||
self._selector = selector
|
||||
self._instance._registry.set_default_selector(selector)
|
||||
|
||||
def get(self, id: str) -> ExecutionPlanHandle:
|
||||
return self._id_map.get(id, None)
|
||||
|
||||
def select(
|
||||
self,
|
||||
collective: str,
|
||||
world_size: int,
|
||||
n_ranks_per_node: int,
|
||||
send_buffer: int,
|
||||
recv_buffer: int,
|
||||
message_size: int,
|
||||
hints: dict = {},
|
||||
) -> ExecutionPlanHandle:
|
||||
if self._selector is None or collective not in self._collective_map:
|
||||
return None
|
||||
req = ExecutionRequest(
|
||||
collective=collective,
|
||||
world_size=world_size,
|
||||
n_ranks_per_node=n_ranks_per_node,
|
||||
send_buffer=send_buffer,
|
||||
recv_buffer=recv_buffer,
|
||||
message_size=message_size,
|
||||
hints=hints,
|
||||
)
|
||||
return self._selector(self._collective_map[collective], req)
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls):
|
||||
if cls._instance is not None:
|
||||
cls._instance._registry.clear()
|
||||
cls._instance._id_map = {}
|
||||
cls._instance._collective_map = {}
|
||||
cls._instance._selector = None
|
||||
cls._instance = None
|
||||
|
||||
|
||||
atexit.register(ExecutionPlanRegistry.reset_instance)
|
||||
|
||||
_execution_plan_registry = ExecutionPlanRegistry()
|
||||
|
||||
|
||||
def _stable_json_bytes(obj: Any) -> bytes:
|
||||
return json.dumps(
|
||||
obj,
|
||||
sort_keys=True,
|
||||
ensure_ascii=False,
|
||||
separators=(",", ":"),
|
||||
).encode("utf-8")
|
||||
|
||||
|
||||
def compile(
|
||||
algo,
|
||||
algo_spec: AlgoSpec,
|
||||
rank: int,
|
||||
**kwargs,
|
||||
) -> ExecutionPlanHandle:
|
||||
"""Compile a MSCCL++ program from a high-level algorithm description.
|
||||
Args:
|
||||
algo: The high-level algorithm description (e.g., a function or class).
|
||||
algo_spec (AlgoSpec): Algorithm specification containing collective type,
|
||||
world size, ranks per node, instances, protocol, and other configuration.
|
||||
rank (int): The rank of the current process.
|
||||
**kwargs: Additional keyword arguments passed to the algorithm function.
|
||||
Returns:
|
||||
ExecutionPlanHandle: The compiled execution plan handle.
|
||||
Raises:
|
||||
ValueError: If the 'algo' argument is not callable.
|
||||
"""
|
||||
if not callable(algo):
|
||||
raise ValueError("The 'algo' argument must be a callable (e.g., a function or class).")
|
||||
prog: CollectiveProgram = algo(
|
||||
algo_spec,
|
||||
**kwargs,
|
||||
)
|
||||
source = inspect.getsource(algo)
|
||||
|
||||
source_hash = blake3(source.encode("utf-8")).hexdigest()
|
||||
plan_id = blake3(
|
||||
_stable_json_bytes(
|
||||
{
|
||||
"version": __version__,
|
||||
"algo_name": algo_spec.name,
|
||||
"collective": algo_spec.collective.name,
|
||||
"tags": sorted(algo_spec.tags.items()),
|
||||
"source_hash": source_hash,
|
||||
"envs": {
|
||||
"nranks_per_node": algo_spec.nranks_per_node,
|
||||
"world_size": algo_spec.world_size,
|
||||
"instances": algo_spec.instances,
|
||||
"protocol": algo_spec.protocol,
|
||||
},
|
||||
}
|
||||
)
|
||||
).hexdigest()
|
||||
plan_handle = _execution_plan_registry.get(plan_id)
|
||||
if plan_handle is not None:
|
||||
return plan_handle
|
||||
|
||||
plan_dir = os.environ.get("MSCCLPP_EXECUTION_PLAN_DIR", Path.home() / ".cache/mscclpp")
|
||||
os.makedirs(plan_dir, exist_ok=True)
|
||||
filename = f"{plan_id}.json"
|
||||
plan_path = os.path.join(plan_dir, filename)
|
||||
tmp_path = plan_path + f".tmp.{os.getpid()}"
|
||||
if not os.path.exists(plan_path):
|
||||
try:
|
||||
# TODO (binyli): Each rank could generate its own execution plan separately. Doesn't need to generate whole plan.
|
||||
with open(tmp_path, "w") as f:
|
||||
prog.post_process_operations()
|
||||
f.write(prog.to_json(indent=None, separators=(",", ":"), ensure_ascii=False))
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
if not os.path.exists(plan_path):
|
||||
os.rename(tmp_path, plan_path)
|
||||
else:
|
||||
os.remove(tmp_path)
|
||||
except Exception:
|
||||
Path(plan_path).unlink(missing_ok=True)
|
||||
execution_plan = ExecutionPlan(plan_path, rank)
|
||||
handle = _ExecutionPlanHandle.create(
|
||||
id=plan_id,
|
||||
world_size=algo_spec.world_size,
|
||||
nranks_per_node=algo_spec.nranks_per_node,
|
||||
plan=execution_plan,
|
||||
tags=algo_spec.tags,
|
||||
)
|
||||
return ExecutionPlanHandle(handle)
|
||||
|
||||
77
python/mscclpp/__main__.py
Normal file
77
python/mscclpp/__main__.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from mscclpp.language import default_algos as def_algo
|
||||
from mscclpp.language.collectives import *
|
||||
from mscclpp.language.utils import AlgoSpec
|
||||
|
||||
default_algo_configs = [
|
||||
{
|
||||
"filename": "allreduce_2nodes.json",
|
||||
"function": def_algo.allreduce_2nodes,
|
||||
"spec": AlgoSpec(
|
||||
name="allreduce_2nodes",
|
||||
collective=AllReduce(16, 1, True),
|
||||
nranks_per_node=8,
|
||||
world_size=16,
|
||||
in_place=True,
|
||||
instances=1,
|
||||
protocol="LL",
|
||||
auto_sync=False,
|
||||
num_threads_per_block=1024,
|
||||
reuse_resources=True,
|
||||
use_double_scratch_buffer=True,
|
||||
min_message_size=1 << 10,
|
||||
max_message_size=2 << 20,
|
||||
tags={"default": 1},
|
||||
),
|
||||
"additional_args": [4],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def create_default_plans():
|
||||
plan_dir = os.environ.get("MSCCLPP_EXECUTION_PLAN_DIR", Path.home() / ".cache/mscclpp_default")
|
||||
plan_path = Path(plan_dir)
|
||||
if plan_path.exists():
|
||||
shutil.rmtree(plan_path)
|
||||
plan_path.mkdir(parents=True)
|
||||
|
||||
for config in default_algo_configs:
|
||||
filename = config["filename"]
|
||||
func = config["function"]
|
||||
spec = config["spec"]
|
||||
additional_args = config.get("additional_args", [])
|
||||
plan_path = os.path.join(plan_dir, filename)
|
||||
|
||||
try:
|
||||
if additional_args:
|
||||
prog = func(spec, *additional_args)
|
||||
else:
|
||||
prog = func(spec)
|
||||
|
||||
with open(plan_path, "w", encoding="utf-8") as f:
|
||||
f.write(prog.to_json())
|
||||
f.flush()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating plan for {spec.name}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--install", action="store_true", help="flag to install default plans")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.install:
|
||||
create_default_plans()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,43 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/executor.hpp>
|
||||
#include <mscclpp/gpu.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_executor(nb::module_& m) {
|
||||
nb::enum_<DataType>(m, "DataType")
|
||||
.value("int32", DataType::INT32)
|
||||
.value("uint32", DataType::UINT32)
|
||||
.value("float16", DataType::FLOAT16)
|
||||
.value("float32", DataType::FLOAT32)
|
||||
.value("bfloat16", DataType::BFLOAT16);
|
||||
|
||||
nb::enum_<PacketType>(m, "PacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);
|
||||
|
||||
nb::class_<ExecutionPlan>(m, "ExecutionPlan")
|
||||
.def(nb::init<const std::string&, int>(), nb::arg("planPath"), nb::arg("rank"))
|
||||
.def("name", &ExecutionPlan::name)
|
||||
.def("collective", &ExecutionPlan::collective)
|
||||
.def("min_message_size", &ExecutionPlan::minMessageSize)
|
||||
.def("max_message_size", &ExecutionPlan::maxMessageSize);
|
||||
|
||||
nb::class_<Executor>(m, "Executor")
|
||||
.def(nb::init<std::shared_ptr<Communicator>>(), nb::arg("comm"))
|
||||
.def(
|
||||
"execute",
|
||||
[](Executor* self, int rank, uintptr_t sendbuff, uintptr_t recvBuff, size_t sendBuffSize, size_t recvBuffSize,
|
||||
DataType dataType, const ExecutionPlan& plan, uintptr_t stream, PacketType packetType) {
|
||||
self->execute(rank, reinterpret_cast<void*>(sendbuff), reinterpret_cast<void*>(recvBuff), sendBuffSize,
|
||||
recvBuffSize, dataType, plan, (cudaStream_t)stream, packetType);
|
||||
},
|
||||
nb::arg("rank"), nb::arg("send_buff"), nb::arg("recv_buff"), nb::arg("send_buff_size"),
|
||||
nb::arg("recv_buff_size"), nb::arg("data_type"), nb::arg("plan"), nb::arg("stream"),
|
||||
nb::arg("packet_type") = PacketType::LL16);
|
||||
}
|
||||
@@ -26,6 +26,11 @@ class MemoryChannel:
|
||||
|
||||
_channel_counts = defaultdict(int)
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
"""Reset all channel counts for this channel type."""
|
||||
cls._channel_counts.clear()
|
||||
|
||||
def __init__(self, dst_rank: int, src_rank: int):
|
||||
"""Initialize a new MemoryChannel.
|
||||
|
||||
@@ -453,6 +458,11 @@ class PortChannel:
|
||||
|
||||
_channel_counts = defaultdict(int)
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
"""Reset all channel counts for this channel type."""
|
||||
cls._channel_counts.clear()
|
||||
|
||||
def __init__(self, dst_rank: int, src_rank: int):
|
||||
"""Initialize a new PortChannel.
|
||||
|
||||
@@ -741,6 +751,11 @@ class SwitchChannel:
|
||||
|
||||
_channel_counts = defaultdict(int)
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
"""Reset all channel counts for this channel type."""
|
||||
cls._channel_counts.clear()
|
||||
|
||||
def __init__(self, rank_list: List[int], buffer_type: BufferType):
|
||||
"""Initialize a new SwitchChannel.
|
||||
|
||||
|
||||
6
python/mscclpp/language/default_algos/__init__.py
Normal file
6
python/mscclpp/language/default_algos/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from mscclpp.language.default_algos.allreduce_2nodes import allreduce_2nodes
|
||||
|
||||
__all__ = ["allreduce_2nodes"]
|
||||
@@ -7,7 +7,7 @@ This implements a hierarchical AllReduce: intra-node allreduce followed by
|
||||
inter-node exchange and final intra-node allreduce.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from mscclpp.language.utils import AlgoSpec
|
||||
from mscclpp.language.channel import *
|
||||
from mscclpp.language.rank import *
|
||||
from mscclpp.language.general import *
|
||||
@@ -15,9 +15,7 @@ from mscclpp.language.program import *
|
||||
from mscclpp.language.collectives import *
|
||||
|
||||
|
||||
def allreduce_example(
|
||||
program_name, gpus_per_node, thread_block_group_size, num_threads_per_block, min_message_size, max_message_size
|
||||
):
|
||||
def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> CollectiveProgram:
|
||||
"""
|
||||
Implements a multi-node AllReduce using a hierarchical approach:
|
||||
1. Intra-node allreduce
|
||||
@@ -26,24 +24,11 @@ def allreduce_example(
|
||||
"""
|
||||
# Configuration constants
|
||||
num_nodes = 2
|
||||
gpus_per_node = spec.nranks_per_node
|
||||
total_gpus = num_nodes * gpus_per_node
|
||||
chunks_per_loop = 1
|
||||
packets_per_gpu = 2 # Each GPU handles 2 data packets
|
||||
packets_per_gpu = 2
|
||||
|
||||
# Initialize collective operation
|
||||
collective = AllReduce(total_gpus, chunks_per_loop, True)
|
||||
|
||||
with CollectiveProgram(
|
||||
program_name,
|
||||
collective,
|
||||
total_gpus,
|
||||
protocol="LL",
|
||||
num_threads_per_block=num_threads_per_block,
|
||||
reuse_resources=False,
|
||||
use_double_scratch_buffer=True,
|
||||
min_message_size=min_message_size,
|
||||
max_message_size=max_message_size,
|
||||
):
|
||||
with CollectiveProgram.from_spec(spec) as prog:
|
||||
# Initialize communication channels and buffers
|
||||
intra_node_memory_channels = {}
|
||||
inter_node_port_channels = {}
|
||||
@@ -175,25 +160,4 @@ def allreduce_example(
|
||||
tb_group=thread_block_group,
|
||||
)
|
||||
|
||||
print(JSON())
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--name", type=str, help="name of the program")
|
||||
parser.add_argument("--gpus_per_node", type=int, help="number of gpus per node")
|
||||
parser.add_argument("--tbg_size", type=int, help="number of thread blocks in the thread block group")
|
||||
parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block")
|
||||
parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size")
|
||||
parser.add_argument("--max_message_size", type=int, default=2 * 2**20, help="maximum message size")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
allreduce_example(
|
||||
args.name,
|
||||
args.gpus_per_node,
|
||||
args.tbg_size,
|
||||
args.num_threads_per_block,
|
||||
args.min_message_size,
|
||||
args.max_message_size,
|
||||
)
|
||||
return prog
|
||||
@@ -16,5 +16,4 @@ def JSON():
|
||||
str: A JSON string representation of the current MSCCL++ program,
|
||||
including all ranks, operations, channels, and configuration.
|
||||
"""
|
||||
get_program().post_process_operations()
|
||||
return get_program().to_json()
|
||||
|
||||
@@ -6,6 +6,7 @@ from mscclpp.language.internal.optimizer import *
|
||||
from mscclpp.language.internal.buffer_access import *
|
||||
from dataclasses import dataclass, field
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -88,7 +89,7 @@ class ThreadBlock:
|
||||
@dataclass
|
||||
class Channel:
|
||||
channel_type: ChannelType
|
||||
channel_ids: list[int] = field(default_factory=list)
|
||||
channel_ids: List[int] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"channel_type": self.channel_type.value, "channel_ids": self.channel_ids}
|
||||
@@ -96,7 +97,7 @@ class ThreadBlock:
|
||||
@dataclass
|
||||
class RemoteBuffer:
|
||||
access_channel_type: ChannelType
|
||||
remote_buffer_ids: list[int] = field(default_factory=list)
|
||||
remote_buffer_ids: List[int] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Set
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class SyncType(Enum):
|
||||
|
||||
@@ -3,8 +3,12 @@
|
||||
|
||||
from mscclpp.language.collectives import Collective
|
||||
from mscclpp.language.internal.globals import set_program
|
||||
from mscclpp.language.internal.types import BufferType, RemoteBuffer, ChannelType, ReplicationPolicy
|
||||
from mscclpp.language.internal.types import BufferType, RemoteBuffer, ChannelType
|
||||
from mscclpp.language.internal.gpu import Gpu
|
||||
from mscclpp.language.channel import *
|
||||
from mscclpp.language.rank import Semaphore
|
||||
from mscclpp.language.collectives import *
|
||||
from mscclpp.language.utils import AlgoSpec, ReplicationPolicy
|
||||
from typing import List
|
||||
import json
|
||||
|
||||
@@ -108,6 +112,55 @@ class CollectiveProgram:
|
||||
|
||||
self.loop_context = None
|
||||
|
||||
@classmethod
|
||||
def from_spec(cls, spec: AlgoSpec):
|
||||
"""Initialize a new CollectiveProgram from an algorithm specification.
|
||||
|
||||
This constructor provides an alternative way to create a CollectiveProgram
|
||||
using an AlgoSpec object, which contains the complete algorithm specification
|
||||
including collective instance, protocol parameters, and optimization settings.
|
||||
The collective operation is directly provided through the spec's collective attribute.
|
||||
|
||||
Args:
|
||||
spec (AlgoSpec): Algorithm specification containing all program parameters
|
||||
and configuration settings, including a Collective instance.
|
||||
|
||||
Raises:
|
||||
AssertionError: If protocol is not "Simple" or "LL".
|
||||
|
||||
Example:
|
||||
>>> from mscclpp.language.utils import AlgoSpec
|
||||
>>> from mscclpp.language.collectives import AllReduce
|
||||
>>> collective = AllReduce(num_ranks=4, chunk_factor=1, inplace=False)
|
||||
>>> spec = AlgoSpec(
|
||||
... name="my_allreduce",
|
||||
... collective=collective,
|
||||
... world_size=4,
|
||||
... instances=1,
|
||||
... protocol="Simple",
|
||||
... in_place=False
|
||||
... )
|
||||
>>> with CollectiveProgram.from_spec(spec) as prog:
|
||||
... # Define communication operations
|
||||
... pass
|
||||
"""
|
||||
return cls(
|
||||
spec.name,
|
||||
spec.collective,
|
||||
spec.world_size,
|
||||
instances=spec.instances,
|
||||
protocol=spec.protocol,
|
||||
instr_fusion=spec.instr_fusion,
|
||||
auto_sync=spec.auto_sync,
|
||||
replication_policy=spec.replication_policy,
|
||||
reuse_resources=spec.reuse_resources,
|
||||
num_threads_per_block=spec.num_threads_per_block,
|
||||
use_double_scratch_buffer=spec.use_double_scratch_buffer,
|
||||
buffer_alignment=spec.buffer_alignment,
|
||||
min_message_size=spec.min_message_size,
|
||||
max_message_size=spec.max_message_size,
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter the program context and set this as the active program.
|
||||
|
||||
@@ -115,6 +168,7 @@ class CollectiveProgram:
|
||||
this program as the active program in the global context.
|
||||
"""
|
||||
set_program(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""Exit the program context and clear the active program.
|
||||
@@ -122,6 +176,10 @@ class CollectiveProgram:
|
||||
This method is called when exiting the 'with' statement and removes
|
||||
this program from the global context.
|
||||
"""
|
||||
MemoryChannel.reset()
|
||||
PortChannel.reset()
|
||||
SwitchChannel.reset()
|
||||
Semaphore.reset()
|
||||
set_program(None)
|
||||
|
||||
def add_channel(self, channel):
|
||||
@@ -175,7 +233,8 @@ class CollectiveProgram:
|
||||
raise RuntimeError("Nested Pipelines are not Supported.")
|
||||
self.loop_context = loop_context
|
||||
|
||||
def to_json(self):
|
||||
def to_json(self, indent=2, **kwargs):
|
||||
self.post_process_operations()
|
||||
json_obj = {
|
||||
"name": self.name,
|
||||
"collective": self.collective.name,
|
||||
@@ -190,4 +249,4 @@ class CollectiveProgram:
|
||||
"max_message_size": self.max_message_size,
|
||||
}
|
||||
|
||||
return json.dumps(json_obj, indent=2)
|
||||
return json.dumps(json_obj, indent=indent, **kwargs)
|
||||
|
||||
@@ -367,6 +367,11 @@ class Semaphore:
|
||||
|
||||
_semaphore_counts = defaultdict(int)
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
"""Reset all semaphore counts."""
|
||||
cls._semaphore_counts.clear()
|
||||
|
||||
def __init__(self, rank: int, initial_value: int):
|
||||
"""Initialize a new Semaphore.
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ def allreduce_example(name, gpu_size, num_threads_per_block, min_message_size, m
|
||||
input_buffer = rank.get_input_buffer()
|
||||
for peer in range(gpu_size):
|
||||
if peer != gpu:
|
||||
channels[(peer, gpu)].put_packet(
|
||||
channels[(peer, gpu)].put_packets(
|
||||
scratch_buffer[peer][gpu : gpu + 1], input_buffer[peer : peer + 1], 0
|
||||
)
|
||||
|
||||
@@ -55,7 +55,7 @@ def allreduce_example(name, gpu_size, num_threads_per_block, min_message_size, m
|
||||
rank.reduce(input_buffer[gpu : gpu + 1], chunks, 0, packet=True)
|
||||
for peer in range(gpu_size):
|
||||
if peer != gpu:
|
||||
channels[(peer, gpu)].put_packet(
|
||||
channels[(peer, gpu)].put_packets(
|
||||
scratch_buffer[peer][gpu_size + gpu : gpu_size + gpu + 1], input_buffer[gpu : gpu + 1], 0
|
||||
)
|
||||
|
||||
@@ -65,7 +65,7 @@ def allreduce_example(name, gpu_size, num_threads_per_block, min_message_size, m
|
||||
input_buffer = rank.get_input_buffer()
|
||||
for peer in range(gpu_size):
|
||||
if peer != gpu:
|
||||
rank.unpack_packet(
|
||||
rank.unpack_packets(
|
||||
input_buffer[peer : peer + 1], scratch_buffer[gpu][gpu_size + peer : gpu_size + peer + 1], 0
|
||||
)
|
||||
|
||||
|
||||
35
python/mscclpp/language/utils.py
Normal file
35
python/mscclpp/language/utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, field
|
||||
from mscclpp.language.collectives import Collective
|
||||
|
||||
|
||||
class ReplicationPolicy(Enum):
|
||||
interleaved = "interleaved"
|
||||
none = "none"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AlgoSpec:
|
||||
name: str
|
||||
collective: Collective
|
||||
nranks_per_node: int
|
||||
world_size: int
|
||||
in_place: bool
|
||||
instances: int
|
||||
protocol: str
|
||||
instr_fusion: bool = True
|
||||
auto_sync: bool = True
|
||||
replication_policy: ReplicationPolicy = ReplicationPolicy.interleaved
|
||||
reuse_resources: bool = False
|
||||
num_threads_per_block: int = 1024
|
||||
use_double_scratch_buffer: bool = False
|
||||
buffer_alignment: int = 16
|
||||
min_message_size: int = 0
|
||||
max_message_size: int = 2**64 - 1
|
||||
tags: dict = field(default_factory=dict)
|
||||
@@ -6,3 +6,4 @@ pytest
|
||||
numpy
|
||||
matplotlib
|
||||
sortedcontainers @ git+https://github.com/grantjenks/python-sortedcontainers.git@3ac358631f58c1347f1d6d2d92784117db0f38ed
|
||||
blake3
|
||||
@@ -5,4 +5,5 @@ netifaces
|
||||
pytest
|
||||
numpy
|
||||
matplotlib
|
||||
sortedcontainers @ git+https://github.com/grantjenks/python-sortedcontainers.git@3ac358631f58c1347f1d6d2d92784117db0f38ed
|
||||
sortedcontainers @ git+https://github.com/grantjenks/python-sortedcontainers.git@3ac358631f58c1347f1d6d2d92784117db0f38ed
|
||||
blake3
|
||||
@@ -187,7 +187,7 @@ def main(
|
||||
if npkit_dump_dir != "":
|
||||
npkit.init(mscclpp_group.my_rank)
|
||||
execution_plan = ExecutionPlan(execution_plan_path, mscclpp_group.my_rank)
|
||||
collective = execution_plan.collective()
|
||||
collective = execution_plan.collective
|
||||
|
||||
dtype = parse_dtype(dtype_str)
|
||||
input_buf, result_buf, test_buf = build_bufs(
|
||||
|
||||
15
src/env.cpp
15
src/env.cpp
@@ -12,8 +12,8 @@
|
||||
#include "debug.h"
|
||||
|
||||
template <typename T>
|
||||
T readEnv(const std::string &envName, const T &defaultValue) {
|
||||
const char *envCstr = getenv(envName.c_str());
|
||||
T readEnv(const std::string& envName, const T& defaultValue) {
|
||||
const char* envCstr = getenv(envName.c_str());
|
||||
if (envCstr == nullptr) return defaultValue;
|
||||
if constexpr (std::is_same_v<T, int>) {
|
||||
return atoi(envCstr);
|
||||
@@ -24,8 +24,8 @@ T readEnv(const std::string &envName, const T &defaultValue) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void readAndSetEnv(const std::string &envName, T &env) {
|
||||
const char *envCstr = getenv(envName.c_str());
|
||||
void readAndSetEnv(const std::string& envName, T& env) {
|
||||
const char* envCstr = getenv(envName.c_str());
|
||||
if (envCstr == nullptr) return;
|
||||
if constexpr (std::is_same_v<T, int>) {
|
||||
env = atoi(envCstr);
|
||||
@@ -37,13 +37,13 @@ void readAndSetEnv(const std::string &envName, T &env) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void logEnv(const std::string &envName, const T &env) {
|
||||
void logEnv(const std::string& envName, const T& env) {
|
||||
if (!getenv(envName.c_str())) return;
|
||||
INFO(MSCCLPP_ENV, "%s=%d", envName.c_str(), env);
|
||||
}
|
||||
|
||||
template <>
|
||||
void logEnv(const std::string &envName, const std::string &env) {
|
||||
void logEnv(const std::string& envName, const std::string& env) {
|
||||
if (!getenv(envName.c_str())) return;
|
||||
INFO(MSCCLPP_ENV, "%s=%s", envName.c_str(), env.c_str());
|
||||
}
|
||||
@@ -59,7 +59,8 @@ Env::Env()
|
||||
socketFamily(readEnv<std::string>("MSCCLPP_SOCKET_FAMILY", "")),
|
||||
socketIfname(readEnv<std::string>("MSCCLPP_SOCKET_IFNAME", "")),
|
||||
commId(readEnv<std::string>("MSCCLPP_COMM_ID", "")),
|
||||
executionPlanDir(readEnv<std::string>("MSCCLPP_EXECUTION_PLAN_DIR", "")),
|
||||
executionPlanDir(readEnv<std::string>("MSCCLPP_EXECUTION_PLAN_DIR",
|
||||
readEnv<std::string>("HOME", "~") + "/.cache/mscclpp_default")),
|
||||
npkitDumpDir(readEnv<std::string>("MSCCLPP_NPKIT_DUMP_DIR", "")),
|
||||
cudaIpcUseDefaultStream(readEnv<bool>("MSCCLPP_CUDAIPC_USE_DEFAULT_STREAM", false)),
|
||||
ncclSharedLibPath(readEnv<std::string>("MSCCLPP_NCCL_LIB_PATH", "")),
|
||||
|
||||
@@ -8,13 +8,14 @@ namespace mscclpp {
|
||||
|
||||
template <typename PacketType, bool ReuseScratch>
|
||||
void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch,
|
||||
uint32_t scratchOffset, uint32_t scrachChunkSize, DataType dataType,
|
||||
DeviceExecutionPlan* plan, DeviceSemaphore* semaphores, uint32_t sharedMemSize,
|
||||
cudaStream_t stream, uint32_t flag) {
|
||||
uint32_t scratchOffset, uint32_t scratchChunkSize, DataType dataType,
|
||||
DeviceExecutionPlan* plan, DeviceSemaphore* semaphores, uint32_t localMemoryIdBegin,
|
||||
uint32_t sharedMemSize, cudaStream_t stream, uint32_t flag) {
|
||||
switch (dataType) {
|
||||
case DataType::INT32:
|
||||
executionKernel<int32_t, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchOffset, scrachChunkSize, plan, semaphores, flag
|
||||
rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchOffset, scratchChunkSize, plan, semaphores,
|
||||
localMemoryIdBegin, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
@@ -24,8 +25,8 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
|
||||
break;
|
||||
case DataType::UINT32:
|
||||
executionKernel<uint32_t, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchOffset, scrachChunkSize, plan, semaphores,
|
||||
flag
|
||||
rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchOffset, scratchChunkSize, plan, semaphores,
|
||||
localMemoryIdBegin, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
@@ -35,7 +36,8 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
|
||||
break;
|
||||
case DataType::FLOAT16:
|
||||
executionKernel<half, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (half*)src, (half*)dst, (half*)scratch, scratchOffset, scrachChunkSize, plan, semaphores, flag
|
||||
rank, (half*)src, (half*)dst, (half*)scratch, scratchOffset, scratchChunkSize, plan, semaphores,
|
||||
localMemoryIdBegin, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
@@ -45,7 +47,8 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
|
||||
break;
|
||||
case DataType::FLOAT32:
|
||||
executionKernel<float, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (float*)src, (float*)dst, (float*)scratch, scratchOffset, scrachChunkSize, plan, semaphores, flag
|
||||
rank, (float*)src, (float*)dst, (float*)scratch, scratchOffset, scratchChunkSize, plan, semaphores,
|
||||
localMemoryIdBegin, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
@@ -55,8 +58,8 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
|
||||
break;
|
||||
case DataType::BFLOAT16:
|
||||
executionKernel<__bfloat16, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, scratchOffset, scrachChunkSize, plan,
|
||||
semaphores, flag
|
||||
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, scratchOffset, scratchChunkSize, plan,
|
||||
semaphores, localMemoryIdBegin, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
@@ -67,11 +70,11 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_LAUNCH(PKT, REUSE) \
|
||||
template void ExecutionKernel::launchKernel<PKT, REUSE>( \
|
||||
int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch, uint32_t scatchOffset, \
|
||||
uint32_t scratchChunkSize, DataType dataType, DeviceExecutionPlan* plan, DeviceSemaphore* semaphores, \
|
||||
uint32_t sharedMemSize, cudaStream_t stream, uint32_t flag);
|
||||
#define INSTANTIATE_LAUNCH(PKT, REUSE) \
|
||||
template void ExecutionKernel::launchKernel<PKT, REUSE>( \
|
||||
int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch, uint32_t scratchOffset, \
|
||||
uint32_t scratchChunkSize, DataType dataType, DeviceExecutionPlan* plan, DeviceSemaphore* semaphores, \
|
||||
uint32_t localMemoryIdBegin, uint32_t sharedMemSize, cudaStream_t stream, uint32_t flag);
|
||||
|
||||
INSTANTIATE_LAUNCH(LL16Packet, true)
|
||||
INSTANTIATE_LAUNCH(LL8Packet, true)
|
||||
|
||||
@@ -4,10 +4,30 @@
|
||||
#include "execution_plan.hpp"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
|
||||
#include "debug.h"
|
||||
|
||||
namespace {
|
||||
|
||||
static const std::vector<mscclpp::AlgoConfig> defaultAlgoConfigs = {
|
||||
{"allreduce_2nodes.json", "allreduce", 8, 16, {{"default", 1}}}};
|
||||
|
||||
std::string simpleHash(const std::string& input) {
|
||||
std::hash<std::string> hasher;
|
||||
size_t hashValue = hasher(input);
|
||||
std::ostringstream oss;
|
||||
oss << std::hex << hashValue;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
std::string generateFileId(const std::string& filePath) { return simpleHash(filePath); }
|
||||
|
||||
template <typename T, typename Predicate>
|
||||
std::vector<T> filter(const std::vector<T>& vec, Predicate pred) {
|
||||
std::vector<T> filtered;
|
||||
@@ -69,7 +89,7 @@ auto getOpType = [](const std::string& str) {
|
||||
} else if (str == "sem_release") {
|
||||
return mscclpp::OperationType::SEM_RELEASE;
|
||||
} else {
|
||||
throw mscclpp::Error("Invalid operation type", mscclpp::ErrorCode::ExecutorError);
|
||||
throw mscclpp::Error("Invalid operation type: " + str, mscclpp::ErrorCode::ExecutorError);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -684,4 +704,149 @@ size_t ExecutionPlan::maxMessageSize() const { return this->impl_->maxMessageSiz
|
||||
|
||||
bool ExecutionPlan::isInPlace() const { return this->impl_->isInPlace; }
|
||||
|
||||
void ExecutionPlanRegistry::Impl::setSelector(ExecutionPlanSelector selector) { selector_ = selector; }
|
||||
|
||||
void ExecutionPlanRegistry::Impl::setDefaultSelector(ExecutionPlanSelector selector) { defaultSelector_ = selector; }
|
||||
|
||||
std::shared_ptr<ExecutionPlanHandle> ExecutionPlanRegistry::Impl::select(const ExecutionRequest& request) {
|
||||
std::vector<std::shared_ptr<ExecutionPlanHandle>> plans;
|
||||
for (auto plan : planMap_[request.collective]) {
|
||||
if (plan->match(request)) {
|
||||
plans.push_back(plan);
|
||||
}
|
||||
}
|
||||
if (selector_) {
|
||||
auto plan = selector_(plans, request);
|
||||
if (plan) {
|
||||
return plan;
|
||||
}
|
||||
}
|
||||
if (defaultSelector_) {
|
||||
auto plan = defaultSelector_(plans, request);
|
||||
if (plan) {
|
||||
return plan;
|
||||
}
|
||||
}
|
||||
INFO(MSCCLPP_EXECUTOR, "No suitable execution plan found for collective: %s", request.collective.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void ExecutionPlanRegistry::Impl::registerPlan(const std::shared_ptr<ExecutionPlanHandle> planHandle) {
|
||||
if (!planHandle) {
|
||||
throw Error("Cannot register a null plan", ErrorCode::ExecutorError);
|
||||
}
|
||||
planMap_[planHandle->plan->collective()].push_back(planHandle);
|
||||
idMap_[planHandle->id] = planHandle;
|
||||
}
|
||||
|
||||
void ExecutionPlanRegistry::Impl::loadDefaultPlans(int rank) {
|
||||
std::string planDir = mscclpp::env()->executionPlanDir;
|
||||
if (!std::filesystem::exists(planDir)) {
|
||||
INFO(MSCCLPP_EXECUTOR, "Plan directory does not exist: %s", planDir.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
for (const auto& config : defaultAlgoConfigs) {
|
||||
std::string planPath = planDir + "/" + config.filename;
|
||||
INFO(MSCCLPP_EXECUTOR, "Loading plan: %s", planPath.c_str());
|
||||
if (!std::filesystem::exists(planPath)) {
|
||||
INFO(MSCCLPP_EXECUTOR, "Plan file does not exist: %s", planPath.c_str());
|
||||
continue;
|
||||
}
|
||||
std::string planId = generateFileId(planPath);
|
||||
if (idMap_.find(planId) != idMap_.end()) {
|
||||
INFO(MSCCLPP_EXECUTOR, "Plan already registered: %s", planId.c_str());
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
auto executionPlan = std::make_shared<ExecutionPlan>(planPath, rank);
|
||||
auto handle =
|
||||
ExecutionPlanHandle::create(planId, config.worldSize, config.nRanksPerNode, executionPlan, config.tags);
|
||||
registerPlan(handle);
|
||||
INFO(MSCCLPP_EXECUTOR, "Successfully loaded plan: %s for collective: %s", planId.c_str(),
|
||||
config.collective.c_str());
|
||||
} catch (const std::exception& e) {
|
||||
WARN("Failed to load plan %s: %s", planPath.c_str(), e.what());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ExecutionPlanRegistry> ExecutionPlanRegistry::getInstance() {
|
||||
static std::shared_ptr<ExecutionPlanRegistry> instance(new ExecutionPlanRegistry);
|
||||
return instance;
|
||||
}
|
||||
|
||||
void ExecutionPlanRegistry::registerPlan(const std::shared_ptr<ExecutionPlanHandle> planHandle) {
|
||||
impl_->registerPlan(planHandle);
|
||||
}
|
||||
|
||||
void ExecutionPlanRegistry::setSelector(ExecutionPlanSelector selector) { impl_->setSelector(selector); }
|
||||
|
||||
void ExecutionPlanRegistry::setDefaultSelector(ExecutionPlanSelector selector) { impl_->setDefaultSelector(selector); }
|
||||
|
||||
std::shared_ptr<ExecutionPlanHandle> ExecutionPlanRegistry::select(
|
||||
const std::string& collective, int worldSize, int nRanksPerNode, int rank, const void* sendBuffer, void* recvBuffer,
|
||||
size_t messageSize, const std::unordered_map<std::string, std::vector<uint64_t>>& hints) {
|
||||
ExecutionRequest request{worldSize, nRanksPerNode, rank, sendBuffer, recvBuffer, messageSize, collective, hints};
|
||||
return impl_->select(request);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<ExecutionPlanHandle>> ExecutionPlanRegistry::getPlans(const std::string& collective) {
|
||||
if (impl_->planMap_.find(collective) != impl_->planMap_.end()) {
|
||||
return impl_->planMap_[collective];
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
std::shared_ptr<ExecutionPlanHandle> ExecutionPlanRegistry::get(const std::string& id) {
|
||||
if (impl_->idMap_.find(id) != impl_->idMap_.end()) {
|
||||
return impl_->idMap_[id];
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ExecutionPlanRegistry::ExecutionPlanRegistry() : impl_(std::make_unique<Impl>()) {}
|
||||
|
||||
ExecutionPlanRegistry::~ExecutionPlanRegistry() = default;
|
||||
|
||||
void ExecutionPlanRegistry::clear() {
|
||||
impl_->planMap_.clear();
|
||||
impl_->idMap_.clear();
|
||||
impl_->selector_ = nullptr;
|
||||
impl_->defaultSelector_ = nullptr;
|
||||
}
|
||||
|
||||
void ExecutionPlanRegistry::loadDefaultPlans(int rank) { impl_->loadDefaultPlans(rank); }
|
||||
|
||||
bool ExecutionRequest::isInPlace() const {
|
||||
if (inputBuffer == outputBuffer) return true;
|
||||
if (collective == "allgather") {
|
||||
size_t rankOffset = rank * messageSize;
|
||||
const char* expectedInput = static_cast<const char*>(outputBuffer) + rankOffset;
|
||||
return static_cast<const void*>(expectedInput) == inputBuffer;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<ExecutionPlanHandle> ExecutionPlanHandle::create(
|
||||
const std::string& id, int worldSize, int nRanksPerNode, std::shared_ptr<ExecutionPlan> plan,
|
||||
const std::unordered_map<std::string, uint64_t>& tags) {
|
||||
std::shared_ptr<ExecutionPlanHandle> handle(new ExecutionPlanHandle{id, {worldSize, nRanksPerNode}, plan, tags});
|
||||
return handle;
|
||||
}
|
||||
|
||||
bool ExecutionPlanHandle::match(const ExecutionRequest& request) {
|
||||
bool worldSizeMatch = constraint.worldSize == request.worldSize;
|
||||
bool ranksPerNodeMatch = constraint.nRanksPerNode == request.nRanksPerNode;
|
||||
bool collectiveMatch = plan->collective() == request.collective;
|
||||
bool inPlaceMatch = plan->isInPlace() == request.isInPlace();
|
||||
size_t effectiveSize =
|
||||
(request.collective == "allgather") ? (request.messageSize * request.worldSize) : request.messageSize;
|
||||
bool minSizeMatch = effectiveSize >= plan->minMessageSize();
|
||||
bool maxSizeMatch = effectiveSize <= plan->maxMessageSize();
|
||||
|
||||
bool result = worldSizeMatch && ranksPerNodeMatch && collectiveMatch && inPlaceMatch && minSizeMatch && maxSizeMatch;
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -114,6 +114,7 @@ struct ExecutionContext {
|
||||
std::shared_ptr<ProxyService> proxyService;
|
||||
std::unordered_map<int, std::shared_ptr<Connection>> connections;
|
||||
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections;
|
||||
MemoryId localMemoryIdBegin = MemoryId(0);
|
||||
|
||||
// For registered memories, registeredMemoryAddresses is used for memoryChannel and registeredMemoryIds is used for
|
||||
// proxy channel
|
||||
@@ -144,17 +145,24 @@ struct Executor::Impl {
|
||||
int nranksPerNode;
|
||||
int nranks;
|
||||
std::shared_ptr<Communicator> comm;
|
||||
const size_t defaultScratchBufferSize = (1 << 27);
|
||||
std::shared_ptr<char> defaultScratchBuffer;
|
||||
std::shared_ptr<ProxyService> proxyService;
|
||||
std::unordered_map<ExecutionContextKey, ExecutionContext> contexts;
|
||||
|
||||
Impl(std::shared_ptr<Communicator> comm) : comm(comm) {
|
||||
Impl(std::shared_ptr<Communicator> comm, std::shared_ptr<char> defaultScratchBuffer = nullptr)
|
||||
: comm(comm), defaultScratchBuffer(defaultScratchBuffer) {
|
||||
this->nranksPerNode = comm->bootstrap()->getNranksPerNode();
|
||||
this->nranks = comm->bootstrap()->getNranks();
|
||||
this->proxyService = std::make_shared<ProxyService>();
|
||||
this->proxyService->startProxy();
|
||||
}
|
||||
~Impl() = default;
|
||||
|
||||
ExecutionContext setupExecutionContext(int rank, void* sendbuff, void* recvbuff, size_t inputMessageSize,
|
||||
size_t outputMessageSize, size_t constSrcOffset, size_t constDstOffset,
|
||||
size_t sendMemRange, size_t recvMemRange, const ExecutionPlan& plan) {
|
||||
size_t sendMemRange, size_t recvMemRange, const ExecutionPlan& plan,
|
||||
std::shared_ptr<ProxyService> proxyService) {
|
||||
ExecutionContextKey key = {sendbuff, recvbuff, sendMemRange, recvMemRange, plan.impl_->name};
|
||||
DeviceExecutionPlanKey devicePlanKey = {inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset};
|
||||
|
||||
@@ -188,17 +196,14 @@ struct Executor::Impl {
|
||||
ExecutionContext context;
|
||||
context.reuseResources = plan.impl_->reuseResources;
|
||||
context.doubleScratchBuff = plan.impl_->doubleScratchBuffer;
|
||||
size_t scratchBufferSize = plan.impl_->calScratchBufferSize(std::min(sendMemRange, plan.impl_->maxMessageSize),
|
||||
std::min(recvMemRange, plan.impl_->maxMessageSize));
|
||||
context.scratchChunkSize = plan.impl_->calMaxScratchChunkSize(scratchBufferSize);
|
||||
context.scratchBuffer = GpuBuffer(scratchBufferSize).memory();
|
||||
context.scratchBufferSize = scratchBufferSize;
|
||||
context.proxyService = std::make_shared<ProxyService>();
|
||||
context.proxyService = proxyService;
|
||||
context.nthreadsPerBlock = plan.impl_->nThreadsPerBlock;
|
||||
this->setupConnections(context, rank, sendMemRange, recvMemRange, scratchBufferSize, plan);
|
||||
this->setupScratchBuffer(context, sendMemRange, recvMemRange, plan);
|
||||
this->setupConnections(context, rank, sendMemRange, recvMemRange, context.scratchBufferSize, plan);
|
||||
this->setupChannels(context, plan);
|
||||
this->setupRegisteredMemories(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan);
|
||||
this->setupNvlsChannels(context, sendbuff, recvbuff, rank, sendMemRange, recvMemRange, scratchBufferSize, plan);
|
||||
this->setupNvlsChannels(context, sendbuff, recvbuff, rank, sendMemRange, recvMemRange, context.scratchBufferSize,
|
||||
plan);
|
||||
this->setupSemaphores(context, plan);
|
||||
this->setupDeviceExecutionPlan(context, devicePlanKey, plan);
|
||||
context.deviceExecutionPlansBuffers[devicePlanKey] =
|
||||
@@ -207,7 +212,6 @@ struct Executor::Impl {
|
||||
(char*)context.deviceExecutionPlans[devicePlanKey].data(),
|
||||
context.deviceExecutionPlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan), cudaMemcpyHostToDevice);
|
||||
context.currentDevicePlan = devicePlanKey;
|
||||
context.proxyService->startProxy();
|
||||
this->contexts.insert({key, context});
|
||||
return context;
|
||||
}
|
||||
@@ -227,6 +231,29 @@ struct Executor::Impl {
|
||||
return flags;
|
||||
};
|
||||
|
||||
void setupScratchBuffer(ExecutionContext& context, size_t sendBuffSize, size_t recvBuffSize,
|
||||
const ExecutionPlan& plan) {
|
||||
size_t scratchBufferSize = plan.impl_->calScratchBufferSize(std::min(sendBuffSize, plan.impl_->maxMessageSize),
|
||||
std::min(recvBuffSize, plan.impl_->maxMessageSize));
|
||||
context.scratchChunkSize = plan.impl_->calMaxScratchChunkSize(scratchBufferSize);
|
||||
if (plan.impl_->reuseResources) {
|
||||
if (this->defaultScratchBuffer == nullptr) {
|
||||
this->defaultScratchBuffer = GpuBuffer(this->defaultScratchBufferSize).memory();
|
||||
}
|
||||
if (scratchBufferSize > this->defaultScratchBufferSize) {
|
||||
throw Error("Scratch buffer size (" + std::to_string(scratchBufferSize) +
|
||||
" bytes) exceeds default buffer size (" + std::to_string(this->defaultScratchBufferSize) +
|
||||
" bytes). Consider increasing the default scratch buffer size or disabling resource reuse.",
|
||||
ErrorCode::ExecutorError);
|
||||
}
|
||||
context.scratchBufferSize = this->defaultScratchBufferSize;
|
||||
context.scratchBuffer = this->defaultScratchBuffer;
|
||||
} else {
|
||||
context.scratchBufferSize = scratchBufferSize;
|
||||
context.scratchBuffer = GpuBuffer(scratchBufferSize).memory();
|
||||
}
|
||||
}
|
||||
|
||||
void setupConnections(ExecutionContext& context, int rank, size_t sendBuffSize, size_t recvBuffSize,
|
||||
size_t scratchBuffSize, const ExecutionPlan& plan) {
|
||||
auto getBufferSize = [&](BufferType bufferType) {
|
||||
@@ -264,6 +291,7 @@ struct Executor::Impl {
|
||||
void setupRegisteredMemories(ExecutionContext& context, void* sendbuff, void* recvbuff, size_t sendBufferSize,
|
||||
size_t recvBufferSize, int rank, const ExecutionPlan& plan) {
|
||||
// Add local src,dst and scratch to registeredMemoryIds
|
||||
context.localMemoryIdBegin = context.proxyService->nextMemoryId(3);
|
||||
for (auto& bufferType : {BufferType::INPUT, BufferType::OUTPUT, BufferType::SCRATCH}) {
|
||||
TransportFlags flags = Transport::CudaIpc;
|
||||
#if defined(USE_IBVERBS)
|
||||
@@ -440,12 +468,12 @@ struct Executor::Impl {
|
||||
ExecutionKernel::launchKernel<PacketType, true>(
|
||||
rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, scratchBuffer, scratchOffset,
|
||||
context.scratchChunkSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffers[key].get(),
|
||||
(DeviceSemaphore*)context.smemaphores.get(), sharedMemSize, stream, flag);
|
||||
(DeviceSemaphore*)context.smemaphores.get(), context.localMemoryIdBegin, sharedMemSize, stream, flag);
|
||||
} else {
|
||||
ExecutionKernel::launchKernel<PacketType, false>(
|
||||
rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, scratchBuffer, scratchOffset,
|
||||
context.scratchChunkSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffers[key].get(),
|
||||
(DeviceSemaphore*)context.smemaphores.get(), sharedMemSize, stream, flag);
|
||||
(DeviceSemaphore*)context.smemaphores.get(), context.localMemoryIdBegin, sharedMemSize, stream, flag);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -480,7 +508,8 @@ struct Executor::Impl {
|
||||
}
|
||||
};
|
||||
|
||||
Executor::Executor(std::shared_ptr<Communicator> comm) : impl_(std::make_unique<Impl>(comm)) {}
|
||||
Executor::Executor(std::shared_ptr<Communicator> comm, std::shared_ptr<char> defaultScratchBuffer)
|
||||
: impl_(std::make_unique<Impl>(comm, defaultScratchBuffer)) {}
|
||||
|
||||
void Executor::execute(int rank, void* sendbuff, void* recvbuff, size_t sendBuffSize,
|
||||
[[maybe_unused]] size_t recvBuffSize, DataType dataType, const ExecutionPlan& plan,
|
||||
@@ -494,9 +523,9 @@ void Executor::execute(int rank, void* sendbuff, void* recvbuff, size_t sendBuff
|
||||
size_t offsetIn = (char*)sendbuff - (char*)sendBasePtr;
|
||||
size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr;
|
||||
|
||||
ExecutionContext context =
|
||||
this->impl_->setupExecutionContext(rank, (void*)sendBasePtr, (void*)recvBasePtr, sendBuffSize, recvBuffSize,
|
||||
offsetIn, offsetOut, sendMemRange, recvMemRange, plan);
|
||||
ExecutionContext context = this->impl_->setupExecutionContext(
|
||||
rank, (void*)sendBasePtr, (void*)recvBasePtr, sendBuffSize, recvBuffSize, offsetIn, offsetOut, sendMemRange,
|
||||
recvMemRange, plan, this->impl_->proxyService);
|
||||
this->impl_->launchKernel(context, rank, sendbuff, recvbuff, dataType, stream, packetType);
|
||||
}
|
||||
|
||||
|
||||
@@ -68,6 +68,7 @@ IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) {
|
||||
this->mr = IBVerbs::ibv_reg_dmabuf_mr(pd, offsetInDmaBuf, size, (uint64_t)dptr, fd,
|
||||
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ |
|
||||
IBV_ACCESS_RELAXED_ORDERING | IBV_ACCESS_REMOTE_ATOMIC);
|
||||
close(fd);
|
||||
if (this->mr == nullptr) {
|
||||
std::stringstream err;
|
||||
err << "ibv_reg_dmabuf_mr failed (errno " << errno << ")";
|
||||
|
||||
@@ -384,6 +384,7 @@ __shared__ BufferType* portChannelBufferTypes_;
|
||||
__shared__ uint32_t flag_;
|
||||
__shared__ uint32_t scratchChunkSize_;
|
||||
__shared__ uint32_t scratchOffset_;
|
||||
__shared__ MemoryId localMemoryIdBegin_;
|
||||
#if defined(ENABLE_NPKIT)
|
||||
__shared__ NpKitEvent* eventBuffer_;
|
||||
#endif
|
||||
@@ -516,7 +517,7 @@ MSCCLPP_DEVICE_INLINE void handlePut(const Operation& op, void* input, void* out
|
||||
if (tid < count) {
|
||||
uint32_t size = min(outputSizes[tid] - offset, unitSize);
|
||||
MemoryId dstMemoryId = portChannelBufferIds_[op.outputBufferRefs[tid].id];
|
||||
MemoryId srcMemoryId = static_cast<MemoryId>(op.inputBufferRefs[tid].type);
|
||||
MemoryId srcMemoryId = static_cast<MemoryId>(op.inputBufferRefs[tid].type) + localMemoryIdBegin_;
|
||||
uint32_t dstOffset =
|
||||
dstOffsets[tid] + getOffset<ReuseScratch>(portChannelBufferTypes_[op.outputBufferRefs[tid].id], offset);
|
||||
uint32_t srcOffset = srcOffsets[tid] + getOffset<ReuseScratch>(op.inputBufferRefs[tid].type, offset);
|
||||
@@ -674,8 +675,8 @@ MSCCLPP_DEVICE_INLINE void handleReadPutPackets(const Operation& op, void* scrat
|
||||
uint32_t dstOffset = (dstOffsets[chIdx] << 1) + scratchOffset_;
|
||||
uint32_t srcOffset = (srcOffsets[chIdx] << 1) + scratchOffset_;
|
||||
MemoryId dstMemoryId = portChannelBufferIds_[op.outputBufferRefs[chIdx].id];
|
||||
portChannels_[channelIndexes[chIdx]].put(dstMemoryId, dstOffset, static_cast<MemoryId>(BufferType::SCRATCH),
|
||||
srcOffset, size << 1);
|
||||
portChannels_[channelIndexes[chIdx]].put(
|
||||
dstMemoryId, dstOffset, static_cast<MemoryId>(BufferType::SCRATCH) + localMemoryIdBegin_, srcOffset, size << 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1027,7 +1028,8 @@ template <typename T, typename PacketType = LL16Packet, bool ReuseScratch = fals
|
||||
__global__ __launch_bounds__(1024, 1) void executionKernel([[maybe_unused]] int rank /*for debug*/, T* input, T* output,
|
||||
T* scratch, uint32_t scratchOffset,
|
||||
uint32_t scratchChunkSize, DeviceExecutionPlan* plan,
|
||||
DeviceSemaphore* semaphores, uint32_t flag
|
||||
DeviceSemaphore* semaphores, uint32_t localMemoryIdBegin,
|
||||
uint32_t flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKitEventCollectContext* npKitEventCollectContexts,
|
||||
@@ -1067,6 +1069,7 @@ __global__ __launch_bounds__(1024, 1) void executionKernel([[maybe_unused]] int
|
||||
flag_ = flag;
|
||||
scratchChunkSize_ = scratchChunkSize;
|
||||
scratchOffset_ = scratchOffset;
|
||||
localMemoryIdBegin_ = localMemoryIdBegin;
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_CPU)
|
||||
#if defined(MSCCLPP_DEVICE_HIP)
|
||||
@@ -1112,13 +1115,13 @@ class ExecutionKernel {
|
||||
template <typename PacketType, bool ReuseScratch>
|
||||
static void launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch,
|
||||
uint32_t scratchOffset, uint32_t scratchChunkSize, DataType dataType,
|
||||
DeviceExecutionPlan* plan, DeviceSemaphore* semaphores, uint32_t sharedMemSize,
|
||||
cudaStream_t stream, uint32_t flag = 0) {
|
||||
DeviceExecutionPlan* plan, DeviceSemaphore* semaphores, uint32_t localMemoryIdBegin,
|
||||
uint32_t sharedMemSize, cudaStream_t stream, uint32_t flag = 0) {
|
||||
switch (dataType) {
|
||||
case DataType::INT32:
|
||||
executionKernel<int32_t, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchOffset, scratchChunkSize, plan, semaphores,
|
||||
flag
|
||||
localMemoryIdBegin, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
@@ -1129,7 +1132,7 @@ class ExecutionKernel {
|
||||
case DataType::UINT32:
|
||||
executionKernel<uint32_t, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchOffset, scratchChunkSize, plan, semaphores,
|
||||
flag
|
||||
localMemoryIdBegin, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
@@ -1139,7 +1142,8 @@ class ExecutionKernel {
|
||||
break;
|
||||
case DataType::FLOAT16:
|
||||
executionKernel<half, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (half*)src, (half*)dst, (half*)scratch, scratchOffset, scratchChunkSize, plan, semaphores, flag
|
||||
rank, (half*)src, (half*)dst, (half*)scratch, scratchOffset, scratchChunkSize, plan, semaphores,
|
||||
localMemoryIdBegin, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
@@ -1149,7 +1153,8 @@ class ExecutionKernel {
|
||||
break;
|
||||
case DataType::FLOAT32:
|
||||
executionKernel<float, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (float*)src, (float*)dst, (float*)scratch, scratchOffset, scratchChunkSize, plan, semaphores, flag
|
||||
rank, (float*)src, (float*)dst, (float*)scratch, scratchOffset, scratchChunkSize, plan, semaphores,
|
||||
localMemoryIdBegin, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
@@ -1160,7 +1165,7 @@ class ExecutionKernel {
|
||||
case DataType::BFLOAT16:
|
||||
executionKernel<__bfloat16, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, scratchOffset, scratchChunkSize, plan,
|
||||
semaphores, flag
|
||||
semaphores, localMemoryIdBegin, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
@@ -1172,7 +1177,7 @@ class ExecutionKernel {
|
||||
case DataType::FP8_E4M3:
|
||||
executionKernel<__fp8_e4m3, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (__fp8_e4m3*)src, (__fp8_e4m3*)dst, (__fp8_e4m3*)scratch, scratchOffset, scratchChunkSize, plan,
|
||||
semaphores, flag
|
||||
semaphores, localMemoryIdBegin, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
@@ -1183,7 +1188,7 @@ class ExecutionKernel {
|
||||
case DataType::FP8_E5M2:
|
||||
executionKernel<__fp8_e5m2, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (__fp8_e5m2*)src, (__fp8_e5m2*)dst, (__fp8_e5m2*)scratch, scratchOffset, scratchChunkSize, plan,
|
||||
semaphores, flag
|
||||
semaphores, localMemoryIdBegin, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
@@ -1198,8 +1203,8 @@ class ExecutionKernel {
|
||||
template <typename PacketType, bool ReuseScratch>
|
||||
static void launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch,
|
||||
uint32_t scratchOffset, uint32_t scratchChunkSize, DataType dataType,
|
||||
DeviceExecutionPlan* plan, DeviceSemaphore* semaphores, uint32_t sharedMemSize,
|
||||
cudaStream_t stream, uint32_t flag = 0);
|
||||
DeviceExecutionPlan* plan, DeviceSemaphore* semaphores, uint32_t localMemoryIdBegin,
|
||||
uint32_t sharedMemSize, cudaStream_t stream, uint32_t flag = 0);
|
||||
#endif // !defined(MSCCLPP_DEVICE_HIP)
|
||||
};
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -66,6 +66,29 @@ struct SemaphoreInfo {
|
||||
int initValue;
|
||||
};
|
||||
|
||||
struct AlgoConfig {
|
||||
std::string filename;
|
||||
std::string collective;
|
||||
int nRanksPerNode;
|
||||
int worldSize;
|
||||
std::unordered_map<std::string, uint64_t> tags;
|
||||
};
|
||||
|
||||
struct ExecutionPlanRegistry::Impl {
|
||||
void setSelector(ExecutionPlanSelector selector);
|
||||
void setDefaultSelector(ExecutionPlanSelector selector);
|
||||
void registerPlan(const std::shared_ptr<ExecutionPlanHandle> planHandle);
|
||||
std::shared_ptr<ExecutionPlanHandle> select(const ExecutionRequest& request);
|
||||
std::vector<ExecutionPlanHandle> getPlans(const std::string& collective);
|
||||
std::shared_ptr<ExecutionPlanHandle> get(const std::string& id);
|
||||
void loadDefaultPlans(int rank);
|
||||
|
||||
ExecutionPlanSelector selector_ = nullptr;
|
||||
ExecutionPlanSelector defaultSelector_ = nullptr;
|
||||
std::unordered_map<std::string, std::vector<std::shared_ptr<ExecutionPlanHandle>>> planMap_;
|
||||
std::unordered_map<std::string, std::shared_ptr<ExecutionPlanHandle>> idMap_;
|
||||
};
|
||||
|
||||
struct ExecutionPlan::Impl {
|
||||
public:
|
||||
Impl(const std::string& planPath, int rank);
|
||||
|
||||
@@ -62,6 +62,14 @@ MSCCLPP_API_CPP MemoryId ProxyService::addMemory(RegisteredMemory memory) {
|
||||
return memories_.size() - 1;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP MemoryId ProxyService::nextMemoryId([[maybe_unused]] uint32_t count) const {
|
||||
if (count == 0) {
|
||||
throw Error("count must be greater than 0", ErrorCode::InvalidUsage);
|
||||
}
|
||||
MemoryId firstId = memories_.size();
|
||||
return firstId;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP std::shared_ptr<Host2DeviceSemaphore> ProxyService::semaphore(SemaphoreId id) const {
|
||||
return semaphores_[id];
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user