Torch integration (#692)

Reorganize current native algorithm implementation and DSL algorithm
implementation.
Provide unified API for DSL algo and native algo and provide interface
to tune the algo
Provide interface for pytorch integration with native API and DSL

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: chhwang <8018170+chhwang@users.noreply.github.com>
This commit is contained in:
Binyang Li
2026-01-21 20:32:24 -08:00
committed by GitHub
parent 78ce9fac8d
commit a707273701
156 changed files with 6107 additions and 4076 deletions

201
src/core/algorithm.cc Normal file
View File

@@ -0,0 +1,201 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include <filesystem>
#include <mscclpp/algorithm.hpp>
#include "logger.hpp"
namespace mscclpp {
CollectiveBufferMode CollectiveRequest::bufferMode() const {
if (inputBuffer == outputBuffer) return CollectiveBufferMode::InPlace;
if (collective == "allgather") {
size_t rankOffset = rank * messageSize;
const char* expectedInput = static_cast<const char*>(outputBuffer) + rankOffset;
if (static_cast<const void*>(expectedInput) == inputBuffer) {
return CollectiveBufferMode::InPlace;
}
return CollectiveBufferMode::OutOfPlace;
}
return CollectiveBufferMode::OutOfPlace;
}
NativeAlgorithm::NativeAlgorithm(std::string name, std::string collective, InitFunc initFunc, KernelFunc kernelFunc,
ContextInitFunc contextInitFunc, ContextKeyGenFunc contextKeyGenFunc,
size_t minMessageSize, size_t maxMessageSize, CollectiveBufferMode bufferMode,
std::unordered_map<std::string, uint64_t> tags, Constraint constraint)
: name_(name),
collective_(collective),
initFunc_(initFunc),
kernelLaunchFunc_(kernelFunc),
contextInitFunc_(contextInitFunc),
contextKeyGenFunc_(contextKeyGenFunc),
minMessageSize_(minMessageSize),
maxMessageSize_(maxMessageSize),
bufferMode_(bufferMode),
tags_(tags),
constraint_(constraint) {}
CommResult NativeAlgorithm::execute(std::shared_ptr<Communicator> comm, const void* input, void* output,
size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op,
cudaStream_t stream, std::shared_ptr<Executor>, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>& extras) {
if (!initialized_) {
initFunc_(comm);
initialized_ = true;
}
AlgorithmCtxKey ctxKey = contextKeyGenFunc_(input, output, inputSize, outputSize, dtype);
auto it = contexts_.find(ctxKey);
if (it == contexts_.end()) {
auto ctx = contextInitFunc_(comm, input, output, inputSize, outputSize, dtype);
contexts_[ctxKey] = ctx;
}
return kernelLaunchFunc_(contexts_[ctxKey], input, output, inputSize, outputSize, dtype, op, stream, nBlocks,
nThreadsPerBlock, extras);
}
const std::string& NativeAlgorithm::name() const { return name_; }
const std::string& NativeAlgorithm::collective() const { return collective_; }
const std::pair<size_t, size_t>& NativeAlgorithm::messageRange() const {
static std::pair<size_t, size_t> range;
range = {minMessageSize_, maxMessageSize_};
return range;
}
const std::unordered_map<std::string, uint64_t>& NativeAlgorithm::tags() const { return tags_; }
const CollectiveBufferMode& NativeAlgorithm::bufferMode() const { return bufferMode_; }
Algorithm::Constraint NativeAlgorithm::constraint() const { return constraint_; }
void NativeAlgorithm::reset() {
contexts_.clear();
initialized_ = false;
}
void AlgorithmCollection::registerAlgorithm(const std::string collective, const std::string algoName,
std::shared_ptr<Algorithm> algorithm) {
this->algoMapByCollective_[collective][algoName] = algorithm;
}
std::shared_ptr<Algorithm> AlgorithmCollection::selectAlgorithm(const CollectiveRequest& request) {
std::shared_ptr<Algorithm> algo;
if (!algoSelector_ && !fallbackAlgoSelector_) {
THROW(ALGO, Error, ErrorCode::InvalidUsage, "No algorithm selector is set in AlgorithmCollection.");
}
if (algoSelector_) {
algo = algoSelector_(algoMapByCollective_, request);
}
if (!algo) {
algo = fallbackAlgoSelector_(algoMapByCollective_, request);
}
return algo;
}
void AlgorithmCollection::extend(const AlgorithmCollection& other) {
for (const auto& [collective, algoMap] : other.algoMapByCollective_) {
for (const auto& [algoName, algorithm] : algoMap) {
this->registerAlgorithm(collective, algoName, algorithm);
}
}
}
void AlgorithmCollection::setSelectors(AlgoSelectFunc algoSelector, AlgoSelectFunc fallbackAlgoSelector) {
algoSelector_ = algoSelector;
fallbackAlgoSelector_ = fallbackAlgoSelector;
}
std::vector<std::shared_ptr<Algorithm>> AlgorithmCollection::getAllAlgorithms() const {
std::vector<std::shared_ptr<Algorithm>> allAlgos;
for (const auto& [collective, algoMap] : algoMapByCollective_) {
for (const auto& [algoName, algorithm] : algoMap) {
allAlgos.push_back(algorithm);
}
}
return allAlgos;
}
std::unordered_map<std::string, std::shared_ptr<Algorithm>> AlgorithmCollection::getAlgorithmsByCollective(
const std::string& collective) const {
auto it = algoMapByCollective_.find(collective);
if (it != algoMapByCollective_.end()) {
return it->second;
} else {
return {};
}
}
DslAlgorithm::DslAlgorithm(std::string id, ExecutionPlan plan, std::unordered_map<std::string, uint64_t> tags,
Constraint constraint)
: plan_(plan), id_(id), tags_(tags), constraint_(constraint) {}
const std::string& DslAlgorithm::name() const { return plan_.name(); }
const std::string& DslAlgorithm::collective() const { return plan_.collective(); }
const std::pair<size_t, size_t>& DslAlgorithm::messageRange() const {
static std::pair<size_t, size_t> range;
range = {plan_.minMessageSize(), plan_.maxMessageSize()};
return range;
}
const std::unordered_map<std::string, uint64_t>& DslAlgorithm::tags() const { return tags_; }
const CollectiveBufferMode& DslAlgorithm::bufferMode() const {
// TODO: need to fix
static CollectiveBufferMode mode =
plan_.isInPlace() ? CollectiveBufferMode::InPlace : CollectiveBufferMode::OutOfPlace;
return mode;
}
Algorithm::Constraint DslAlgorithm::constraint() const { return constraint_; }
CommResult DslAlgorithm::execute(std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
size_t outputSize, DataType dtype, ReduceOp, cudaStream_t stream,
std::shared_ptr<Executor> executor, int, int,
const std::unordered_map<std::string, uintptr_t>&) {
if (!executor) {
THROW(EXEC, Error, ErrorCode::InvalidUsage, "Executor is null in DslAlgorithm::execute");
}
int rank = comm->bootstrap()->getRank();
switch (dtype) {
case DataType::FLOAT16:
executor->execute(rank, (half*)input, (half*)output, inputSize, outputSize, DataType::FLOAT16, plan_, stream);
break;
case DataType::FLOAT32:
executor->execute(rank, (float*)input, (float*)output, inputSize, outputSize, DataType::FLOAT32, plan_, stream);
break;
case DataType::BFLOAT16:
executor->execute(rank, (__bfloat16*)input, (__bfloat16*)output, inputSize, outputSize, DataType::BFLOAT16, plan_,
stream);
break;
#if defined(__FP8_TYPES_EXIST__)
case DataType::FP8_E4M3:
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FP8_E4M3, plan_,
stream);
break;
case DataType::FP8_E5M2:
executor->execute(rank, (__fp8_e5m2*)input, (__fp8_e5m2*)output, inputSize, outputSize, DataType::FP8_E5M2, plan_,
stream);
break;
#endif
case DataType::INT32:
case DataType::UINT32:
executor->execute(rank, (int*)input, (int*)output, inputSize, outputSize, DataType::UINT32, plan_, stream);
break;
default:
WARN(ALGO, "Unsupported data type: ", static_cast<int>(dtype), " in DslAlgorithm");
return CommResult::CommInvalidArgument;
}
return CommResult::CommSuccess;
}
std::shared_ptr<Algorithm> DslAlgorithm::build() { return shared_from_this(); }
// TODO: implement this
void DslAlgorithm::reset() {}
} // namespace mscclpp