Fix FP8 ROCm build/test issues and dtype naming (#792)

## Summary
- Fix ROCm FP8 build failure by using the actual FP8 `DataType` enum
constants in allreduce packet tuning.
- Fix FP8 E4M3FNUZ test encoding so small negative values do not produce
the FNUZ NaN byte (`0x80`).
- Align FP8 `DataType` enum constants and Python bindings with
torch-style names (`FLOAT8_E4M3FN`, `FLOAT8_E4M3FNUZ`, `FLOAT8_E5M2FNUZ`
/ `float8_e4m3fn`, `float8_e4m3fnuz`, `float8_e5m2fnuz`).

## Validation
- `./tools/lint.sh`
- `make -j` from `build/`
- `mpirun --allow-run-as-root -np 8 python3 -m pytest
python/test/test_fp8_accum.py -q` (`36 passed, 9 skipped`)
- `DTYPE=float8_e4m3fnuz ACCUM_DTYPE=float32 torchrun --nnodes=1
--nproc_per_node=8
examples/torch-integration/customized_comm_with_tuning.py`

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Binyang Li
2026-04-28 15:02:22 -07:00
committed by GitHub
parent c97be492d5
commit 2c52937b26
12 changed files with 271 additions and 138 deletions

View File

@@ -21,7 +21,10 @@ using __bfloat162 = __hip_bfloat162;
#if defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR >= 6)
#include <hip/hip_fp8.h>
// Create aliases matching CUDA naming convention for cross-platform compatibility
// Create aliases matching CUDA naming convention for cross-platform compatibility.
// Define __FP8_E4M3_IS_FNUZ__ / __FP8_E5M2_IS_FNUZ__ when the platform-native FP8 is the
// "fnuz" variant (no infinities, NaN-only at 0x80, bias differs from OCP). Dispatch layers
// use these macros to throw on unsupported variants requested via DataType.
#if (HIP_VERSION_MAJOR == 6) || (HIP_VERSION_MAJOR > 6 && HIP_FP8_TYPE_FNUZ && !HIP_FP8_TYPE_OCP)
using __fp8_e4m3 = __hip_fp8_e4m3_fnuz;
using __fp8_e5m2 = __hip_fp8_e5m2_fnuz;
@@ -29,6 +32,8 @@ using __fp8x2_e4m3 = __hip_fp8x2_e4m3_fnuz;
using __fp8x2_e5m2 = __hip_fp8x2_e5m2_fnuz;
using __fp8x4_e4m3 = __hip_fp8x4_e4m3_fnuz;
using __fp8x4_e5m2 = __hip_fp8x4_e5m2_fnuz;
#define __FP8_E4M3_IS_FNUZ__
#define __FP8_E5M2_IS_FNUZ__
#else
using __fp8_e4m3 = __hip_fp8_e4m3;
using __fp8_e5m2 = __hip_fp8_e5m2;
@@ -66,8 +71,8 @@ using __bfloat162 = __nv_bfloat162;
/// Software float8 with 4 exponent bits, 3 mantissa bits, exponent bias = 15.
/// Format (MSB first): [sign:1][exponent:4][mantissa:3]
/// No infinities; exp=15 is NaN. Negative zero is NaN (fnuz convention).
/// Max finite value: 0.9375, min normal: ~6.1e-5, min subnormal: ~7.6e-6.
/// No infinities, no NaN. Encode saturates to ±1.75 (0x7e/0xfe).
/// Adapted from the Triton compiler's fp8e4b15 format.
struct alignas(1) __fp8_e4m3b15 {
uint8_t __x;
@@ -97,35 +102,15 @@ struct alignas(1) __fp8_e4m3b15 {
/// Algorithm: reinterpret fp8 bits into an fp16 bit pattern with exponent shifted by -8,
/// then convert fp16 → float32.
static MSCCLPP_HOST_DEVICE_INLINE float toFloat(uint8_t bits) {
// Handle special values: negative zero (0x80) → NaN, exponent=15 → NaN.
uint32_t exp = (bits >> 3) & 0xFu;
if (bits == 0x80 || exp == 15) {
union {
uint32_t u;
float f;
} nan_val = {0x7FC00000u};
return nan_val.f;
}
if (bits == 0) return 0.0f;
// Triton-style bit manipulation: fp8 → fp16 → fp32.
// fp8 layout: [S:1][E:4][M:3] (bias=15)
// fp16 layout: [S:1][E:5][M:10] (bias=15)
//
// Place fp8 in upper byte of fp16, then right-shift exponent+mantissa by 1
// to convert E4 → E5 (both share bias=15). Sign bit stays at bit 15.
// Branch-free decode: fp8 → fp16 → fp32, no special-case handling.
// Encode saturates to ±1.75, so 0x7f/0xff are never produced.
// Refer:
// https://github.com/triton-lang/triton/blob/cf34004b8a67d290a962da166f5aa2fc66751326/python/triton/language/extra/cuda/utils.py#L34
uint16_t h = (uint16_t)bits << 8; // place fp8 in upper byte of fp16
uint16_t sign16 = h & 0x8000u; // extract sign at fp16 position
uint16_t nosign = h & 0x7F00u; // exponent + mantissa (no sign)
uint16_t fp16_bits = sign16 | (nosign >> 1); // shift exponent right by 1
uint16_t fp16_bits = sign16 | (nosign >> 1); // shift exponent right by 1 (E4→E5)
// For subnormals: when fp8 exponent=0, the above gives fp16 exponent=0
// and fp16 mantissa = (fp8_mantissa << 7), which correctly represents
// the subnormal fp16 value since both share bias=15.
// Convert fp16 bits to float via __half (works on host and device, CUDA and HIP).
union {
uint16_t u;
__half h;
@@ -139,14 +124,6 @@ struct alignas(1) __fp8_e4m3b15 {
/// The key insight is to convert to fp16 first (which shares bias=15 with e4m3b15),
/// then pack the fp16 bits back into 8 bits by shifting the exponent left by 1.
static MSCCLPP_HOST_DEVICE_INLINE uint8_t fromFloat(float val) {
union {
float f;
uint32_t u;
} in = {val};
// NaN → 0x80 (negative-zero bit pattern = NaN in fnuz).
if ((in.u & 0x7F800000u) == 0x7F800000u && (in.u & 0x007FFFFFu) != 0) return 0x80u;
// Convert float32 → fp16 bits via __half (works on host and device, CUDA and HIP).
__half h_val = __float2half_rn(val);
union {
@@ -155,32 +132,19 @@ struct alignas(1) __fp8_e4m3b15 {
} cvt = {h_val};
uint16_t fp16_bits = cvt.u;
// Clamp absolute value to max finite e4m3b15: 0.9375 → fp16 = 0x3B80.
// Clamp abs to max encodable value: 1.75 → fp16 = 0x3F00.
// Matches Triton: encode saturates, 0x7f/0xff are never produced.
uint16_t abs_fp16 = fp16_bits & 0x7FFFu;
if (abs_fp16 > 0x3B80u) abs_fp16 = 0x3B80u;
if (abs_fp16 > 0x3F00u) abs_fp16 = 0x3F00u;
// Reconstruct with sign.
uint16_t sign16 = fp16_bits & 0x8000u;
// Triton-style: fp16 → fp8.
// fp16 layout: [S:1][E:5][M:10] (bias=15)
// fp8 layout: [S:1][E:4][M:3] (bias=15)
//
// mad.lo.u32 a0, a0, 2, 0x00800080 → (abs_fp16 * 2 + 0x0080)
// This shifts left by 1 (undoing the right-shift in decode) and adds rounding bias.
// Then: lop3.b32 b0, $1, 0x80008000, a0, 0xea → (sign & 0x8000) | a0
// Finally: prmt for byte extraction.
//
// Simplified for scalar: shift abs_fp16 left by 1, add rounding bias, take upper byte.
// fp16 → fp8: shift abs left by 1 (undo decode's right-shift), add rounding bias, take upper byte.
uint16_t adjusted = (uint16_t)(abs_fp16 * 2u + 0x0080u);
// The upper byte now contains [E:4][M:3][round_bit].
// Combine with sign and extract.
uint16_t with_sign = sign16 | adjusted;
uint8_t result = (uint8_t)(with_sign >> 8);
// Zero → 0x00 (ensure positive zero, not negative zero which is NaN).
if ((result & 0x7Fu) == 0) result = 0x00u;
return result;
}
};
@@ -199,16 +163,18 @@ namespace mscclpp {
/// Data types supported by mscclpp operations.
enum class DataType {
INT32, // 32-bit signed integer.
UINT32, // 32-bit unsigned integer.
FLOAT16, // IEEE 754 half precision.
FLOAT32, // IEEE 754 single precision.
BFLOAT16, // bfloat16 precision.
FLOAT8_E4M3, // float8 with E4M3 layout.
FLOAT8_E5M2, // float8 with E5M2 layout.
UINT8, // 8-bit unsigned integer.
FLOAT8_E4M3B15, // float8 with E4M3 layout, bias=15 (software, no HW accel).
AUTO = 255, // Sentinel: resolve to the input dtype at runtime.
INT32, // 32-bit signed integer.
UINT32, // 32-bit unsigned integer.
FLOAT16, // IEEE 754 half precision.
FLOAT32, // IEEE 754 single precision.
BFLOAT16, // bfloat16 precision.
FLOAT8_E4M3FN, // float8 E4M3, OCP variant (NV; AMD HIP > 6 with OCP enabled).
FLOAT8_E4M3FNUZ, // float8 E4M3, fnuz variant (AMD HIP 6, or HIP > 6 with FNUZ enabled).
FLOAT8_E5M2, // float8 E5M2, OCP variant (NV; AMD HIP > 6 with OCP enabled).
FLOAT8_E5M2FNUZ, // float8 E5M2, fnuz variant (AMD HIP 6, or HIP > 6 with FNUZ enabled).
UINT8, // 8-bit unsigned integer.
FLOAT8_E4M3B15, // float8 with E4M3 layout, bias=15 (software, no HW accel).
AUTO = 255, // Sentinel: resolve to the input dtype at runtime.
};
/// Word array.
@@ -1137,11 +1103,11 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to<f8_e4m3b15x2, f16x2>(const f16x2& v) {
#if defined(MSCCLPP_DEVICE_CUDA)
uint32_t in0;
asm("mov.b32 %0, %1;" : "=r"(in0) : "r"(*reinterpret_cast<const uint32_t*>(&v)));
// Clamp abs to max finite e4m3b15 (0x3B80 = 0.9375 in fp16).
// Clamp abs to max encodable e4m3b15 (0x3F00 = 1.75 in fp16).
uint32_t lo = in0 & 0xFFFFu, hi = in0 >> 16;
uint32_t alo = lo & 0x7FFFu, ahi = hi & 0x7FFFu;
alo = alo < 0x3B80u ? alo : 0x3B80u;
ahi = ahi < 0x3B80u ? ahi : 0x3B80u;
alo = alo < 0x3F00u ? alo : 0x3F00u;
ahi = ahi < 0x3F00u ? ahi : 0x3F00u;
uint32_t a0 = alo | (ahi << 16);
a0 = a0 * 2u + 0x00800080u;
uint32_t b0 = a0 | (in0 & 0x80008000u);
@@ -1152,7 +1118,7 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to<f8_e4m3b15x2, f16x2>(const f16x2& v) {
uint32_t in0 = v.words[0];
uint32_t abs0 = in0 & 0x7fff7fffu;
uint32_t a0;
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3B803B80u));
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F003F00u));
a0 = a0 * 2u + 0x00800080u;
uint32_t b0 = a0 | (in0 & 0x80008000u);
uint16_t packed = (uint16_t)(((b0 >> 8) & 0xFFu) | ((b0 >> 16) & 0xFF00u));
@@ -1175,8 +1141,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to<f8_e4m3b15x4, f16x4>(const f16x4& v) {
asm("mov.b32 %0, %1;" : "=r"(in1) : "r"(v.words[1]));
uint32_t abs0 = in0 & 0x7fff7fffu;
uint32_t abs1 = in1 & 0x7fff7fffu;
uint32_t a0 = __vminu2(abs0, 0x3B803B80u);
uint32_t a1 = __vminu2(abs1, 0x3B803B80u);
uint32_t a0 = __vminu2(abs0, 0x3F003F00u);
uint32_t a1 = __vminu2(abs1, 0x3F003F00u);
a0 = a0 * 2u + 0x00800080u;
a1 = a1 * 2u + 0x00800080u;
uint32_t b0, b1;
@@ -1189,8 +1155,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to<f8_e4m3b15x4, f16x4>(const f16x4& v) {
uint32_t in0 = v.words[0], in1 = v.words[1];
uint32_t abs0 = in0 & 0x7fff7fffu, abs1 = in1 & 0x7fff7fffu;
uint32_t a0, a1;
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3B803B80u));
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a1) : "v"(abs1), "v"(0x3B803B80u));
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F003F00u));
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a1) : "v"(abs1), "v"(0x3F003F00u));
a0 = a0 * 2u + 0x00800080u;
a1 = a1 * 2u + 0x00800080u;
uint32_t b0 = a0 | (in0 & 0x80008000u);

