mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 08:50:21 +00:00
Merge remote-tracking branch 'origin/main' into binyli/mnnvl
This commit is contained in:
@@ -21,7 +21,10 @@ using __bfloat162 = __hip_bfloat162;
|
||||
#if defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR >= 6)
|
||||
#include <hip/hip_fp8.h>
|
||||
|
||||
// Create aliases matching CUDA naming convention for cross-platform compatibility
|
||||
// Create aliases matching CUDA naming convention for cross-platform compatibility.
|
||||
// Define __FP8_E4M3_IS_FNUZ__ / __FP8_E5M2_IS_FNUZ__ when the platform-native FP8 is the
|
||||
// "fnuz" variant (no infinities, NaN-only at 0x80, bias differs from OCP). Dispatch layers
|
||||
// use these macros to throw on unsupported variants requested via DataType.
|
||||
#if (HIP_VERSION_MAJOR == 6) || (HIP_VERSION_MAJOR > 6 && HIP_FP8_TYPE_FNUZ && !HIP_FP8_TYPE_OCP)
|
||||
using __fp8_e4m3 = __hip_fp8_e4m3_fnuz;
|
||||
using __fp8_e5m2 = __hip_fp8_e5m2_fnuz;
|
||||
@@ -29,6 +32,8 @@ using __fp8x2_e4m3 = __hip_fp8x2_e4m3_fnuz;
|
||||
using __fp8x2_e5m2 = __hip_fp8x2_e5m2_fnuz;
|
||||
using __fp8x4_e4m3 = __hip_fp8x4_e4m3_fnuz;
|
||||
using __fp8x4_e5m2 = __hip_fp8x4_e5m2_fnuz;
|
||||
#define __FP8_E4M3_IS_FNUZ__
|
||||
#define __FP8_E5M2_IS_FNUZ__
|
||||
#else
|
||||
using __fp8_e4m3 = __hip_fp8_e4m3;
|
||||
using __fp8_e5m2 = __hip_fp8_e5m2;
|
||||
@@ -66,8 +71,8 @@ using __bfloat162 = __nv_bfloat162;
|
||||
|
||||
/// Software float8 with 4 exponent bits, 3 mantissa bits, exponent bias = 15.
|
||||
/// Format (MSB first): [sign:1][exponent:4][mantissa:3]
|
||||
/// No infinities; exp=15 is NaN. Negative zero is NaN (fnuz convention).
|
||||
/// Max finite value: 0.9375, min normal: ~6.1e-5, min subnormal: ~7.6e-6.
|
||||
/// No infinities, no NaN. Encode saturates to ±1.75 (0x7e/0xfe).
|
||||
/// Adapted from the Triton compiler's fp8e4b15 format.
|
||||
struct alignas(1) __fp8_e4m3b15 {
|
||||
uint8_t __x;
|
||||
|
||||
@@ -97,35 +102,15 @@ struct alignas(1) __fp8_e4m3b15 {
|
||||
/// Algorithm: reinterpret fp8 bits into an fp16 bit pattern with exponent shifted by -8,
|
||||
/// then convert fp16 → float32.
|
||||
static MSCCLPP_HOST_DEVICE_INLINE float toFloat(uint8_t bits) {
|
||||
// Handle special values: negative zero (0x80) → NaN, exponent=15 → NaN.
|
||||
uint32_t exp = (bits >> 3) & 0xFu;
|
||||
if (bits == 0x80 || exp == 15) {
|
||||
union {
|
||||
uint32_t u;
|
||||
float f;
|
||||
} nan_val = {0x7FC00000u};
|
||||
return nan_val.f;
|
||||
}
|
||||
if (bits == 0) return 0.0f;
|
||||
|
||||
// Triton-style bit manipulation: fp8 → fp16 → fp32.
|
||||
// fp8 layout: [S:1][E:4][M:3] (bias=15)
|
||||
// fp16 layout: [S:1][E:5][M:10] (bias=15)
|
||||
//
|
||||
// Place fp8 in upper byte of fp16, then right-shift exponent+mantissa by 1
|
||||
// to convert E4 → E5 (both share bias=15). Sign bit stays at bit 15.
|
||||
// Branch-free decode: fp8 → fp16 → fp32, no special-case handling.
|
||||
// Encode saturates to ±1.75, so 0x7f/0xff are never produced.
|
||||
// Refer:
|
||||
// https://github.com/triton-lang/triton/blob/cf34004b8a67d290a962da166f5aa2fc66751326/python/triton/language/extra/cuda/utils.py#L34
|
||||
uint16_t h = (uint16_t)bits << 8; // place fp8 in upper byte of fp16
|
||||
uint16_t sign16 = h & 0x8000u; // extract sign at fp16 position
|
||||
uint16_t nosign = h & 0x7F00u; // exponent + mantissa (no sign)
|
||||
uint16_t fp16_bits = sign16 | (nosign >> 1); // shift exponent right by 1
|
||||
uint16_t fp16_bits = sign16 | (nosign >> 1); // shift exponent right by 1 (E4→E5)
|
||||
|
||||
// For subnormals: when fp8 exponent=0, the above gives fp16 exponent=0
|
||||
// and fp16 mantissa = (fp8_mantissa << 7), which correctly represents
|
||||
// the subnormal fp16 value since both share bias=15.
|
||||
|
||||
// Convert fp16 bits to float via __half (works on host and device, CUDA and HIP).
|
||||
union {
|
||||
uint16_t u;
|
||||
__half h;
|
||||
@@ -139,14 +124,6 @@ struct alignas(1) __fp8_e4m3b15 {
|
||||
/// The key insight is to convert to fp16 first (which shares bias=15 with e4m3b15),
|
||||
/// then pack the fp16 bits back into 8 bits by shifting the exponent left by 1.
|
||||
static MSCCLPP_HOST_DEVICE_INLINE uint8_t fromFloat(float val) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t u;
|
||||
} in = {val};
|
||||
|
||||
// NaN → 0x80 (negative-zero bit pattern = NaN in fnuz).
|
||||
if ((in.u & 0x7F800000u) == 0x7F800000u && (in.u & 0x007FFFFFu) != 0) return 0x80u;
|
||||
|
||||
// Convert float32 → fp16 bits via __half (works on host and device, CUDA and HIP).
|
||||
__half h_val = __float2half_rn(val);
|
||||
union {
|
||||
@@ -155,32 +132,19 @@ struct alignas(1) __fp8_e4m3b15 {
|
||||
} cvt = {h_val};
|
||||
uint16_t fp16_bits = cvt.u;
|
||||
|
||||
// Clamp absolute value to max finite e4m3b15: 0.9375 → fp16 = 0x3B80.
|
||||
// Clamp abs to max encodable value: 1.75 → fp16 = 0x3F00.
|
||||
// Matches Triton: encode saturates, 0x7f/0xff are never produced.
|
||||
uint16_t abs_fp16 = fp16_bits & 0x7FFFu;
|
||||
if (abs_fp16 > 0x3B80u) abs_fp16 = 0x3B80u;
|
||||
if (abs_fp16 > 0x3F00u) abs_fp16 = 0x3F00u;
|
||||
|
||||
// Reconstruct with sign.
|
||||
uint16_t sign16 = fp16_bits & 0x8000u;
|
||||
|
||||
// Triton-style: fp16 → fp8.
|
||||
// fp16 layout: [S:1][E:5][M:10] (bias=15)
|
||||
// fp8 layout: [S:1][E:4][M:3] (bias=15)
|
||||
//
|
||||
// mad.lo.u32 a0, a0, 2, 0x00800080 → (abs_fp16 * 2 + 0x0080)
|
||||
// This shifts left by 1 (undoing the right-shift in decode) and adds rounding bias.
|
||||
// Then: lop3.b32 b0, $1, 0x80008000, a0, 0xea → (sign & 0x8000) | a0
|
||||
// Finally: prmt for byte extraction.
|
||||
//
|
||||
// Simplified for scalar: shift abs_fp16 left by 1, add rounding bias, take upper byte.
|
||||
// fp16 → fp8: shift abs left by 1 (undo decode's right-shift), add rounding bias, take upper byte.
|
||||
uint16_t adjusted = (uint16_t)(abs_fp16 * 2u + 0x0080u);
|
||||
// The upper byte now contains [E:4][M:3][round_bit].
|
||||
// Combine with sign and extract.
|
||||
uint16_t with_sign = sign16 | adjusted;
|
||||
uint8_t result = (uint8_t)(with_sign >> 8);
|
||||
|
||||
// Zero → 0x00 (ensure positive zero, not negative zero which is NaN).
|
||||
if ((result & 0x7Fu) == 0) result = 0x00u;
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
@@ -199,16 +163,18 @@ namespace mscclpp {
|
||||
|
||||
/// Data types supported by mscclpp operations.
|
||||
enum class DataType {
|
||||
INT32, // 32-bit signed integer.
|
||||
UINT32, // 32-bit unsigned integer.
|
||||
FLOAT16, // IEEE 754 half precision.
|
||||
FLOAT32, // IEEE 754 single precision.
|
||||
BFLOAT16, // bfloat16 precision.
|
||||
FLOAT8_E4M3, // float8 with E4M3 layout.
|
||||
FLOAT8_E5M2, // float8 with E5M2 layout.
|
||||
UINT8, // 8-bit unsigned integer.
|
||||
FLOAT8_E4M3B15, // float8 with E4M3 layout, bias=15 (software, no HW accel).
|
||||
AUTO = 255, // Sentinel: resolve to the input dtype at runtime.
|
||||
INT32, // 32-bit signed integer.
|
||||
UINT32, // 32-bit unsigned integer.
|
||||
FLOAT16, // IEEE 754 half precision.
|
||||
FLOAT32, // IEEE 754 single precision.
|
||||
BFLOAT16, // bfloat16 precision.
|
||||
FLOAT8_E4M3FN, // float8 E4M3, OCP variant (NV; AMD HIP > 6 with OCP enabled).
|
||||
FLOAT8_E4M3FNUZ, // float8 E4M3, fnuz variant (AMD HIP 6, or HIP > 6 with FNUZ enabled).
|
||||
FLOAT8_E5M2, // float8 E5M2, OCP variant (NV; AMD HIP > 6 with OCP enabled).
|
||||
FLOAT8_E5M2FNUZ, // float8 E5M2, fnuz variant (AMD HIP 6, or HIP > 6 with FNUZ enabled).
|
||||
UINT8, // 8-bit unsigned integer.
|
||||
FLOAT8_E4M3B15, // float8 with E4M3 layout, bias=15 (software, no HW accel).
|
||||
AUTO = 255, // Sentinel: resolve to the input dtype at runtime.
|
||||
};
|
||||
|
||||
/// Word array.
|
||||
@@ -1137,11 +1103,11 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to<f8_e4m3b15x2, f16x2>(const f16x2& v) {
|
||||
#if defined(MSCCLPP_DEVICE_CUDA)
|
||||
uint32_t in0;
|
||||
asm("mov.b32 %0, %1;" : "=r"(in0) : "r"(*reinterpret_cast<const uint32_t*>(&v)));
|
||||
// Clamp abs to max finite e4m3b15 (0x3B80 = 0.9375 in fp16).
|
||||
// Clamp abs to max encodable e4m3b15 (0x3F00 = 1.75 in fp16).
|
||||
uint32_t lo = in0 & 0xFFFFu, hi = in0 >> 16;
|
||||
uint32_t alo = lo & 0x7FFFu, ahi = hi & 0x7FFFu;
|
||||
alo = alo < 0x3B80u ? alo : 0x3B80u;
|
||||
ahi = ahi < 0x3B80u ? ahi : 0x3B80u;
|
||||
alo = alo < 0x3F00u ? alo : 0x3F00u;
|
||||
ahi = ahi < 0x3F00u ? ahi : 0x3F00u;
|
||||
uint32_t a0 = alo | (ahi << 16);
|
||||
a0 = a0 * 2u + 0x00800080u;
|
||||
uint32_t b0 = a0 | (in0 & 0x80008000u);
|
||||
@@ -1152,7 +1118,7 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to<f8_e4m3b15x2, f16x2>(const f16x2& v) {
|
||||
uint32_t in0 = v.words[0];
|
||||
uint32_t abs0 = in0 & 0x7fff7fffu;
|
||||
uint32_t a0;
|
||||
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3B803B80u));
|
||||
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F003F00u));
|
||||
a0 = a0 * 2u + 0x00800080u;
|
||||
uint32_t b0 = a0 | (in0 & 0x80008000u);
|
||||
uint16_t packed = (uint16_t)(((b0 >> 8) & 0xFFu) | ((b0 >> 16) & 0xFF00u));
|
||||
@@ -1175,8 +1141,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to<f8_e4m3b15x4, f16x4>(const f16x4& v) {
|
||||
asm("mov.b32 %0, %1;" : "=r"(in1) : "r"(v.words[1]));
|
||||
uint32_t abs0 = in0 & 0x7fff7fffu;
|
||||
uint32_t abs1 = in1 & 0x7fff7fffu;
|
||||
uint32_t a0 = __vminu2(abs0, 0x3B803B80u);
|
||||
uint32_t a1 = __vminu2(abs1, 0x3B803B80u);
|
||||
uint32_t a0 = __vminu2(abs0, 0x3F003F00u);
|
||||
uint32_t a1 = __vminu2(abs1, 0x3F003F00u);
|
||||
a0 = a0 * 2u + 0x00800080u;
|
||||
a1 = a1 * 2u + 0x00800080u;
|
||||
uint32_t b0, b1;
|
||||
@@ -1189,8 +1155,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to<f8_e4m3b15x4, f16x4>(const f16x4& v) {
|
||||
uint32_t in0 = v.words[0], in1 = v.words[1];
|
||||
uint32_t abs0 = in0 & 0x7fff7fffu, abs1 = in1 & 0x7fff7fffu;
|
||||
uint32_t a0, a1;
|
||||
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3B803B80u));
|
||||
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a1) : "v"(abs1), "v"(0x3B803B80u));
|
||||
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F003F00u));
|
||||
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a1) : "v"(abs1), "v"(0x3F003F00u));
|
||||
a0 = a0 * 2u + 0x00800080u;
|
||||
a1 = a1 * 2u + 0x00800080u;
|
||||
uint32_t b0 = a0 | (in0 & 0x80008000u);
|
||||
|
||||
@@ -45,8 +45,10 @@ void register_core(nb::module_& m) {
|
||||
.value("float16", DataType::FLOAT16)
|
||||
.value("float32", DataType::FLOAT32)
|
||||
.value("bfloat16", DataType::BFLOAT16)
|
||||
.value("float8_e4m3", DataType::FLOAT8_E4M3)
|
||||
.value("float8_e4m3fn", DataType::FLOAT8_E4M3FN)
|
||||
.value("float8_e4m3fnuz", DataType::FLOAT8_E4M3FNUZ)
|
||||
.value("float8_e5m2", DataType::FLOAT8_E5M2)
|
||||
.value("float8_e5m2fnuz", DataType::FLOAT8_E5M2FNUZ)
|
||||
.value("uint8", DataType::UINT8)
|
||||
.value("float8_e4m3b15", DataType::FLOAT8_E4M3B15);
|
||||
|
||||
@@ -328,4 +330,4 @@ NB_MODULE(_mscclpp, m) {
|
||||
|
||||
// ext
|
||||
register_algorithm_collection_builder(m);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -192,12 +192,14 @@ def torch_dtype_to_mscclpp_dtype(dtype: "torch.dtype") -> DataType:
|
||||
return DataType.int32
|
||||
elif dtype == torch.bfloat16:
|
||||
return DataType.bfloat16
|
||||
# Hardware supports either OCP format or FNUZ format for float8.
|
||||
# Mapping both to the same MSCClPP data type.
|
||||
elif dtype == torch.float8_e5m2 or dtype == torch.float8_e5m2fnuz:
|
||||
elif dtype == torch.float8_e5m2:
|
||||
return DataType.float8_e5m2
|
||||
elif dtype == torch.float8_e4m3fn or dtype == torch.float8_e4m3fnuz:
|
||||
return DataType.float8_e4m3
|
||||
elif dtype == torch.float8_e5m2fnuz:
|
||||
return DataType.float8_e5m2fnuz
|
||||
elif dtype == torch.float8_e4m3fn:
|
||||
return DataType.float8_e4m3fn
|
||||
elif dtype == torch.float8_e4m3fnuz:
|
||||
return DataType.float8_e4m3fnuz
|
||||
elif dtype == torch.uint8:
|
||||
return DataType.uint8
|
||||
else:
|
||||
|
||||
@@ -21,6 +21,13 @@ from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group
|
||||
# FP8 E4M3 (hardware) requires SM >= 89 (Ada / Hopper) on NVIDIA GPUs.
|
||||
# On AMD/ROCm (e.g. MI300X), FP8 is supported natively — no skip needed.
|
||||
_is_hip = hasattr(cp.cuda.runtime, "is_hip") and cp.cuda.runtime.is_hip
|
||||
_gcn_arch_name = ""
|
||||
if _is_hip:
|
||||
_gcn_arch_name = cp.cuda.runtime.getDeviceProperties(0).get("gcnArchName", b"")
|
||||
if isinstance(_gcn_arch_name, bytes):
|
||||
_gcn_arch_name = _gcn_arch_name.decode()
|
||||
_gcn_arch_name = _gcn_arch_name.split(":", maxsplit=1)[0]
|
||||
_is_cdna4 = _gcn_arch_name.startswith("gfx95")
|
||||
_skip_fp8 = not _is_hip and int(cp.cuda.Device().compute_capability) < 89
|
||||
pytestmark = pytest.mark.skipif(_skip_fp8, reason="FP8 accum tests require SM >= 89 on CUDA")
|
||||
|
||||
@@ -90,7 +97,78 @@ def float_to_e4m3fn(f32_array, chunk_size=65536):
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FP8 E4M3B15 helpers (bias=15, max=0.9375, NaN = exp==15 or bits==0x80)
|
||||
# FP8 E4M3FNUZ helpers (AMD/ROCm; bias=8, max=240, NaN = bits==0x80, no -0)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def e4m3fnuz_to_float(uint8_array):
|
||||
"""Decode a cupy uint8 array of E4M3FNUZ bit patterns to float32."""
|
||||
bits = uint8_array.astype(cp.int32)
|
||||
sign = (bits >> 7) & 1
|
||||
exp = (bits >> 3) & 0xF
|
||||
mant = bits & 0x7
|
||||
|
||||
# Normal: (-1)^s * 2^(exp-8) * (1 + mant/8)
|
||||
normal_val = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), (exp - 8).astype(cp.int32))
|
||||
# Subnormal (exp==0): (-1)^s * 2^(-7) * (mant/8)
|
||||
subnormal_val = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-7))
|
||||
|
||||
result = cp.where(exp == 0, subnormal_val, normal_val)
|
||||
result = cp.where(sign == 1, -result, result)
|
||||
# Zero is only 0x00; the 0x80 encoding is reserved for NaN under fnuz.
|
||||
result = cp.where(uint8_array.astype(cp.int32) == 0, cp.float32(0.0), result)
|
||||
nan_mask = uint8_array.astype(cp.int32) == 0x80
|
||||
result = cp.where(nan_mask, cp.float32(float("nan")), result)
|
||||
return result
|
||||
|
||||
|
||||
def float_to_e4m3fnuz(f32_array, chunk_size=65536):
|
||||
"""Encode a cupy float32 array to uint8 E4M3FNUZ bit patterns.
|
||||
|
||||
Same lookup-table approach as float_to_e4m3fn but using the fnuz table.
|
||||
"""
|
||||
all_bytes = cp.arange(128, dtype=cp.uint8)
|
||||
all_floats = e4m3fnuz_to_float(all_bytes)
|
||||
all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats)
|
||||
|
||||
clamped = f32_array.astype(cp.float32)
|
||||
clamped = cp.clip(clamped, -240.0, 240.0)
|
||||
signs = (clamped < 0).astype(cp.uint8)
|
||||
absval = cp.abs(clamped)
|
||||
|
||||
result = cp.zeros(absval.shape, dtype=cp.uint8)
|
||||
n = absval.size
|
||||
absval_flat = absval.ravel()
|
||||
result_flat = result.ravel()
|
||||
|
||||
for start in range(0, n, chunk_size):
|
||||
end = min(start + chunk_size, n)
|
||||
chunk = absval_flat[start:end]
|
||||
diffs = cp.abs(chunk[:, None] - all_floats[None, :])
|
||||
result_flat[start:end] = cp.argmin(diffs, axis=1).astype(cp.uint8)
|
||||
|
||||
result = result_flat.reshape(absval.shape)
|
||||
result = result | (signs << 7)
|
||||
# 0x80 is NaN under fnuz (no negative zero). Collapse any encoding that
|
||||
# landed on 0x80 (small negatives quantised to zero magnitude) to 0x00.
|
||||
result = cp.where(result == 0x80, cp.uint8(0), result)
|
||||
return result
|
||||
|
||||
|
||||
# Platform-aware E4M3 native helpers: ROCm CDNA4 and CUDA use OCP fn; older ROCm uses fnuz.
|
||||
if _is_hip and not _is_cdna4:
|
||||
e4m3_native_to_float = e4m3fnuz_to_float
|
||||
float_to_e4m3_native = float_to_e4m3fnuz
|
||||
fp8_native_dtype = DataType.float8_e4m3fnuz
|
||||
else:
|
||||
e4m3_native_to_float = e4m3fn_to_float
|
||||
float_to_e4m3_native = float_to_e4m3fn
|
||||
fp8_native_dtype = DataType.float8_e4m3fn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FP8 E4M3B15 helpers (bias=15, encode saturates to ±1.75, no NaN)
|
||||
# Matches Triton's fp8e4b15: all 256 bit patterns are finite.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -108,11 +186,6 @@ def e4m3b15_to_float(uint8_array):
|
||||
|
||||
result = cp.where(exp == 0, subnormal_val, normal_val)
|
||||
result = cp.where(sign == 1, -result, result)
|
||||
# Zero
|
||||
result = cp.where((exp == 0) & (mant == 0), cp.float32(0.0), result)
|
||||
# NaN: exp==15 or negative zero (0x80)
|
||||
nan_mask = (exp == 15) | (uint8_array.astype(cp.int32) == 0x80)
|
||||
result = cp.where(nan_mask, cp.float32(float("nan")), result)
|
||||
return result
|
||||
|
||||
|
||||
@@ -120,18 +193,17 @@ def float_to_e4m3b15(f32_array, chunk_size=65536):
|
||||
"""Encode a cupy float32 array to uint8 E4M3B15 bit patterns.
|
||||
|
||||
Same lookup-table approach as float_to_e4m3fn.
|
||||
Saturates to ±1.75 (0x7e/0xfe), matching Triton's fp8e4b15.
|
||||
"""
|
||||
# Build lookup table of all 128 positive E4M3B15 values (0x00..0x7F)
|
||||
all_bytes = cp.arange(128, dtype=cp.uint8)
|
||||
all_floats = e4m3b15_to_float(all_bytes) # (128,) float32
|
||||
# Mark NaN entries as inf so they're never selected as nearest
|
||||
all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats)
|
||||
|
||||
# Clamp input and extract sign
|
||||
clamped = f32_array.astype(cp.float32)
|
||||
clamped = cp.clip(clamped, -0.9375, 0.9375)
|
||||
signs = (clamped < 0).astype(cp.uint8)
|
||||
absval = cp.abs(clamped)
|
||||
# Clamp input and extract sign.
|
||||
values = f32_array.astype(cp.float32)
|
||||
signs = cp.signbit(values).astype(cp.uint8)
|
||||
absval = cp.abs(values)
|
||||
absval = cp.clip(absval, cp.float32(0.0), cp.float32(1.75))
|
||||
|
||||
result = cp.zeros(absval.shape, dtype=cp.uint8)
|
||||
n = absval.size
|
||||
@@ -148,8 +220,6 @@ def float_to_e4m3b15(f32_array, chunk_size=65536):
|
||||
# Combine with sign bit
|
||||
result = result_flat.reshape(absval.shape)
|
||||
result = result | (signs << 7)
|
||||
# Handle exact zero
|
||||
result = cp.where(absval == 0, cp.uint8(0), result)
|
||||
return result
|
||||
|
||||
|
||||
@@ -226,12 +296,6 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
|
||||
buf = GpuBuffer(size, dtype=cp.uint8)
|
||||
|
||||
accum_configs = [
|
||||
("fp8_native", DataType.float8_e4m3),
|
||||
("float16", DataType.float16),
|
||||
("float32", DataType.float32),
|
||||
]
|
||||
|
||||
# rsag_zero_copy and fullmesh need explicit block/thread counts
|
||||
if "rsag" in algo_name:
|
||||
nb = max(1, min(32, size // (world_size * 32)))
|
||||
@@ -243,13 +307,19 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
nb = 0
|
||||
nt = 0
|
||||
|
||||
accum_configs = [
|
||||
("fp8_native", fp8_native_dtype),
|
||||
("float16", DataType.float16),
|
||||
("float32", DataType.float32),
|
||||
]
|
||||
|
||||
errors = {}
|
||||
for accum_label, accum_dtype in accum_configs:
|
||||
# Generate deterministic per-rank data (use numpy to avoid hipRAND issues on ROCm)
|
||||
rng = np.random.RandomState(42 + rank)
|
||||
src_f32 = cp.asarray(rng.randn(size).astype(np.float32))
|
||||
src_f32 = cp.clip(src_f32, -240.0, 240.0)
|
||||
src_fp8 = float_to_e4m3fn(src_f32)
|
||||
src_fp8 = float_to_e4m3_native(src_f32)
|
||||
|
||||
# Copy into symmetric buffer
|
||||
buf[:] = src_fp8
|
||||
@@ -260,12 +330,12 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
algo,
|
||||
comm_group,
|
||||
buf,
|
||||
dtype=DataType.float8_e4m3,
|
||||
dtype=fp8_native_dtype,
|
||||
accum_dtype=accum_dtype,
|
||||
nblocks=nb,
|
||||
nthreads_per_block=nt,
|
||||
)
|
||||
result_f32 = e4m3fn_to_float(result)
|
||||
result_f32 = e4m3_native_to_float(result)
|
||||
|
||||
# Compute float32 reference: sum all ranks' quantized FP8 inputs in float32
|
||||
ref_f32 = cp.zeros(size, dtype=cp.float32)
|
||||
@@ -273,12 +343,13 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
rng_r = np.random.RandomState(42 + r)
|
||||
rank_data = cp.asarray(rng_r.randn(size).astype(np.float32))
|
||||
rank_data = cp.clip(rank_data, -240.0, 240.0)
|
||||
rank_data_fp8 = float_to_e4m3fn(rank_data)
|
||||
ref_f32 += e4m3fn_to_float(rank_data_fp8)
|
||||
rank_data_fp8 = float_to_e4m3_native(rank_data)
|
||||
ref_f32 += e4m3_native_to_float(rank_data_fp8)
|
||||
|
||||
# Compute errors
|
||||
abs_err = cp.abs(result_f32 - ref_f32)
|
||||
mean_abs_err = float(cp.mean(abs_err))
|
||||
# Compute errors (only on valid, non-NaN entries)
|
||||
valid = ~cp.isnan(result_f32) & ~cp.isnan(ref_f32)
|
||||
abs_err = cp.abs(result_f32[valid] - ref_f32[valid])
|
||||
mean_abs_err = float(cp.mean(abs_err)) if abs_err.size > 0 else 0.0
|
||||
errors[accum_label] = mean_abs_err
|
||||
|
||||
# Reset between runs
|
||||
@@ -341,13 +412,10 @@ def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
|
||||
errors = {}
|
||||
for accum_label, accum_dtype in accum_configs:
|
||||
# Generate deterministic per-rank random uint8 values in valid e4m3b15 range
|
||||
# Generate deterministic per-rank random uint8 values covering the full e4m3b15 range.
|
||||
# All 256 bit patterns are valid (no NaN in this format).
|
||||
rng = np.random.RandomState(42 + rank)
|
||||
raw = cp.asarray(rng.randint(0, 0x78, (size,)).astype(np.uint8))
|
||||
signs = cp.asarray(rng.randint(0, 2, (size,)).astype(np.uint8)) << 7
|
||||
src_uint8 = raw | signs
|
||||
# Fix negative zero -> positive zero
|
||||
src_uint8 = cp.where(src_uint8 == 0x80, cp.uint8(0), src_uint8)
|
||||
src_uint8 = cp.asarray(rng.randint(0, 256, (size,)).astype(np.uint8))
|
||||
|
||||
# Copy into symmetric buffer
|
||||
buf[:] = src_uint8
|
||||
@@ -371,19 +439,15 @@ def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
ref_f32 = cp.zeros(size, dtype=cp.float32)
|
||||
for r in range(world_size):
|
||||
rng_r = np.random.RandomState(42 + r)
|
||||
raw_r = cp.asarray(rng_r.randint(0, 0x78, (size,)).astype(np.uint8))
|
||||
signs_r = cp.asarray(rng_r.randint(0, 2, (size,)).astype(np.uint8)) << 7
|
||||
bits_r = raw_r | signs_r
|
||||
bits_r = cp.where(bits_r == 0x80, cp.uint8(0), bits_r)
|
||||
bits_r = cp.asarray(rng_r.randint(0, 256, (size,)).astype(np.uint8))
|
||||
ref_f32 += e4m3b15_to_float(bits_r)
|
||||
|
||||
# Clamp reference to e4m3b15 representable range
|
||||
ref_f32 = cp.clip(ref_f32, -0.9375, 0.9375)
|
||||
ref_f32 = cp.clip(ref_f32, -1.75, 1.75)
|
||||
|
||||
# Compute errors (only on valid entries)
|
||||
valid = ~cp.isnan(result_f32) & ~cp.isnan(ref_f32)
|
||||
abs_err = cp.abs(result_f32[valid] - ref_f32[valid])
|
||||
mean_abs_err = float(cp.mean(abs_err)) if abs_err.size > 0 else 0.0
|
||||
# Compute errors
|
||||
abs_err = cp.abs(result_f32 - ref_f32)
|
||||
mean_abs_err = float(cp.mean(abs_err))
|
||||
errors[accum_label] = mean_abs_err
|
||||
|
||||
algo.reset()
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include <filesystem>
|
||||
#include <mscclpp/algorithm.hpp>
|
||||
#include <mscclpp/errors.hpp>
|
||||
#include <mscclpp/gpu_utils.hpp>
|
||||
|
||||
#include "logger.hpp"
|
||||
@@ -182,13 +183,41 @@ CommResult DslAlgorithm::execute(std::shared_ptr<Communicator> comm, const void*
|
||||
stream);
|
||||
break;
|
||||
#if defined(__FP8_TYPES_EXIST__)
|
||||
case DataType::FLOAT8_E4M3:
|
||||
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3,
|
||||
case DataType::FLOAT8_E4M3FN:
|
||||
#if defined(__FP8_E4M3_IS_FNUZ__)
|
||||
THROW(EXEC, Error, ErrorCode::InvalidUsage,
|
||||
"FLOAT8_E4M3FN is not natively supported on this platform; use FLOAT8_E4M3FNUZ");
|
||||
#else
|
||||
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3FN,
|
||||
plan_, stream);
|
||||
#endif
|
||||
break;
|
||||
case DataType::FLOAT8_E4M3FNUZ:
|
||||
#if defined(__FP8_E4M3_IS_FNUZ__)
|
||||
executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3FNUZ,
|
||||
plan_, stream);
|
||||
#else
|
||||
THROW(EXEC, Error, ErrorCode::InvalidUsage,
|
||||
"FLOAT8_E4M3FNUZ is not natively supported on this platform; use FLOAT8_E4M3FN");
|
||||
#endif
|
||||
break;
|
||||
case DataType::FLOAT8_E5M2:
|
||||
#if defined(__FP8_E5M2_IS_FNUZ__)
|
||||
THROW(EXEC, Error, ErrorCode::InvalidUsage,
|
||||
"FLOAT8_E5M2 is not natively supported on this platform; use FLOAT8_E5M2FNUZ");
|
||||
#else
|
||||
executor->execute(rank, (__fp8_e5m2*)input, (__fp8_e5m2*)output, inputSize, outputSize, DataType::FLOAT8_E5M2,
|
||||
plan_, stream);
|
||||
#endif
|
||||
break;
|
||||
case DataType::FLOAT8_E5M2FNUZ:
|
||||
#if defined(__FP8_E5M2_IS_FNUZ__)
|
||||
executor->execute(rank, (__fp8_e5m2*)input, (__fp8_e5m2*)output, inputSize, outputSize, DataType::FLOAT8_E5M2FNUZ,
|
||||
plan_, stream);
|
||||
#else
|
||||
THROW(EXEC, Error, ErrorCode::InvalidUsage,
|
||||
"FLOAT8_E5M2FNUZ is not natively supported on this platform; use FLOAT8_E5M2");
|
||||
#endif
|
||||
break;
|
||||
#endif
|
||||
case DataType::FLOAT8_E4M3B15:
|
||||
@@ -230,4 +259,4 @@ std::pair<std::shared_ptr<void>, size_t> getFlagBuffer() {
|
||||
return {ptr, gDefaultFlagCount * sizeof(uint32_t)};
|
||||
}
|
||||
|
||||
} // namespace mscclpp
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -78,8 +78,10 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
|
||||
);
|
||||
#endif
|
||||
break;
|
||||
case DataType::FLOAT8_E4M3:
|
||||
case DataType::FLOAT8_E4M3FN:
|
||||
case DataType::FLOAT8_E4M3FNUZ:
|
||||
case DataType::FLOAT8_E5M2:
|
||||
case DataType::FLOAT8_E5M2FNUZ:
|
||||
// FP8 is not supported in CUDA execution kernel.
|
||||
break;
|
||||
case DataType::FLOAT8_E4M3B15:
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#include <mscclpp/switch_channel.hpp>
|
||||
#include <mscclpp/utils.hpp>
|
||||
|
||||
#include "debug.h"
|
||||
#include "execution_kernel.hpp"
|
||||
#include "execution_plan.hpp"
|
||||
|
||||
@@ -509,8 +508,7 @@ Executor::Executor(std::shared_ptr<Communicator> comm, std::shared_ptr<char> def
|
||||
void Executor::execute(int rank, void* sendbuff, void* recvbuff, size_t sendBuffSize,
|
||||
[[maybe_unused]] size_t recvBuffSize, DataType dataType, const ExecutionPlan& plan,
|
||||
cudaStream_t stream, PacketType packetType) {
|
||||
INFO(MSCCLPP_EXECUTOR, "Starting execution with plan: %s, collective: %s", plan.name().c_str(),
|
||||
plan.collective().c_str());
|
||||
INFO(LogSubsys::EXEC, "Starting execution with plan: ", plan.name(), ", collective: ", plan.collective());
|
||||
size_t sendMemRange, recvMemRange;
|
||||
CUdeviceptr sendBasePtr, recvBasePtr;
|
||||
MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendMemRange, (CUdeviceptr)sendbuff));
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include <mscclpp/switch_channel_device.hpp>
|
||||
|
||||
#include "execution_common.hpp"
|
||||
#include "logger.hpp"
|
||||
#include "reduce_kernel.hpp"
|
||||
namespace mscclpp {
|
||||
|
||||
@@ -876,7 +877,19 @@ class ExecutionKernel {
|
||||
#endif
|
||||
break;
|
||||
#if defined(__FP8_TYPES_EXIST__)
|
||||
case DataType::FLOAT8_E4M3:
|
||||
case DataType::FLOAT8_E4M3FN:
|
||||
case DataType::FLOAT8_E4M3FNUZ:
|
||||
#if defined(__FP8_E4M3_IS_FNUZ__)
|
||||
if (dataType == DataType::FLOAT8_E4M3FN) {
|
||||
THROW(LogSubsys::EXEC, Error, ErrorCode::InvalidUsage,
|
||||
"FLOAT8_E4M3FN is not natively supported on this platform; use FLOAT8_E4M3FNUZ");
|
||||
}
|
||||
#else
|
||||
if (dataType == DataType::FLOAT8_E4M3FNUZ) {
|
||||
THROW(LogSubsys::EXEC, Error, ErrorCode::InvalidUsage,
|
||||
"FLOAT8_E4M3FNUZ is not natively supported on this platform; use FLOAT8_E4M3FN");
|
||||
}
|
||||
#endif
|
||||
executionKernel<__fp8_e4m3, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (__fp8_e4m3*)src, (__fp8_e4m3*)dst, (__fp8_e4m3*)scratch, scratchOffset, scratchChunkSize, plan,
|
||||
semaphores, localMemoryIdBegin, flag
|
||||
@@ -888,6 +901,18 @@ class ExecutionKernel {
|
||||
#endif
|
||||
break;
|
||||
case DataType::FLOAT8_E5M2:
|
||||
case DataType::FLOAT8_E5M2FNUZ:
|
||||
#if defined(__FP8_E5M2_IS_FNUZ__)
|
||||
if (dataType == DataType::FLOAT8_E5M2) {
|
||||
THROW(LogSubsys::EXEC, Error, ErrorCode::InvalidUsage,
|
||||
"FLOAT8_E5M2 is not natively supported on this platform; use FLOAT8_E5M2FNUZ");
|
||||
}
|
||||
#else
|
||||
if (dataType == DataType::FLOAT8_E5M2FNUZ) {
|
||||
THROW(LogSubsys::EXEC, Error, ErrorCode::InvalidUsage,
|
||||
"FLOAT8_E5M2FNUZ is not natively supported on this platform; use FLOAT8_E5M2");
|
||||
}
|
||||
#endif
|
||||
executionKernel<__fp8_e5m2, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (__fp8_e5m2*)src, (__fp8_e5m2*)dst, (__fp8_e5m2*)scratch, scratchOffset, scratchChunkSize, plan,
|
||||
semaphores, localMemoryIdBegin, flag
|
||||
|
||||
@@ -200,7 +200,8 @@ inline std::pair<int, int> getDefaultBlockNumAndThreadNum(size_t inputSize, int
|
||||
{
|
||||
bool isFp8 = dtype == DataType::FLOAT8_E4M3B15;
|
||||
#if defined(__FP8_TYPES_EXIST__)
|
||||
isFp8 = isFp8 || dtype == DataType::FLOAT8_E4M3 || dtype == DataType::FLOAT8_E5M2;
|
||||
isFp8 = isFp8 || dtype == DataType::FLOAT8_E4M3FN || dtype == DataType::FLOAT8_E4M3FNUZ ||
|
||||
dtype == DataType::FLOAT8_E5M2 || dtype == DataType::FLOAT8_E5M2FNUZ;
|
||||
#endif
|
||||
if (isFp8) {
|
||||
if (inputSize < (64 << 10)) {
|
||||
|
||||
@@ -101,10 +101,20 @@ AllreduceFunc dispatchByDtype(mscclpp::DataType dtype, mscclpp::DataType accumDt
|
||||
return Adapter<Op, __bfloat16, __bfloat16>::call;
|
||||
#endif
|
||||
#if defined(__FP8_TYPES_EXIST__)
|
||||
} else if (dtype == mscclpp::DataType::FLOAT8_E4M3) {
|
||||
#if defined(__FP8_E4M3_IS_FNUZ__)
|
||||
} else if (dtype == mscclpp::DataType::FLOAT8_E4M3FNUZ) {
|
||||
return dispatchFp8Accum<Op, __fp8_e4m3, Adapter>(accumDtype, dtype);
|
||||
#else
|
||||
} else if (dtype == mscclpp::DataType::FLOAT8_E4M3FN) {
|
||||
return dispatchFp8Accum<Op, __fp8_e4m3, Adapter>(accumDtype, dtype);
|
||||
#endif
|
||||
#if defined(__FP8_E5M2_IS_FNUZ__)
|
||||
} else if (dtype == mscclpp::DataType::FLOAT8_E5M2FNUZ) {
|
||||
return dispatchFp8Accum<Op, __fp8_e5m2, Adapter>(accumDtype, dtype);
|
||||
#else
|
||||
} else if (dtype == mscclpp::DataType::FLOAT8_E5M2) {
|
||||
return dispatchFp8Accum<Op, __fp8_e5m2, Adapter>(accumDtype, dtype);
|
||||
#endif
|
||||
#endif
|
||||
} else if (dtype == mscclpp::DataType::FLOAT8_E4M3B15) {
|
||||
return dispatchFp8Accum<Op, __fp8_e4m3b15, Adapter>(accumDtype, dtype);
|
||||
@@ -125,4 +135,4 @@ AllreduceFunc dispatch(ReduceOp op, mscclpp::DataType dtype, mscclpp::DataType a
|
||||
} // namespace collective
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_ALLREDUCE_COMMON_HPP_
|
||||
#endif // MSCCLPP_ALLREDUCE_COMMON_HPP_
|
||||
|
||||
@@ -20,7 +20,8 @@ static bool isNvlsSupportedForDataType(const AlgorithmSelectorConfig& config, Da
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool isFp8 = dtype == DataType::FLOAT8_E4M3 || dtype == DataType::FLOAT8_E5M2;
|
||||
const bool isFp8 = dtype == DataType::FLOAT8_E4M3FN || dtype == DataType::FLOAT8_E4M3FNUZ ||
|
||||
dtype == DataType::FLOAT8_E5M2 || dtype == DataType::FLOAT8_E5M2FNUZ;
|
||||
|
||||
if (!isFp8) {
|
||||
return nvlsSupported;
|
||||
|
||||
@@ -28,12 +28,21 @@ inline mscclpp::DataType ncclDataTypeToMscclpp(ncclDataType_t dtype) {
|
||||
return mscclpp::DataType::BFLOAT16;
|
||||
#ifdef __FP8_TYPES_EXIST__
|
||||
case ncclFloat8e4m3:
|
||||
return mscclpp::DataType::FLOAT8_E4M3;
|
||||
#if defined(__FP8_E4M3_IS_FNUZ__)
|
||||
return mscclpp::DataType::FLOAT8_E4M3FNUZ;
|
||||
#else
|
||||
return mscclpp::DataType::FLOAT8_E4M3FN;
|
||||
#endif
|
||||
case ncclFloat8e5m2:
|
||||
#if defined(__FP8_E5M2_IS_FNUZ__)
|
||||
return mscclpp::DataType::FLOAT8_E5M2FNUZ;
|
||||
#else
|
||||
return mscclpp::DataType::FLOAT8_E5M2;
|
||||
#endif
|
||||
#endif
|
||||
default:
|
||||
throw mscclpp::Error("Unsupported ncclDataType_t: " + std::to_string(dtype), mscclpp::ErrorCode::InvalidUsage);
|
||||
THROW(mscclpp::LogSubsys::NCCL, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage,
|
||||
"Unsupported ncclDataType_t: " + std::to_string(dtype));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,8 +50,10 @@ inline mscclpp::DataType ncclDataTypeToMscclpp(ncclDataType_t dtype) {
|
||||
inline size_t getDataTypeSize(mscclpp::DataType dtype) {
|
||||
switch (dtype) {
|
||||
case mscclpp::DataType::UINT8:
|
||||
case mscclpp::DataType::FLOAT8_E4M3:
|
||||
case mscclpp::DataType::FLOAT8_E4M3FN:
|
||||
case mscclpp::DataType::FLOAT8_E4M3FNUZ:
|
||||
case mscclpp::DataType::FLOAT8_E5M2:
|
||||
case mscclpp::DataType::FLOAT8_E5M2FNUZ:
|
||||
case mscclpp::DataType::FLOAT8_E4M3B15:
|
||||
return 1;
|
||||
case mscclpp::DataType::FLOAT16:
|
||||
@@ -72,10 +83,32 @@ static inline ncclDataType_t mscclppToNcclDataType(mscclpp::DataType dtype) {
|
||||
case mscclpp::DataType::BFLOAT16:
|
||||
return ncclBfloat16;
|
||||
#ifdef __FP8_TYPES_EXIST__
|
||||
case mscclpp::DataType::FLOAT8_E4M3:
|
||||
#if defined(__FP8_E4M3_IS_FNUZ__)
|
||||
case mscclpp::DataType::FLOAT8_E4M3FNUZ:
|
||||
return ncclFloat8e4m3;
|
||||
case mscclpp::DataType::FLOAT8_E4M3FN:
|
||||
THROW(mscclpp::LogSubsys::NCCL, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage,
|
||||
"FLOAT8_E4M3FN is not natively supported on this platform; use FLOAT8_E4M3FNUZ for NCCL collectives");
|
||||
#else
|
||||
case mscclpp::DataType::FLOAT8_E4M3FN:
|
||||
return ncclFloat8e4m3;
|
||||
case mscclpp::DataType::FLOAT8_E4M3FNUZ:
|
||||
THROW(mscclpp::LogSubsys::NCCL, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage,
|
||||
"FLOAT8_E4M3FNUZ is not natively supported on this platform; use FLOAT8_E4M3FN for NCCL collectives");
|
||||
#endif
|
||||
#if defined(__FP8_E5M2_IS_FNUZ__)
|
||||
case mscclpp::DataType::FLOAT8_E5M2FNUZ:
|
||||
return ncclFloat8e5m2;
|
||||
case mscclpp::DataType::FLOAT8_E5M2:
|
||||
THROW(mscclpp::LogSubsys::NCCL, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage,
|
||||
"FLOAT8_E5M2 is not natively supported on this platform; use FLOAT8_E5M2FNUZ for NCCL collectives");
|
||||
#else
|
||||
case mscclpp::DataType::FLOAT8_E5M2:
|
||||
return ncclFloat8e5m2;
|
||||
case mscclpp::DataType::FLOAT8_E5M2FNUZ:
|
||||
THROW(mscclpp::LogSubsys::NCCL, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage,
|
||||
"FLOAT8_E5M2FNUZ is not natively supported on this platform; use FLOAT8_E5M2 for NCCL collectives");
|
||||
#endif
|
||||
#endif
|
||||
case mscclpp::DataType::FLOAT8_E4M3B15:
|
||||
// float8_e4m3b15 has no NCCL equivalent; NCCL cannot reduce this type correctly.
|
||||
@@ -98,4 +131,4 @@ inline mscclpp::ReduceOp ncclRedOpToMscclpp(ncclRedOp_t op) {
|
||||
}
|
||||
}
|
||||
|
||||
#endif // MSCCLPP_DATATYPE_CONVERSION_HPP_
|
||||
#endif // MSCCLPP_DATATYPE_CONVERSION_HPP_
|
||||
|
||||
Reference in New Issue
Block a user