Files
mscclpp/python/csrc/gpu_utils_py.cpp
Binyang Li 96a72bbd3e Support E4M3B15 datatype (#765)
## Summary

- **Add `fp8_e4m3b15` datatype**: A software-defined FP8 type with 4
exponent bits, 3 mantissa bits, and bias=15 (max finite value: 0.9375).
Implemented entirely in software with no HW dependency, using
Triton-style bit manipulation through fp16 as intermediate for efficient
conversion.
- **Add mixed-precision accumulation for allreduce**: All allreduce
algorithm variants (packet, NVLS packet, fullmesh, RSAG zero-copy, and
others) now support a configurable `accumDtype` parameter, enabling FP8
inputs to be reduced in float16 or float32 for higher accuracy.
- **Propagate `accumDtype` through the full API**: The new parameter is
threaded from `Algorithm::execute()` → `NativeAlgorithm` → `KernelFunc`
→ dispatch → CUDA kernels, with `DataType::AUTO` as the default
(resolves to input dtype at runtime).
- **Add FP8 accumulation correctness tests**: New `test_fp8_accum.py`
validates that higher-precision accumulation produces results at least
as accurate as native FP8 accumulation across multiple algorithms and
sizes. Skipped on CUDA SM < 89 (pre-Hopper); runs on HIP/ROCm.
- **Add `test_fp8_accum.py` to CI**: Azure Pipeline `ut.yml` now runs
FP8 accumulation tests alongside existing pytests.
- **NCCL shim logging cleanup**: Migrated `printf`-style `WARN`/`INFO`
calls to streaming-style logging.

## Key files

| Area | Files |
|------|-------|
| New datatype + vector ops | `include/mscclpp/gpu_data_types.hpp` |
| Accumulation reduce helpers | `src/core/include/reduce_kernel.hpp` |
| Algorithm API (`accumDtype`) | `include/mscclpp/algorithm.hpp`,
`src/core/algorithm.cc` |
| Allreduce kernels | `src/ext/collectives/allreduce/*.cu` |
| Dispatch + common | `src/ext/collectives/include/allreduce/common.hpp`
|
| Python bindings | `python/csrc/algorithm.cpp`,
`python/mscclpp/_core/algorithm.py` |
| Tests | `python/test/test_fp8_accum.py` |
| CI | `.azure-pipelines/templates/ut.yml` |

## Test plan

- [x] CI passes on H100 (CUDA SM 90) — full FP8 E4M3 + E4M3B15
accumulation tests
- [x] CI passes on A100 (CUDA SM 80) — FP8 tests correctly skipped
- [x] CI passes on MI300X (ROCm) — FP8 tests run via HIP
- [x] Existing `test_mscclpp.py` tests continue to pass
- [x] NCCL shim builds and runs correctly with new `accumDtype` defaults

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-04-07 13:37:02 -07:00

130 lines
4.6 KiB
C++

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <dlpack/dlpack.h>
#include <nanobind/nanobind.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>
#include <mscclpp/gpu_data_types.hpp>
#include <mscclpp/gpu_utils.hpp>
namespace nb = nanobind;
using namespace mscclpp;
constexpr int BYTE_BITS = 8;
static DLDeviceType getDeviceType() {
#if defined(MSCCLPP_USE_ROCM)
return kDLROCM;
#else
return kDLCUDA;
#endif
}
static DLDataType getDlType(std::string type) {
if (type == "torch.float32") {
return DLDataType{kDLFloat, 32, 1};
} else if (type == "torch.int32") {
return DLDataType{kDLInt, 32, 1};
} else if (type == "torch.uint32") {
return DLDataType{kDLUInt, 32, 1};
} else if (type == "torch.bfloat16") {
return DLDataType{kDLBfloat, 16, 1};
} else if (type == "torch.float16") {
return DLDataType{kDLFloat, 16, 1};
} else if (type == "torch.float8_e4m3fn") {
return DLDataType{kDLFloat8_e4m3fn, 8, 1};
} else if (type == "torch.float8_e4m3fnuz") {
return DLDataType{kDLFloat8_e4m3fnuz, 8, 1};
} else if (type == "torch.float8_e5m2") {
return DLDataType{kDLFloat8_e5m2, 8, 1};
} else if (type == "torch.float8_e5m2fnuz") {
return DLDataType{kDLFloat8_e5m2fnuz, 8, 1};
} else if (type == "torch.uint8") {
return DLDataType{kDLUInt, 8, 1};
} else if (type == "fp8_e4m3b15") {
// No standard DLPack code for fp8_e4m3b15; store as raw uint8 bytes.
return DLDataType{kDLUInt, 8, 1};
} else {
throw Error("Unsupported type: " + type, ErrorCode::InvalidUsage);
}
}
static nb::capsule toDlpack(GpuBuffer<char> buffer, std::string dataType, std::vector<int64_t>& shape,
std::vector<int64_t>& strides) {
DLDataType dtype = getDlType(dataType);
int64_t* tensorShape = shape.size() > 0 ? new int64_t[shape.size()] : new int64_t[1];
int64_t* tensorStrides = strides.size() > 0 ? new int64_t[strides.size()] : nullptr;
if (shape.size() == 0) {
tensorShape[0] = (int64_t)(buffer.nelems() / ((dtype.bits * dtype.lanes + 7) / BYTE_BITS));
} else {
for (size_t i = 0; i < shape.size(); ++i) {
tensorShape[i] = shape[i];
}
}
for (size_t i = 0; i < strides.size(); ++i) {
tensorStrides[i] = strides[i];
}
DLManagedTensor* dlManagedTensor = new DLManagedTensor();
dlManagedTensor->dl_tensor.data = buffer.data();
dlManagedTensor->dl_tensor.device.device_type = getDeviceType();
dlManagedTensor->dl_tensor.device.device_id = buffer.deviceId();
dlManagedTensor->dl_tensor.ndim = shape.size() == 0 ? 1 : shape.size();
dlManagedTensor->dl_tensor.strides = tensorStrides;
dlManagedTensor->dl_tensor.shape = tensorShape;
dlManagedTensor->dl_tensor.byte_offset = 0;
dlManagedTensor->dl_tensor.dtype = dtype;
dlManagedTensor->manager_ctx = new GpuBuffer<char>(buffer);
dlManagedTensor->deleter = [](DLManagedTensor* self) {
delete static_cast<GpuBuffer<char>*>(self->manager_ctx);
self->manager_ctx = nullptr;
self->dl_tensor.data = nullptr;
if (self->dl_tensor.shape != nullptr) {
delete[] self->dl_tensor.shape;
self->dl_tensor.shape = nullptr;
if (self->dl_tensor.strides) {
delete[] self->dl_tensor.strides;
self->dl_tensor.strides = nullptr;
}
}
delete self;
};
PyObject* dlCapsule = PyCapsule_New(static_cast<void*>(dlManagedTensor), "dltensor", [](PyObject* capsule) {
if (PyCapsule_IsValid(capsule, "used_dltensor")) {
return;
}
if (!PyCapsule_IsValid(capsule, "dltensor")) {
return;
}
DLManagedTensor* managedTensor = static_cast<DLManagedTensor*>(PyCapsule_GetPointer(capsule, "dltensor"));
if (managedTensor == nullptr) {
return;
}
if (managedTensor->deleter) {
managedTensor->deleter(managedTensor);
}
});
return nb::steal<nb::capsule>(dlCapsule);
}
void register_gpu_utils(nb::module_& m) {
m.def("is_nvls_supported", &isNvlsSupported);
nb::class_<GpuBuffer<char>>(m, "CppRawGpuBuffer")
.def(nb::init<size_t>(), nb::arg("nelems"))
.def("nelems", &GpuBuffer<char>::nelems)
.def("bytes", &GpuBuffer<char>::bytes)
.def("data", [](GpuBuffer<char>& self) { return reinterpret_cast<uintptr_t>(self.data()); })
.def("device_id", &GpuBuffer<char>::deviceId)
.def(
"to_dlpack",
[](GpuBuffer<char>& self, std::string dataType, std::vector<int64_t> shape, std::vector<int64_t> strides) {
return toDlpack(self, dataType, shape, strides);
},
nb::arg("data_type"), nb::arg("shape") = std::vector<int64_t>(), nb::arg("strides") = std::vector<int64_t>());
}