View File

@@ -45,8 +45,10 @@ void register_core(nb::module_& m) {
.value("float16", DataType::FLOAT16)
.value("float32", DataType::FLOAT32)
.value("bfloat16", DataType::BFLOAT16)
.value("float8_e4m3", DataType::FLOAT8_E4M3)
.value("float8_e4m3fn", DataType::FLOAT8_E4M3FN)
.value("float8_e4m3fnuz", DataType::FLOAT8_E4M3FNUZ)
.value("float8_e5m2", DataType::FLOAT8_E5M2)
.value("float8_e5m2fnuz", DataType::FLOAT8_E5M2FNUZ)
.value("uint8", DataType::UINT8)
.value("float8_e4m3b15", DataType::FLOAT8_E4M3B15);
@@ -328,4 +330,4 @@ NB_MODULE(_mscclpp, m) {
// ext
register_algorithm_collection_builder(m);
}
}

View File

@@ -192,12 +192,14 @@ def torch_dtype_to_mscclpp_dtype(dtype: "torch.dtype") -> DataType:
return DataType.int32
elif dtype == torch.bfloat16:
return DataType.bfloat16
# Hardware supports either OCP format or FNUZ format for float8.
# Mapping both to the same MSCClPP data type.
elif dtype == torch.float8_e5m2 or dtype == torch.float8_e5m2fnuz:
elif dtype == torch.float8_e5m2:
return DataType.float8_e5m2
elif dtype == torch.float8_e4m3fn or dtype == torch.float8_e4m3fnuz:
return DataType.float8_e4m3
elif dtype == torch.float8_e5m2fnuz:
return DataType.float8_e5m2fnuz
elif dtype == torch.float8_e4m3fn:
return DataType.float8_e4m3fn
elif dtype == torch.float8_e4m3fnuz:
return DataType.float8_e4m3fnuz
elif dtype == torch.uint8:
return DataType.uint8
else:

