mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
## Summary - Fix ROCm FP8 build failure by using the actual FP8 `DataType` enum constants in allreduce packet tuning. - Fix FP8 E4M3FNUZ test encoding so small negative values do not produce the FNUZ NaN byte (`0x80`). - Align FP8 `DataType` enum constants and Python bindings with torch-style names (`FLOAT8_E4M3FN`, `FLOAT8_E4M3FNUZ`, `FLOAT8_E5M2FNUZ` / `float8_e4m3fn`, `float8_e4m3fnuz`, `float8_e5m2fnuz`). ## Validation - `./tools/lint.sh` - `make -j` from `build/` - `mpirun --allow-run-as-root -np 8 python3 -m pytest python/test/test_fp8_accum.py -q` (`36 passed, 9 skipped`) - `DTYPE=float8_e4m3fnuz ACCUM_DTYPE=float32 torchrun --nnodes=1 --nproc_per_node=8 examples/torch-integration/customized_comm_with_tuning.py` --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
188 lines
7.0 KiB
C++
188 lines
7.0 KiB
C++
// Copyright (c) Microsoft Corporation.
|
|
// Licensed under the MIT License.
|
|
|
|
#include "algorithm_selector.hpp"
|
|
|
|
#include <mscclpp/env.hpp>
|
|
#include <mscclpp/utils.hpp>
|
|
|
|
#include "debug.h"
|
|
|
|
namespace mscclpp {
|
|
namespace nccl {
|
|
|
|
static bool isNvlsSupportedForDataType(const AlgorithmSelectorConfig& config, DataType dtype) {
|
|
bool nvlsSupported = config.nvlsSupported;
|
|
|
|
// NVLS does not support uint8_t (no hardware support for byte-level reduction)
|
|
// NVLS also does not support float8_e4m3b15 (software-defined type with no hardware NVLS reduction support)
|
|
if (dtype == DataType::UINT8 || dtype == DataType::FLOAT8_E4M3B15) {
|
|
return false;
|
|
}
|
|
|
|
const bool isFp8 = dtype == DataType::FLOAT8_E4M3FN || dtype == DataType::FLOAT8_E4M3FNUZ ||
|
|
dtype == DataType::FLOAT8_E5M2 || dtype == DataType::FLOAT8_E5M2FNUZ;
|
|
|
|
if (!isFp8) {
|
|
return nvlsSupported;
|
|
}
|
|
|
|
// FP8 handling
|
|
#if !defined(__HIP_PLATFORM_AMD__)
|
|
// NVLS does not support FP8 on devices with compute capability < 10
|
|
if (config.computeCapability.first < 10) {
|
|
return false;
|
|
}
|
|
#if (defined(__CUDA_ARCH_SPECIFIC__) || defined(__CUDA_ARCH_FAMILY_SPECIFIC__))
|
|
return true;
|
|
#else
|
|
return false;
|
|
#endif
|
|
#else
|
|
return nvlsSupported;
|
|
#endif
|
|
}
|
|
|
|
bool matchExecutionPlan(std::shared_ptr<DslAlgorithm> algo, const CollectiveRequest& request) {
|
|
bool worldSizeMatch = algo->constraint().worldSize == request.worldSize;
|
|
bool ranksPerNodeMatch = algo->constraint().nRanksPerNode == request.nRanksPerNode;
|
|
bool collectiveMatch = algo->collective() == request.collective;
|
|
bool bufferModeMatch = algo->bufferMode() == CollectiveBufferMode::Any || request.bufferMode() == algo->bufferMode();
|
|
size_t effectiveSize =
|
|
(request.collective == "allgather") ? (request.messageSize * request.worldSize) : request.messageSize;
|
|
bool minSizeMatch = effectiveSize >= algo->messageRange().first;
|
|
bool maxSizeMatch = effectiveSize <= algo->messageRange().second;
|
|
bool result =
|
|
worldSizeMatch && ranksPerNodeMatch && collectiveMatch && bufferModeMatch && minSizeMatch && maxSizeMatch;
|
|
return result;
|
|
}
|
|
|
|
static std::shared_ptr<Algorithm> selectSingleNodeAllreduceBlackwell(
|
|
const std::unordered_map<std::string, std::shared_ptr<Algorithm>>& algoMap, const CollectiveRequest& request,
|
|
const AlgorithmSelectorConfig& config) {
|
|
const size_t messageSize = request.messageSize;
|
|
|
|
const bool nvlsSupported = isNvlsSupportedForDataType(config, request.dtype);
|
|
|
|
// Small messages always use NVLS packet algorithm
|
|
if (messageSize <= (1 << 15)) { // <= 32KB
|
|
return algoMap.at("default_allreduce_nvls_packet");
|
|
}
|
|
|
|
if (!config.symmetricMemory) {
|
|
if (messageSize <= (1 << 21)) { // <= 2MB
|
|
return algoMap.at("default_allreduce_packet");
|
|
}
|
|
if (config.inCaptureMode) {
|
|
// CUDA graph mode: setup new connections each time (zero-copy for graph)
|
|
return algoMap.at("default_allreduce_rsag_zero_copy");
|
|
}
|
|
// Non-graph mode: use non-zero-copy algorithms
|
|
if (messageSize <= (1 << 23)) { // <= 8MB
|
|
return algoMap.at("default_allreduce_rsag");
|
|
}
|
|
return algoMap.at("default_allreduce_rsag_pipeline");
|
|
}
|
|
|
|
// Symmetric memory path: can use cached memory handles
|
|
const bool useNvlsWithZeroCopy = nvlsSupported && config.isCuMemMapAllocated;
|
|
if (messageSize <= (1 << 16) || (messageSize <= (1 << 20) && !useNvlsWithZeroCopy)) { // <= 64KB or <= 1MB
|
|
return algoMap.at("default_allreduce_packet");
|
|
}
|
|
if (useNvlsWithZeroCopy) {
|
|
return algoMap.at("default_allreduce_nvls_zero_copy");
|
|
}
|
|
|
|
return algoMap.at("default_allreduce_rsag_zero_copy");
|
|
}
|
|
|
|
std::shared_ptr<Algorithm> selectSingleNodeAllreduce(
|
|
const std::unordered_map<std::string, std::shared_ptr<Algorithm>>& algoMap, const CollectiveRequest& request,
|
|
const AlgorithmSelectorConfig& config) {
|
|
// Use Blackwell-specific selection for compute capability 10.x
|
|
if (config.computeCapability.first == 10) {
|
|
return selectSingleNodeAllreduceBlackwell(algoMap, request, config);
|
|
}
|
|
|
|
const size_t messageSize = request.messageSize;
|
|
|
|
// Determine NVLS availability based on data type and device capability
|
|
const bool nvlsSupported = isNvlsSupportedForDataType(config, request.dtype);
|
|
|
|
const bool useNvlsWithZeroCopy = nvlsSupported && config.symmetricMemory && config.isCuMemMapAllocated;
|
|
|
|
// Very small messages: use allpair packet algorithm
|
|
if (messageSize <= (1 << 14)) { // <= 16KB
|
|
return algoMap.at("default_allreduce_allpair_packet");
|
|
}
|
|
// Small messages with NVLS support
|
|
if (messageSize <= (1 << 15) && nvlsSupported) { // <= 32KB
|
|
return algoMap.at("default_allreduce_nvls_packet");
|
|
}
|
|
// Medium messages: use packet algorithm
|
|
if (messageSize <= (1 << 16) || (messageSize <= (1 << 20) && !useNvlsWithZeroCopy)) { // <= 64KB or <= 1MB
|
|
return algoMap.at("default_allreduce_packet");
|
|
}
|
|
// Large messages with NVLS zero-copy support
|
|
if (nvlsSupported && useNvlsWithZeroCopy) {
|
|
return algoMap.at("default_allreduce_nvls_zero_copy");
|
|
}
|
|
// Large messages with NVLS but without zero-copy
|
|
if (nvlsSupported) {
|
|
if (messageSize < (1 << 24)) { // < 16MB
|
|
return algoMap.at("default_allreduce_nvls_warp_pipeline");
|
|
}
|
|
return algoMap.at("default_allreduce_nvls_block_pipeline");
|
|
}
|
|
#if defined(__HIP_PLATFORM_AMD__)
|
|
// AMD platform: use fullmesh algorithm
|
|
return algoMap.at("default_allreduce_fullmesh");
|
|
#else
|
|
// NVIDIA without NVLS: use RSAG pipeline if no NCCL fallback
|
|
if (!config.ncclDlopenSharedLib) {
|
|
return algoMap.at("default_allreduce_fullmesh");
|
|
}
|
|
return nullptr;
|
|
#endif
|
|
}
|
|
|
|
std::shared_ptr<Algorithm> selectSingleNodeAllgather(
|
|
const std::unordered_map<std::string, std::shared_ptr<Algorithm>>& algoMap, const CollectiveRequest& request,
|
|
[[maybe_unused]] const AlgorithmSelectorConfig& config) {
|
|
const size_t messageSize = request.messageSize;
|
|
|
|
// For messages up to 32MB, use fullmesh2 algorithm
|
|
if (messageSize <= 32 * (1 << 20)) {
|
|
return algoMap.at("default_allgather_fullmesh2");
|
|
}
|
|
|
|
#if defined(__HIP_PLATFORM_AMD__)
|
|
// AMD platform always uses fullmesh2
|
|
return algoMap.at("default_allgather_fullmesh2");
|
|
#else
|
|
// NVIDIA: use fullmesh for large messages if no NCCL fallback is available
|
|
if (!config.ncclDlopenSharedLib) {
|
|
return algoMap.at("default_allgather_fullmesh");
|
|
}
|
|
return nullptr;
|
|
#endif
|
|
}
|
|
|
|
std::shared_ptr<Algorithm> selectMultiNodeAlgorithm(
|
|
const std::unordered_map<std::string, std::shared_ptr<Algorithm>>& algoMap [[maybe_unused]],
|
|
const CollectiveRequest& request [[maybe_unused]], const AlgorithmSelectorConfig& config [[maybe_unused]]) {
|
|
// TODO: Implement multi-node algorithm selection
|
|
// Multi-node scenarios will need to consider:
|
|
// 1. Multi-node NVLS (if supported by hardware)
|
|
// 2. Multi-node IB (InfiniBand)
|
|
// 3. Hierarchical algorithms (intra-node + inter-node)
|
|
// 4. Network topology awareness
|
|
|
|
// For now, return nullptr to fallback to NCCL/RCCL
|
|
INFO(MSCCLPP_NCCL, "Multi-node collective not yet supported, fallback to nccl/rccl");
|
|
return nullptr;
|
|
}
|
|
|
|
} // namespace nccl
|
|
} // namespace mscclpp
|