Files
mscclpp/src/core/algorithm.cc
Binyang Li 2c52937b26 Fix FP8 ROCm build/test issues and dtype naming (#792)
## Summary
- Fix ROCm FP8 build failure by using the actual FP8 `DataType` enum
constants in allreduce packet tuning.
- Fix FP8 E4M3FNUZ test encoding so small negative values do not produce
the FNUZ NaN byte (`0x80`).
- Align FP8 `DataType` enum constants and Python bindings with
torch-style names (`FLOAT8_E4M3FN`, `FLOAT8_E4M3FNUZ`, `FLOAT8_E5M2FNUZ`
/ `float8_e4m3fn`, `float8_e4m3fnuz`, `float8_e5m2fnuz`).

## Validation
- `./tools/lint.sh`
- `make -j` from `build/`
- `mpirun --allow-run-as-root -np 8 python3 -m pytest
python/test/test_fp8_accum.py -q` (`36 passed, 9 skipped`)
- `DTYPE=float8_e4m3fnuz ACCUM_DTYPE=float32 torchrun --nnodes=1
--nproc_per_node=8
examples/torch-integration/customized_comm_with_tuning.py`

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-04-28 15:02:22 -07:00

263 lines
10 KiB
C++

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include <filesystem>
#include <mscclpp/algorithm.hpp>
#include <mscclpp/errors.hpp>
#include <mscclpp/gpu_utils.hpp>
#include "logger.hpp"
namespace mscclpp {
CollectiveBufferMode CollectiveRequest::bufferMode() const {
if (inputBuffer == outputBuffer) return CollectiveBufferMode::InPlace;
if (collective == "allgather") {
size_t rankOffset = rank * messageSize;
const char* expectedInput = static_cast<const char*>(outputBuffer) + rankOffset;
if (static_cast<const void*>(expectedInput) == inputBuffer) {
return CollectiveBufferMode::InPlace;
}
return CollectiveBufferMode::OutOfPlace;
}
return CollectiveBufferMode::OutOfPlace;
}
NativeAlgorithm::NativeAlgorithm(std::string name, std::string collective, InitFunc initFunc, KernelFunc kernelFunc,
ContextInitFunc contextInitFunc, ContextKeyGenFunc contextKeyGenFunc,
size_t minMessageSize, size_t maxMessageSize, CollectiveBufferMode bufferMode,
std::unordered_map<std::string, uint64_t> tags, Constraint constraint)
: name_(name),
collective_(collective),
initFunc_(initFunc),
kernelLaunchFunc_(kernelFunc),
contextInitFunc_(contextInitFunc),
contextKeyGenFunc_(contextKeyGenFunc),
minMessageSize_(minMessageSize),
maxMessageSize_(maxMessageSize),
bufferMode_(bufferMode),
tags_(tags),
constraint_(constraint) {}
CommResult NativeAlgorithm::execute(std::shared_ptr<Communicator> comm, const void* input, void* output,
size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op,
cudaStream_t stream, std::shared_ptr<Executor>, int nBlocks, int nThreadsPerBlock,
bool symmetricMemory, const std::unordered_map<std::string, uintptr_t>& extras,
DataType accumDtype) {
if (accumDtype == DataType::AUTO) accumDtype = dtype;
if (!initialized_) {
initFunc_(comm);
initialized_ = true;
}
AlgorithmCtxKey ctxKey = contextKeyGenFunc_(input, output, inputSize, outputSize, dtype, symmetricMemory);
auto it = contexts_.find(ctxKey);
if (it == contexts_.end()) {
auto ctx = contextInitFunc_(comm, input, output, inputSize, outputSize, dtype);
contexts_[ctxKey] = ctx;
}
return kernelLaunchFunc_(contexts_[ctxKey], input, output, inputSize, outputSize, dtype, op, stream, nBlocks,
nThreadsPerBlock, extras, accumDtype);
}
const std::string& NativeAlgorithm::name() const { return name_; }
const std::string& NativeAlgorithm::collective() const { return collective_; }
const std::pair<size_t, size_t>& NativeAlgorithm::messageRange() const {
static std::pair<size_t, size_t> range;
range = {minMessageSize_, maxMessageSize_};
return range;
}
void NativeAlgorithm::setMessageSizeRange(size_t minMessageSize, size_t maxMessageSize) {
minMessageSize_ = minMessageSize;
maxMessageSize_ = maxMessageSize;
}
const std::unordered_map<std::string, uint64_t>& NativeAlgorithm::tags() const { return tags_; }
const CollectiveBufferMode& NativeAlgorithm::bufferMode() const { return bufferMode_; }
Algorithm::Constraint NativeAlgorithm::constraint() const { return constraint_; }
void NativeAlgorithm::reset() { contexts_.clear(); }
void AlgorithmCollection::registerAlgorithm(const std::string collective, const std::string algoName,
std::shared_ptr<Algorithm> algorithm) {
this->algoMapByCollective_[collective][algoName] = algorithm;
}
std::shared_ptr<Algorithm> AlgorithmCollection::selectAlgorithm(const CollectiveRequest& request) {
std::shared_ptr<Algorithm> algo;
if (!algoSelector_ && !fallbackAlgoSelector_) {
THROW(ALGO, Error, ErrorCode::InvalidUsage, "No algorithm selector is set in AlgorithmCollection.");
}
if (algoSelector_) {
algo = algoSelector_(algoMapByCollective_, request);
}
if (!algo) {
algo = fallbackAlgoSelector_(algoMapByCollective_, request);
}
return algo;
}
void AlgorithmCollection::extend(const AlgorithmCollection& other) {
for (const auto& [collective, algoMap] : other.algoMapByCollective_) {
for (const auto& [algoName, algorithm] : algoMap) {
this->registerAlgorithm(collective, algoName, algorithm);
}
}
}
void AlgorithmCollection::setSelectors(AlgoSelectFunc algoSelector, AlgoSelectFunc fallbackAlgoSelector) {
algoSelector_ = algoSelector;
fallbackAlgoSelector_ = fallbackAlgoSelector;
}
std::vector<std::shared_ptr<Algorithm>> AlgorithmCollection::getAllAlgorithms() const {
std::vector<std::shared_ptr<Algorithm>> allAlgos;
for (const auto& [collective, algoMap] : algoMapByCollective_) {
for (const auto& [algoName, algorithm] : algoMap) {
allAlgos.push_back(algorithm);
}
}
return allAlgos;
}
std::unordered_map<std::string, std::shared_ptr<Algorithm>> AlgorithmCollection::getAlgorithmsByCollective(
const std::string& collective) const {
auto it = algoMapByCollective_.find(collective);
if (it != algoMapByCollective_.end()) {
return it->second;
} else {
return {};
}
}
DslAlgorithm::DslAlgorithm(std::string id, ExecutionPlan plan, std::unordered_map<std::string, uint64_t> tags,
Constraint constraint)
: plan_(plan), id_(id), tags_(tags), constraint_(constraint) {}
const std::string& DslAlgorithm::name() const { return plan_.name(); }
const std::string& DslAlgorithm::collective() const { return plan_.collective(); }
const std::pair<size_t, size_t>& DslAlgorithm::messageRange() const {
static std::pair<size_t, size_t> range;
range = {plan_.minMessageSize(), plan_.maxMessageSize()};
return range;
}
void DslAlgorithm::setMessageSizeRange(size_t, size_t) {
THROW(EXEC, Error, ErrorCode::InvalidUsage, "setMessageSizeRange is only supported for native algorithms");
}
const std::unordered_map<std::string, uint64_t>& DslAlgorithm::tags() const { return tags_; }
const CollectiveBufferMode& DslAlgorithm::bufferMode() const {
// TODO: need to fix
static CollectiveBufferMode mode =
plan_.isInPlace() ? CollectiveBufferMode::InPlace : CollectiveBufferMode::OutOfPlace;
return mode;
}
Algorithm::Constraint DslAlgorithm::constraint() const { return constraint_; }
CommResult DslAlgorithm::execute(std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
size_t outputSize, DataType dtype, ReduceOp, cudaStream_t stream,
std::shared_ptr<Executor> executor, int, int, bool,
const std::unordered_map<std::string, uintptr_t>&, DataType) {
if (!executor) {
THROW(EXEC, Error, ErrorCode::InvalidUsage, "Executor is null in DslAlgorithm::execute");
}
int rank = comm->bootstrap()->getRank();
switch (dtype) {
case DataType::FLOAT16:
executor->execute(rank, (half*)input, (half*)output, inputSize, outputSize, DataType::FLOAT16, plan_, stream);
break;
case DataType::FLOAT32:
executor->execute(rank, (float*)input, (float*)output, inputSize, outputSize, DataType::FLOAT32, plan_, stream);
break;
case DataType::BFLOAT16:
executor->execute(rank, (__bfloat16*)input, (__bfloat16*)output, inputSize, outputSize, DataType::BFLOAT16, plan_,
stream);
break;
#if defined(__FP8_TYPES_EXIST__)
case DataType::FLOAT8_E4M3FN:
#if defined(__FP8_E4M3_IS_FNUZ__)
THROW(EXEC, Error, ErrorCode::InvalidUsage,
"FLOAT8_E4M3FN is not natively supported on this platform; use FLOAT8_E4M3FNUZ");
#else
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3FN,
plan_, stream);
#endif
break;
case DataType::FLOAT8_E4M3FNUZ:
#if defined(__FP8_E4M3_IS_FNUZ__)
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3FNUZ,
plan_, stream);
#else
THROW(EXEC, Error, ErrorCode::InvalidUsage,
"FLOAT8_E4M3FNUZ is not natively supported on this platform; use FLOAT8_E4M3FN");
#endif
break;
case DataType::FLOAT8_E5M2:
#if defined(__FP8_E5M2_IS_FNUZ__)
THROW(EXEC, Error, ErrorCode::InvalidUsage,
"FLOAT8_E5M2 is not natively supported on this platform; use FLOAT8_E5M2FNUZ");
#else
executor->execute(rank, (__fp8_e5m2*)input, (__fp8_e5m2*)output, inputSize, outputSize, DataType::FLOAT8_E5M2,
plan_, stream);
#endif
break;
case DataType::FLOAT8_E5M2FNUZ:
#if defined(__FP8_E5M2_IS_FNUZ__)
executor->execute(rank, (__fp8_e5m2*)input, (__fp8_e5m2*)output, inputSize, outputSize, DataType::FLOAT8_E5M2FNUZ,
plan_, stream);
#else
THROW(EXEC, Error, ErrorCode::InvalidUsage,
"FLOAT8_E5M2FNUZ is not natively supported on this platform; use FLOAT8_E5M2");
#endif
break;
#endif
case DataType::FLOAT8_E4M3B15:
executor->execute(rank, (__fp8_e4m3b15*)input, (__fp8_e4m3b15*)output, inputSize, outputSize,
DataType::FLOAT8_E4M3B15, plan_, stream);
break;
case DataType::INT32:
case DataType::UINT32:
executor->execute(rank, (int*)input, (int*)output, inputSize, outputSize, DataType::UINT32, plan_, stream);
break;
default:
WARN(ALGO, "Unsupported data type: ", static_cast<int>(dtype), " in DslAlgorithm");
return CommResult::CommInvalidArgument;
}
return CommResult::CommSuccess;
}
std::shared_ptr<Algorithm> DslAlgorithm::build() { return shared_from_this(); }
// TODO: implement this
void DslAlgorithm::reset() {}
static uint32_t* gDefaultFlagBuffer = nullptr;
static std::weak_ptr<void> gDefaultFlagBufferWeak;
static size_t gDefaultFlagCount = 128;
std::pair<std::shared_ptr<void>, size_t> getFlagBuffer() {
auto ptr = gDefaultFlagBufferWeak.lock();
if (!ptr) {
if (!gDefaultFlagBuffer) {
// Intentionally never freed — CUDA driver reclaims GPU memory at process exit.
gDefaultFlagBuffer = static_cast<uint32_t*>(mscclpp::detail::gpuCalloc(gDefaultFlagCount * sizeof(uint32_t)));
std::vector<uint32_t> initFlags(gDefaultFlagCount, 1);
mscclpp::gpuMemcpy(gDefaultFlagBuffer, initFlags.data(), gDefaultFlagCount, cudaMemcpyHostToDevice);
}
ptr = std::shared_ptr<void>(gDefaultFlagBuffer, [](void*) {});
gDefaultFlagBufferWeak = ptr;
}
return {ptr, gDefaultFlagCount * sizeof(uint32_t)};
}
} // namespace mscclpp