View File

@@ -21,6 +21,13 @@ from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group
# FP8 E4M3 (hardware) requires SM >= 89 (Ada / Hopper) on NVIDIA GPUs.
# On AMD/ROCm (e.g. MI300X), FP8 is supported natively — no skip needed.
_is_hip = hasattr(cp.cuda.runtime, "is_hip") and cp.cuda.runtime.is_hip
_gcn_arch_name = ""
if _is_hip:
_gcn_arch_name = cp.cuda.runtime.getDeviceProperties(0).get("gcnArchName", b"")
if isinstance(_gcn_arch_name, bytes):
_gcn_arch_name = _gcn_arch_name.decode()
_gcn_arch_name = _gcn_arch_name.split(":", maxsplit=1)[0]
_is_cdna4 = _gcn_arch_name.startswith("gfx95")
_skip_fp8 = not _is_hip and int(cp.cuda.Device().compute_capability) < 89
pytestmark = pytest.mark.skipif(_skip_fp8, reason="FP8 accum tests require SM >= 89 on CUDA")
@@ -90,7 +97,78 @@ def float_to_e4m3fn(f32_array, chunk_size=65536):
# ---------------------------------------------------------------------------
# FP8 E4M3B15 helpers (bias=15, max=0.9375, NaN = exp==15 or bits==0x80)
# FP8 E4M3FNUZ helpers (AMD/ROCm; bias=8, max=240, NaN = bits==0x80, no -0)
# ---------------------------------------------------------------------------
def e4m3fnuz_to_float(uint8_array):
"""Decode a cupy uint8 array of E4M3FNUZ bit patterns to float32."""
bits = uint8_array.astype(cp.int32)
sign = (bits >> 7) & 1
exp = (bits >> 3) & 0xF
mant = bits & 0x7
# Normal: (-1)^s * 2^(exp-8) * (1 + mant/8)
normal_val = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), (exp - 8).astype(cp.int32))
# Subnormal (exp==0): (-1)^s * 2^(-7) * (mant/8)
subnormal_val = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-7))
result = cp.where(exp == 0, subnormal_val, normal_val)
result = cp.where(sign == 1, -result, result)
# Zero is only 0x00; the 0x80 encoding is reserved for NaN under fnuz.
result = cp.where(uint8_array.astype(cp.int32) == 0, cp.float32(0.0), result)
nan_mask = uint8_array.astype(cp.int32) == 0x80
result = cp.where(nan_mask, cp.float32(float("nan")), result)
return result
def float_to_e4m3fnuz(f32_array, chunk_size=65536):
"""Encode a cupy float32 array to uint8 E4M3FNUZ bit patterns.
Same lookup-table approach as float_to_e4m3fn but using the fnuz table.
"""
all_bytes = cp.arange(128, dtype=cp.uint8)
all_floats = e4m3fnuz_to_float(all_bytes)
all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats)
clamped = f32_array.astype(cp.float32)
clamped = cp.clip(clamped, -240.0, 240.0)
signs = (clamped < 0).astype(cp.uint8)
absval = cp.abs(clamped)
result = cp.zeros(absval.shape, dtype=cp.uint8)
n = absval.size
absval_flat = absval.ravel()
result_flat = result.ravel()
for start in range(0, n, chunk_size):
end = min(start + chunk_size, n)
chunk = absval_flat[start:end]
diffs = cp.abs(chunk[:, None] - all_floats[None, :])
result_flat[start:end] = cp.argmin(diffs, axis=1).astype(cp.uint8)
result = result_flat.reshape(absval.shape)
result = result | (signs << 7)
# 0x80 is NaN under fnuz (no negative zero). Collapse any encoding that
# landed on 0x80 (small negatives quantised to zero magnitude) to 0x00.
result = cp.where(result == 0x80, cp.uint8(0), result)
return result
# Platform-aware E4M3 native helpers: ROCm CDNA4 and CUDA use OCP fn; older ROCm uses fnuz.
if _is_hip and not _is_cdna4:
e4m3_native_to_float = e4m3fnuz_to_float
float_to_e4m3_native = float_to_e4m3fnuz
fp8_native_dtype = DataType.float8_e4m3fnuz
else:
e4m3_native_to_float = e4m3fn_to_float
float_to_e4m3_native = float_to_e4m3fn
fp8_native_dtype = DataType.float8_e4m3fn
# ---------------------------------------------------------------------------
# FP8 E4M3B15 helpers (bias=15, encode saturates to ±1.75, no NaN)
# Matches Triton's fp8e4b15: all 256 bit patterns are finite.
# ---------------------------------------------------------------------------
@@ -108,11 +186,6 @@ def e4m3b15_to_float(uint8_array):
result = cp.where(exp == 0, subnormal_val, normal_val)
result = cp.where(sign == 1, -result, result)
# Zero
result = cp.where((exp == 0) & (mant == 0), cp.float32(0.0), result)
# NaN: exp==15 or negative zero (0x80)
nan_mask = (exp == 15) | (uint8_array.astype(cp.int32) == 0x80)
result = cp.where(nan_mask, cp.float32(float("nan")), result)
return result
@@ -120,18 +193,17 @@ def float_to_e4m3b15(f32_array, chunk_size=65536):
"""Encode a cupy float32 array to uint8 E4M3B15 bit patterns.
Same lookup-table approach as float_to_e4m3fn.
Saturates to ±1.75 (0x7e/0xfe), matching Triton's fp8e4b15.
"""
# Build lookup table of all 128 positive E4M3B15 values (0x00..0x7F)
all_bytes = cp.arange(128, dtype=cp.uint8)
all_floats = e4m3b15_to_float(all_bytes) # (128,) float32
# Mark NaN entries as inf so they're never selected as nearest
all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats)
# Clamp input and extract sign
clamped = f32_array.astype(cp.float32)
clamped = cp.clip(clamped, -0.9375, 0.9375)
signs = (clamped < 0).astype(cp.uint8)
absval = cp.abs(clamped)
# Clamp input and extract sign.
values = f32_array.astype(cp.float32)
signs = cp.signbit(values).astype(cp.uint8)
absval = cp.abs(values)
absval = cp.clip(absval, cp.float32(0.0), cp.float32(1.75))
result = cp.zeros(absval.shape, dtype=cp.uint8)
n = absval.size
@@ -148,8 +220,6 @@ def float_to_e4m3b15(f32_array, chunk_size=65536):
# Combine with sign bit
result = result_flat.reshape(absval.shape)
result = result | (signs << 7)
# Handle exact zero
result = cp.where(absval == 0, cp.uint8(0), result)
return result
@@ -226,12 +296,6 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
buf = GpuBuffer(size, dtype=cp.uint8)
accum_configs = [
("fp8_native", DataType.float8_e4m3),
("float16", DataType.float16),
("float32", DataType.float32),
]
# rsag_zero_copy and fullmesh need explicit block/thread counts
if "rsag" in algo_name:
nb = max(1, min(32, size // (world_size * 32)))
@@ -243,13 +307,19 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
nb = 0
nt = 0
accum_configs = [
("fp8_native", fp8_native_dtype),
("float16", DataType.float16),
("float32", DataType.float32),
]
errors = {}
for accum_label, accum_dtype in accum_configs:
# Generate deterministic per-rank data (use numpy to avoid hipRAND issues on ROCm)
rng = np.random.RandomState(42 + rank)
src_f32 = cp.asarray(rng.randn(size).astype(np.float32))
src_f32 = cp.clip(src_f32, -240.0, 240.0)
src_fp8 = float_to_e4m3fn(src_f32)
src_fp8 = float_to_e4m3_native(src_f32)
# Copy into symmetric buffer
buf[:] = src_fp8
@@ -260,12 +330,12 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
algo,
comm_group,
buf,
dtype=DataType.float8_e4m3,
dtype=fp8_native_dtype,
accum_dtype=accum_dtype,
nblocks=nb,
nthreads_per_block=nt,
)
result_f32 = e4m3fn_to_float(result)
result_f32 = e4m3_native_to_float(result)
# Compute float32 reference: sum all ranks' quantized FP8 inputs in float32
ref_f32 = cp.zeros(size, dtype=cp.float32)
@@ -273,12 +343,13 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
rng_r = np.random.RandomState(42 + r)
rank_data = cp.asarray(rng_r.randn(size).astype(np.float32))
rank_data = cp.clip(rank_data, -240.0, 240.0)
rank_data_fp8 = float_to_e4m3fn(rank_data)
ref_f32 += e4m3fn_to_float(rank_data_fp8)
rank_data_fp8 = float_to_e4m3_native(rank_data)
ref_f32 += e4m3_native_to_float(rank_data_fp8)
# Compute errors
abs_err = cp.abs(result_f32 - ref_f32)
mean_abs_err = float(cp.mean(abs_err))
# Compute errors (only on valid, non-NaN entries)
valid = ~cp.isnan(result_f32) & ~cp.isnan(ref_f32)
abs_err = cp.abs(result_f32[valid] - ref_f32[valid])
mean_abs_err = float(cp.mean(abs_err)) if abs_err.size > 0 else 0.0
errors[accum_label] = mean_abs_err
# Reset between runs
@@ -341,13 +412,10 @@ def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int):
errors = {}
for accum_label, accum_dtype in accum_configs:
# Generate deterministic per-rank random uint8 values in valid e4m3b15 range
# Generate deterministic per-rank random uint8 values covering the full e4m3b15 range.
# All 256 bit patterns are valid (no NaN in this format).
rng = np.random.RandomState(42 + rank)
raw = cp.asarray(rng.randint(0, 0x78, (size,)).astype(np.uint8))
signs = cp.asarray(rng.randint(0, 2, (size,)).astype(np.uint8)) << 7
src_uint8 = raw | signs
# Fix negative zero -> positive zero
src_uint8 = cp.where(src_uint8 == 0x80, cp.uint8(0), src_uint8)
src_uint8 = cp.asarray(rng.randint(0, 256, (size,)).astype(np.uint8))
# Copy into symmetric buffer
buf[:] = src_uint8
@@ -371,19 +439,15 @@ def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int):
ref_f32 = cp.zeros(size, dtype=cp.float32)
for r in range(world_size):
rng_r = np.random.RandomState(42 + r)
raw_r = cp.asarray(rng_r.randint(0, 0x78, (size,)).astype(np.uint8))
signs_r = cp.asarray(rng_r.randint(0, 2, (size,)).astype(np.uint8)) << 7
bits_r = raw_r | signs_r
bits_r = cp.where(bits_r == 0x80, cp.uint8(0), bits_r)
bits_r = cp.asarray(rng_r.randint(0, 256, (size,)).astype(np.uint8))
ref_f32 += e4m3b15_to_float(bits_r)
# Clamp reference to e4m3b15 representable range
ref_f32 = cp.clip(ref_f32, -0.9375, 0.9375)
ref_f32 = cp.clip(ref_f32, -1.75, 1.75)
# Compute errors (only on valid entries)
valid = ~cp.isnan(result_f32) & ~cp.isnan(ref_f32)
abs_err = cp.abs(result_f32[valid] - ref_f32[valid])
mean_abs_err = float(cp.mean(abs_err)) if abs_err.size > 0 else 0.0
# Compute errors
abs_err = cp.abs(result_f32 - ref_f32)
mean_abs_err = float(cp.mean(abs_err))
errors[accum_label] = mean_abs_err
algo.reset()

