mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
## Summary - Fix ROCm FP8 build failure by using the actual FP8 `DataType` enum constants in allreduce packet tuning. - Fix FP8 E4M3FNUZ test encoding so small negative values do not produce the FNUZ NaN byte (`0x80`). - Align FP8 `DataType` enum constants and Python bindings with torch-style names (`FLOAT8_E4M3FN`, `FLOAT8_E4M3FNUZ`, `FLOAT8_E5M2FNUZ` / `float8_e4m3fn`, `float8_e4m3fnuz`, `float8_e5m2fnuz`). ## Validation - `./tools/lint.sh` - `make -j` from `build/` - `mpirun --allow-run-as-root -np 8 python3 -m pytest python/test/test_fp8_accum.py -q` (`36 passed, 9 skipped`) - `DTYPE=float8_e4m3fnuz ACCUM_DTYPE=float32 torchrun --nnodes=1 --nproc_per_node=8 examples/torch-integration/customized_comm_with_tuning.py` --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
263 lines
10 KiB
C++
263 lines
10 KiB
C++
// Copyright (c) Microsoft Corporation.
|
|
// Licensed under the MIT License.
|
|
|
|
#include <filesystem>
|
|
#include <mscclpp/algorithm.hpp>
|
|
#include <mscclpp/errors.hpp>
|
|
#include <mscclpp/gpu_utils.hpp>
|
|
|
|
#include "logger.hpp"
|
|
|
|
namespace mscclpp {
|
|
|
|
CollectiveBufferMode CollectiveRequest::bufferMode() const {
|
|
if (inputBuffer == outputBuffer) return CollectiveBufferMode::InPlace;
|
|
if (collective == "allgather") {
|
|
size_t rankOffset = rank * messageSize;
|
|
const char* expectedInput = static_cast<const char*>(outputBuffer) + rankOffset;
|
|
if (static_cast<const void*>(expectedInput) == inputBuffer) {
|
|
return CollectiveBufferMode::InPlace;
|
|
}
|
|
return CollectiveBufferMode::OutOfPlace;
|
|
}
|
|
return CollectiveBufferMode::OutOfPlace;
|
|
}
|
|
|
|
NativeAlgorithm::NativeAlgorithm(std::string name, std::string collective, InitFunc initFunc, KernelFunc kernelFunc,
|
|
ContextInitFunc contextInitFunc, ContextKeyGenFunc contextKeyGenFunc,
|
|
size_t minMessageSize, size_t maxMessageSize, CollectiveBufferMode bufferMode,
|
|
std::unordered_map<std::string, uint64_t> tags, Constraint constraint)
|
|
: name_(name),
|
|
collective_(collective),
|
|
initFunc_(initFunc),
|
|
kernelLaunchFunc_(kernelFunc),
|
|
contextInitFunc_(contextInitFunc),
|
|
contextKeyGenFunc_(contextKeyGenFunc),
|
|
minMessageSize_(minMessageSize),
|
|
maxMessageSize_(maxMessageSize),
|
|
bufferMode_(bufferMode),
|
|
tags_(tags),
|
|
constraint_(constraint) {}
|
|
|
|
CommResult NativeAlgorithm::execute(std::shared_ptr<Communicator> comm, const void* input, void* output,
|
|
size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op,
|
|
cudaStream_t stream, std::shared_ptr<Executor>, int nBlocks, int nThreadsPerBlock,
|
|
bool symmetricMemory, const std::unordered_map<std::string, uintptr_t>& extras,
|
|
DataType accumDtype) {
|
|
if (accumDtype == DataType::AUTO) accumDtype = dtype;
|
|
if (!initialized_) {
|
|
initFunc_(comm);
|
|
initialized_ = true;
|
|
}
|
|
AlgorithmCtxKey ctxKey = contextKeyGenFunc_(input, output, inputSize, outputSize, dtype, symmetricMemory);
|
|
auto it = contexts_.find(ctxKey);
|
|
if (it == contexts_.end()) {
|
|
auto ctx = contextInitFunc_(comm, input, output, inputSize, outputSize, dtype);
|
|
contexts_[ctxKey] = ctx;
|
|
}
|
|
return kernelLaunchFunc_(contexts_[ctxKey], input, output, inputSize, outputSize, dtype, op, stream, nBlocks,
|
|
nThreadsPerBlock, extras, accumDtype);
|
|
}
|
|
|
|
const std::string& NativeAlgorithm::name() const { return name_; }
|
|
|
|
const std::string& NativeAlgorithm::collective() const { return collective_; }
|
|
|
|
const std::pair<size_t, size_t>& NativeAlgorithm::messageRange() const {
|
|
static std::pair<size_t, size_t> range;
|
|
range = {minMessageSize_, maxMessageSize_};
|
|
return range;
|
|
}
|
|
|
|
void NativeAlgorithm::setMessageSizeRange(size_t minMessageSize, size_t maxMessageSize) {
|
|
minMessageSize_ = minMessageSize;
|
|
maxMessageSize_ = maxMessageSize;
|
|
}
|
|
|
|
const std::unordered_map<std::string, uint64_t>& NativeAlgorithm::tags() const { return tags_; }
|
|
|
|
const CollectiveBufferMode& NativeAlgorithm::bufferMode() const { return bufferMode_; }
|
|
|
|
Algorithm::Constraint NativeAlgorithm::constraint() const { return constraint_; }
|
|
|
|
void NativeAlgorithm::reset() { contexts_.clear(); }
|
|
|
|
void AlgorithmCollection::registerAlgorithm(const std::string collective, const std::string algoName,
|
|
std::shared_ptr<Algorithm> algorithm) {
|
|
this->algoMapByCollective_[collective][algoName] = algorithm;
|
|
}
|
|
|
|
std::shared_ptr<Algorithm> AlgorithmCollection::selectAlgorithm(const CollectiveRequest& request) {
|
|
std::shared_ptr<Algorithm> algo;
|
|
if (!algoSelector_ && !fallbackAlgoSelector_) {
|
|
THROW(ALGO, Error, ErrorCode::InvalidUsage, "No algorithm selector is set in AlgorithmCollection.");
|
|
}
|
|
if (algoSelector_) {
|
|
algo = algoSelector_(algoMapByCollective_, request);
|
|
}
|
|
if (!algo) {
|
|
algo = fallbackAlgoSelector_(algoMapByCollective_, request);
|
|
}
|
|
return algo;
|
|
}
|
|
|
|
void AlgorithmCollection::extend(const AlgorithmCollection& other) {
|
|
for (const auto& [collective, algoMap] : other.algoMapByCollective_) {
|
|
for (const auto& [algoName, algorithm] : algoMap) {
|
|
this->registerAlgorithm(collective, algoName, algorithm);
|
|
}
|
|
}
|
|
}
|
|
|
|
void AlgorithmCollection::setSelectors(AlgoSelectFunc algoSelector, AlgoSelectFunc fallbackAlgoSelector) {
|
|
algoSelector_ = algoSelector;
|
|
fallbackAlgoSelector_ = fallbackAlgoSelector;
|
|
}
|
|
|
|
std::vector<std::shared_ptr<Algorithm>> AlgorithmCollection::getAllAlgorithms() const {
|
|
std::vector<std::shared_ptr<Algorithm>> allAlgos;
|
|
for (const auto& [collective, algoMap] : algoMapByCollective_) {
|
|
for (const auto& [algoName, algorithm] : algoMap) {
|
|
allAlgos.push_back(algorithm);
|
|
}
|
|
}
|
|
return allAlgos;
|
|
}
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<Algorithm>> AlgorithmCollection::getAlgorithmsByCollective(
|
|
const std::string& collective) const {
|
|
auto it = algoMapByCollective_.find(collective);
|
|
if (it != algoMapByCollective_.end()) {
|
|
return it->second;
|
|
} else {
|
|
return {};
|
|
}
|
|
}
|
|
|
|
DslAlgorithm::DslAlgorithm(std::string id, ExecutionPlan plan, std::unordered_map<std::string, uint64_t> tags,
|
|
Constraint constraint)
|
|
: plan_(plan), id_(id), tags_(tags), constraint_(constraint) {}
|
|
|
|
const std::string& DslAlgorithm::name() const { return plan_.name(); }
|
|
|
|
const std::string& DslAlgorithm::collective() const { return plan_.collective(); }
|
|
|
|
const std::pair<size_t, size_t>& DslAlgorithm::messageRange() const {
|
|
static std::pair<size_t, size_t> range;
|
|
range = {plan_.minMessageSize(), plan_.maxMessageSize()};
|
|
return range;
|
|
}
|
|
|
|
void DslAlgorithm::setMessageSizeRange(size_t, size_t) {
|
|
THROW(EXEC, Error, ErrorCode::InvalidUsage, "setMessageSizeRange is only supported for native algorithms");
|
|
}
|
|
|
|
const std::unordered_map<std::string, uint64_t>& DslAlgorithm::tags() const { return tags_; }
|
|
|
|
const CollectiveBufferMode& DslAlgorithm::bufferMode() const {
|
|
// TODO: need to fix
|
|
static CollectiveBufferMode mode =
|
|
plan_.isInPlace() ? CollectiveBufferMode::InPlace : CollectiveBufferMode::OutOfPlace;
|
|
return mode;
|
|
}
|
|
|
|
Algorithm::Constraint DslAlgorithm::constraint() const { return constraint_; }
|
|
|
|
CommResult DslAlgorithm::execute(std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
|
|
size_t outputSize, DataType dtype, ReduceOp, cudaStream_t stream,
|
|
std::shared_ptr<Executor> executor, int, int, bool,
|
|
const std::unordered_map<std::string, uintptr_t>&, DataType) {
|
|
if (!executor) {
|
|
THROW(EXEC, Error, ErrorCode::InvalidUsage, "Executor is null in DslAlgorithm::execute");
|
|
}
|
|
int rank = comm->bootstrap()->getRank();
|
|
switch (dtype) {
|
|
case DataType::FLOAT16:
|
|
executor->execute(rank, (half*)input, (half*)output, inputSize, outputSize, DataType::FLOAT16, plan_, stream);
|
|
break;
|
|
case DataType::FLOAT32:
|
|
executor->execute(rank, (float*)input, (float*)output, inputSize, outputSize, DataType::FLOAT32, plan_, stream);
|
|
break;
|
|
case DataType::BFLOAT16:
|
|
executor->execute(rank, (__bfloat16*)input, (__bfloat16*)output, inputSize, outputSize, DataType::BFLOAT16, plan_,
|
|
stream);
|
|
break;
|
|
#if defined(__FP8_TYPES_EXIST__)
|
|
case DataType::FLOAT8_E4M3FN:
|
|
#if defined(__FP8_E4M3_IS_FNUZ__)
|
|
THROW(EXEC, Error, ErrorCode::InvalidUsage,
|
|
"FLOAT8_E4M3FN is not natively supported on this platform; use FLOAT8_E4M3FNUZ");
|
|
#else
|
|
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3FN,
|
|
plan_, stream);
|
|
#endif
|
|
break;
|
|
case DataType::FLOAT8_E4M3FNUZ:
|
|
#if defined(__FP8_E4M3_IS_FNUZ__)
|
|
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3FNUZ,
|
|
plan_, stream);
|
|
#else
|
|
THROW(EXEC, Error, ErrorCode::InvalidUsage,
|
|
"FLOAT8_E4M3FNUZ is not natively supported on this platform; use FLOAT8_E4M3FN");
|
|
#endif
|
|
break;
|
|
case DataType::FLOAT8_E5M2:
|
|
#if defined(__FP8_E5M2_IS_FNUZ__)
|
|
THROW(EXEC, Error, ErrorCode::InvalidUsage,
|
|
"FLOAT8_E5M2 is not natively supported on this platform; use FLOAT8_E5M2FNUZ");
|
|
#else
|
|
executor->execute(rank, (__fp8_e5m2*)input, (__fp8_e5m2*)output, inputSize, outputSize, DataType::FLOAT8_E5M2,
|
|
plan_, stream);
|
|
#endif
|
|
break;
|
|
case DataType::FLOAT8_E5M2FNUZ:
|
|
#if defined(__FP8_E5M2_IS_FNUZ__)
|
|
executor->execute(rank, (__fp8_e5m2*)input, (__fp8_e5m2*)output, inputSize, outputSize, DataType::FLOAT8_E5M2FNUZ,
|
|
plan_, stream);
|
|
#else
|
|
THROW(EXEC, Error, ErrorCode::InvalidUsage,
|
|
"FLOAT8_E5M2FNUZ is not natively supported on this platform; use FLOAT8_E5M2");
|
|
#endif
|
|
break;
|
|
#endif
|
|
case DataType::FLOAT8_E4M3B15:
|
|
executor->execute(rank, (__fp8_e4m3b15*)input, (__fp8_e4m3b15*)output, inputSize, outputSize,
|
|
DataType::FLOAT8_E4M3B15, plan_, stream);
|
|
break;
|
|
case DataType::INT32:
|
|
case DataType::UINT32:
|
|
executor->execute(rank, (int*)input, (int*)output, inputSize, outputSize, DataType::UINT32, plan_, stream);
|
|
break;
|
|
default:
|
|
WARN(ALGO, "Unsupported data type: ", static_cast<int>(dtype), " in DslAlgorithm");
|
|
return CommResult::CommInvalidArgument;
|
|
}
|
|
return CommResult::CommSuccess;
|
|
}
|
|
|
|
std::shared_ptr<Algorithm> DslAlgorithm::build() { return shared_from_this(); }
|
|
|
|
// TODO: implement this
|
|
void DslAlgorithm::reset() {}
|
|
|
|
static uint32_t* gDefaultFlagBuffer = nullptr;
|
|
static std::weak_ptr<void> gDefaultFlagBufferWeak;
|
|
static size_t gDefaultFlagCount = 128;
|
|
|
|
std::pair<std::shared_ptr<void>, size_t> getFlagBuffer() {
|
|
auto ptr = gDefaultFlagBufferWeak.lock();
|
|
if (!ptr) {
|
|
if (!gDefaultFlagBuffer) {
|
|
// Intentionally never freed — CUDA driver reclaims GPU memory at process exit.
|
|
gDefaultFlagBuffer = static_cast<uint32_t*>(mscclpp::detail::gpuCalloc(gDefaultFlagCount * sizeof(uint32_t)));
|
|
std::vector<uint32_t> initFlags(gDefaultFlagCount, 1);
|
|
mscclpp::gpuMemcpy(gDefaultFlagBuffer, initFlags.data(), gDefaultFlagCount, cudaMemcpyHostToDevice);
|
|
}
|
|
ptr = std::shared_ptr<void>(gDefaultFlagBuffer, [](void*) {});
|
|
gDefaultFlagBufferWeak = ptr;
|
|
}
|
|
return {ptr, gDefaultFlagCount * sizeof(uint32_t)};
|
|
}
|
|
|
|
} // namespace mscclpp
|