View File

@@ -3,6 +3,7 @@
#include <filesystem>
#include <mscclpp/algorithm.hpp>
#include <mscclpp/errors.hpp>
#include <mscclpp/gpu_utils.hpp>
#include "logger.hpp"
@@ -182,13 +183,41 @@ CommResult DslAlgorithm::execute(std::shared_ptr<Communicator> comm, const void*
stream);
break;
#if defined(__FP8_TYPES_EXIST__)
case DataType::FLOAT8_E4M3:
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3,
case DataType::FLOAT8_E4M3FN:
#if defined(__FP8_E4M3_IS_FNUZ__)
THROW(EXEC, Error, ErrorCode::InvalidUsage,
"FLOAT8_E4M3FN is not natively supported on this platform; use FLOAT8_E4M3FNUZ");
#else
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3FN,
plan_, stream);
#endif
break;
case DataType::FLOAT8_E4M3FNUZ:
#if defined(__FP8_E4M3_IS_FNUZ__)
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3FNUZ,
plan_, stream);
#else
THROW(EXEC, Error, ErrorCode::InvalidUsage,
"FLOAT8_E4M3FNUZ is not natively supported on this platform; use FLOAT8_E4M3FN");
#endif
break;
case DataType::FLOAT8_E5M2:
#if defined(__FP8_E5M2_IS_FNUZ__)
THROW(EXEC, Error, ErrorCode::InvalidUsage,
"FLOAT8_E5M2 is not natively supported on this platform; use FLOAT8_E5M2FNUZ");
#else
executor->execute(rank, (__fp8_e5m2*)input, (__fp8_e5m2*)output, inputSize, outputSize, DataType::FLOAT8_E5M2,
plan_, stream);
#endif
break;
case DataType::FLOAT8_E5M2FNUZ:
#if defined(__FP8_E5M2_IS_FNUZ__)
executor->execute(rank, (__fp8_e5m2*)input, (__fp8_e5m2*)output, inputSize, outputSize, DataType::FLOAT8_E5M2FNUZ,
plan_, stream);
#else
THROW(EXEC, Error, ErrorCode::InvalidUsage,
"FLOAT8_E5M2FNUZ is not natively supported on this platform; use FLOAT8_E5M2");
#endif
break;
#endif
case DataType::FLOAT8_E4M3B15:
@@ -230,4 +259,4 @@ std::pair<std::shared_ptr<void>, size_t> getFlagBuffer() {
return {ptr, gDefaultFlagCount * sizeof(uint32_t)};
}
} // namespace mscclpp
} // namespace mscclpp

View File

@@ -78,8 +78,10 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
);
#endif
break;
case DataType::FLOAT8_E4M3:
case DataType::FLOAT8_E4M3FN:
case DataType::FLOAT8_E4M3FNUZ:
case DataType::FLOAT8_E5M2:
case DataType::FLOAT8_E5M2FNUZ:
// FP8 is not supported in CUDA execution kernel.
break;
case DataType::FLOAT8_E4M3B15:

View File

@@ -7,7 +7,6 @@
#include <mscclpp/switch_channel.hpp>
#include <mscclpp/utils.hpp>
#include "debug.h"
#include "execution_kernel.hpp"
#include "execution_plan.hpp"
@@ -509,8 +508,7 @@ Executor::Executor(std::shared_ptr<Communicator> comm, std::shared_ptr<char> def
void Executor::execute(int rank, void* sendbuff, void* recvbuff, size_t sendBuffSize,
[[maybe_unused]] size_t recvBuffSize, DataType dataType, const ExecutionPlan& plan,
cudaStream_t stream, PacketType packetType) {
INFO(MSCCLPP_EXECUTOR, "Starting execution with plan: %s, collective: %s", plan.name().c_str(),
plan.collective().c_str());
INFO(LogSubsys::EXEC, "Starting execution with plan: ", plan.name(), ", collective: ", plan.collective());
size_t sendMemRange, recvMemRange;
CUdeviceptr sendBasePtr, recvBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendMemRange, (CUdeviceptr)sendbuff));

View File

@@ -17,6 +17,7 @@
#include <mscclpp/switch_channel_device.hpp>
#include "execution_common.hpp"
#include "logger.hpp"
#include "reduce_kernel.hpp"
namespace mscclpp {
@@ -876,7 +877,19 @@ class ExecutionKernel {
#endif
break;
#if defined(__FP8_TYPES_EXIST__)
case DataType::FLOAT8_E4M3:
case DataType::FLOAT8_E4M3FN:
case DataType::FLOAT8_E4M3FNUZ:
#if defined(__FP8_E4M3_IS_FNUZ__)
if (dataType == DataType::FLOAT8_E4M3FN) {
THROW(LogSubsys::EXEC, Error, ErrorCode::InvalidUsage,
"FLOAT8_E4M3FN is not natively supported on this platform; use FLOAT8_E4M3FNUZ");
}
#else
if (dataType == DataType::FLOAT8_E4M3FNUZ) {
THROW(LogSubsys::EXEC, Error, ErrorCode::InvalidUsage,
"FLOAT8_E4M3FNUZ is not natively supported on this platform; use FLOAT8_E4M3FN");
}
#endif
executionKernel<__fp8_e4m3, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (__fp8_e4m3*)src, (__fp8_e4m3*)dst, (__fp8_e4m3*)scratch, scratchOffset, scratchChunkSize, plan,
semaphores, localMemoryIdBegin, flag
@@ -888,6 +901,18 @@ class ExecutionKernel {
#endif
break;
case DataType::FLOAT8_E5M2:
case DataType::FLOAT8_E5M2FNUZ:
#if defined(__FP8_E5M2_IS_FNUZ__)
if (dataType == DataType::FLOAT8_E5M2) {
THROW(LogSubsys::EXEC, Error, ErrorCode::InvalidUsage,
"FLOAT8_E5M2 is not natively supported on this platform; use FLOAT8_E5M2FNUZ");
}
#else
if (dataType == DataType::FLOAT8_E5M2FNUZ) {
THROW(LogSubsys::EXEC, Error, ErrorCode::InvalidUsage,
"FLOAT8_E5M2FNUZ is not natively supported on this platform; use FLOAT8_E5M2");
}
#endif
executionKernel<__fp8_e5m2, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (__fp8_e5m2*)src, (__fp8_e5m2*)dst, (__fp8_e5m2*)scratch, scratchOffset, scratchChunkSize, plan,
semaphores, localMemoryIdBegin, flag

View File

@@ -200,7 +200,8 @@ inline std::pair<int, int> getDefaultBlockNumAndThreadNum(size_t inputSize, int
{
bool isFp8 = dtype == DataType::FLOAT8_E4M3B15;
#if defined(__FP8_TYPES_EXIST__)
isFp8 = isFp8 || dtype == DataType::FLOAT8_E4M3 || dtype == DataType::FLOAT8_E5M2;
isFp8 = isFp8 || dtype == DataType::FLOAT8_E4M3FN || dtype == DataType::FLOAT8_E4M3FNUZ ||
dtype == DataType::FLOAT8_E5M2 || dtype == DataType::FLOAT8_E5M2FNUZ;
#endif
if (isFp8) {
if (inputSize < (64 << 10)) {
@@ -310,4 +311,4 @@ std::shared_ptr<Algorithm> AllreducePacket::build() {
}
} // namespace collective
} // namespace mscclpp
} // namespace mscclpp

View File

@@ -101,10 +101,20 @@ AllreduceFunc dispatchByDtype(mscclpp::DataType dtype, mscclpp::DataType accumDt
return Adapter<Op, __bfloat16, __bfloat16>::call;
#endif
#if defined(__FP8_TYPES_EXIST__)
} else if (dtype == mscclpp::DataType::FLOAT8_E4M3) {
#if defined(__FP8_E4M3_IS_FNUZ__)
} else if (dtype == mscclpp::DataType::FLOAT8_E4M3FNUZ) {
return dispatchFp8Accum<Op, __fp8_e4m3, Adapter>(accumDtype, dtype);
#else
} else if (dtype == mscclpp::DataType::FLOAT8_E4M3FN) {
return dispatchFp8Accum<Op, __fp8_e4m3, Adapter>(accumDtype, dtype);
#endif
#if defined(__FP8_E5M2_IS_FNUZ__)
} else if (dtype == mscclpp::DataType::FLOAT8_E5M2FNUZ) {
return dispatchFp8Accum<Op, __fp8_e5m2, Adapter>(accumDtype, dtype);
#else
} else if (dtype == mscclpp::DataType::FLOAT8_E5M2) {
return dispatchFp8Accum<Op, __fp8_e5m2, Adapter>(accumDtype, dtype);
#endif
#endif
} else if (dtype == mscclpp::DataType::FLOAT8_E4M3B15) {
return dispatchFp8Accum<Op, __fp8_e4m3b15, Adapter>(accumDtype, dtype);
@@ -125,4 +135,4 @@ AllreduceFunc dispatch(ReduceOp op, mscclpp::DataType dtype, mscclpp::DataType a
} // namespace collective
} // namespace mscclpp
#endif // MSCCLPP_ALLREDUCE_COMMON_HPP_
#endif // MSCCLPP_ALLREDUCE_COMMON_HPP_

View File

@@ -20,7 +20,8 @@ static bool isNvlsSupportedForDataType(const AlgorithmSelectorConfig& config, Da
return false;
}
const bool isFp8 = dtype == DataType::FLOAT8_E4M3 || dtype == DataType::FLOAT8_E5M2;
const bool isFp8 = dtype == DataType::FLOAT8_E4M3FN || dtype == DataType::FLOAT8_E4M3FNUZ ||
dtype == DataType::FLOAT8_E5M2 || dtype == DataType::FLOAT8_E5M2FNUZ;
if (!isFp8) {
return nvlsSupported;

View File

@@ -28,12 +28,21 @@ inline mscclpp::DataType ncclDataTypeToMscclpp(ncclDataType_t dtype) {
return mscclpp::DataType::BFLOAT16;
#ifdef __FP8_TYPES_EXIST__
case ncclFloat8e4m3:
return mscclpp::DataType::FLOAT8_E4M3;
#if defined(__FP8_E4M3_IS_FNUZ__)
return mscclpp::DataType::FLOAT8_E4M3FNUZ;
#else
return mscclpp::DataType::FLOAT8_E4M3FN;
#endif
case ncclFloat8e5m2:
#if defined(__FP8_E5M2_IS_FNUZ__)
return mscclpp::DataType::FLOAT8_E5M2FNUZ;
#else
return mscclpp::DataType::FLOAT8_E5M2;
#endif
#endif
default:
throw mscclpp::Error("Unsupported ncclDataType_t: " + std::to_string(dtype), mscclpp::ErrorCode::InvalidUsage);
THROW(mscclpp::LogSubsys::NCCL, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage,
"Unsupported ncclDataType_t: " + std::to_string(dtype));
}
}
@@ -41,8 +50,10 @@ inline mscclpp::DataType ncclDataTypeToMscclpp(ncclDataType_t dtype) {
inline size_t getDataTypeSize(mscclpp::DataType dtype) {
switch (dtype) {
case mscclpp::DataType::UINT8:
case mscclpp::DataType::FLOAT8_E4M3:
case mscclpp::DataType::FLOAT8_E4M3FN:
case mscclpp::DataType::FLOAT8_E4M3FNUZ:
case mscclpp::DataType::FLOAT8_E5M2:
case mscclpp::DataType::FLOAT8_E5M2FNUZ:
case mscclpp::DataType::FLOAT8_E4M3B15:
return 1;
case mscclpp::DataType::FLOAT16:
@@ -72,10 +83,32 @@ static inline ncclDataType_t mscclppToNcclDataType(mscclpp::DataType dtype) {
case mscclpp::DataType::BFLOAT16:
return ncclBfloat16;
#ifdef __FP8_TYPES_EXIST__
case mscclpp::DataType::FLOAT8_E4M3:
#if defined(__FP8_E4M3_IS_FNUZ__)
case mscclpp::DataType::FLOAT8_E4M3FNUZ:
return ncclFloat8e4m3;
case mscclpp::DataType::FLOAT8_E4M3FN:
THROW(mscclpp::LogSubsys::NCCL, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage,
"FLOAT8_E4M3FN is not natively supported on this platform; use FLOAT8_E4M3FNUZ for NCCL collectives");
#else
case mscclpp::DataType::FLOAT8_E4M3FN:
return ncclFloat8e4m3;
case mscclpp::DataType::FLOAT8_E4M3FNUZ:
THROW(mscclpp::LogSubsys::NCCL, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage,
"FLOAT8_E4M3FNUZ is not natively supported on this platform; use FLOAT8_E4M3FN for NCCL collectives");
#endif
#if defined(__FP8_E5M2_IS_FNUZ__)
case mscclpp::DataType::FLOAT8_E5M2FNUZ:
return ncclFloat8e5m2;
case mscclpp::DataType::FLOAT8_E5M2:
THROW(mscclpp::LogSubsys::NCCL, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage,
"FLOAT8_E5M2 is not natively supported on this platform; use FLOAT8_E5M2FNUZ for NCCL collectives");
#else
case mscclpp::DataType::FLOAT8_E5M2:
return ncclFloat8e5m2;
case mscclpp::DataType::FLOAT8_E5M2FNUZ:
THROW(mscclpp::LogSubsys::NCCL, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage,
"FLOAT8_E5M2FNUZ is not natively supported on this platform; use FLOAT8_E5M2 for NCCL collectives");
#endif
#endif
case mscclpp::DataType::FLOAT8_E4M3B15:
// float8_e4m3b15 has no NCCL equivalent; NCCL cannot reduce this type correctly.
@@ -98,4 +131,4 @@ inline mscclpp::ReduceOp ncclRedOpToMscclpp(ncclRedOp_t op) {
}
}
#endif // MSCCLPP_DATATYPE_CONVERSION_HPP_
#endif // MSCCLPP_DATATYPE_CONVERSION_HPP_