From 2c52937b26e6b72846cb8bec2f7479fb90162913 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Tue, 28 Apr 2026 15:02:22 -0700 Subject: [PATCH 1/2] Fix FP8 ROCm build/test issues and dtype naming (#792) ## Summary - Fix ROCm FP8 build failure by using the actual FP8 `DataType` enum constants in allreduce packet tuning. - Fix FP8 E4M3FNUZ test encoding so small negative values do not produce the FNUZ NaN byte (`0x80`). - Align FP8 `DataType` enum constants and Python bindings with torch-style names (`FLOAT8_E4M3FN`, `FLOAT8_E4M3FNUZ`, `FLOAT8_E5M2FNUZ` / `float8_e4m3fn`, `float8_e4m3fnuz`, `float8_e5m2fnuz`). ## Validation - `./tools/lint.sh` - `make -j` from `build/` - `mpirun --allow-run-as-root -np 8 python3 -m pytest python/test/test_fp8_accum.py -q` (`36 passed, 9 skipped`) - `DTYPE=float8_e4m3fnuz ACCUM_DTYPE=float32 torchrun --nnodes=1 --nproc_per_node=8 examples/torch-integration/customized_comm_with_tuning.py` --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- include/mscclpp/gpu_data_types.hpp | 104 ++++-------- python/csrc/core_py.cpp | 6 +- python/mscclpp/utils.py | 12 +- python/test/test_fp8_accum.py | 152 +++++++++++++----- src/core/algorithm.cc | 35 +++- src/core/executor/execution_kernel.cu | 4 +- src/core/executor/executor.cc | 4 +- src/core/include/execution_kernel.hpp | 27 +++- .../collectives/allreduce/allreduce_packet.cu | 5 +- .../collectives/include/allreduce/common.hpp | 14 +- src/ext/nccl/algorithm_selector.cc | 3 +- src/ext/nccl/datatype_conversion.hpp | 43 ++++- 12 files changed, 271 insertions(+), 138 deletions(-) diff --git a/include/mscclpp/gpu_data_types.hpp b/include/mscclpp/gpu_data_types.hpp index 41bd5928..672434f9 100644 --- a/include/mscclpp/gpu_data_types.hpp +++ b/include/mscclpp/gpu_data_types.hpp @@ -21,7 +21,10 @@ using __bfloat162 = __hip_bfloat162; #if defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR >= 6) #include -// Create aliases matching CUDA naming convention for cross-platform compatibility +// Create aliases matching CUDA naming convention for cross-platform compatibility. +// Define __FP8_E4M3_IS_FNUZ__ / __FP8_E5M2_IS_FNUZ__ when the platform-native FP8 is the +// "fnuz" variant (no infinities, NaN-only at 0x80, bias differs from OCP). Dispatch layers +// use these macros to throw on unsupported variants requested via DataType. #if (HIP_VERSION_MAJOR == 6) || (HIP_VERSION_MAJOR > 6 && HIP_FP8_TYPE_FNUZ && !HIP_FP8_TYPE_OCP) using __fp8_e4m3 = __hip_fp8_e4m3_fnuz; using __fp8_e5m2 = __hip_fp8_e5m2_fnuz; @@ -29,6 +32,8 @@ using __fp8x2_e4m3 = __hip_fp8x2_e4m3_fnuz; using __fp8x2_e5m2 = __hip_fp8x2_e5m2_fnuz; using __fp8x4_e4m3 = __hip_fp8x4_e4m3_fnuz; using __fp8x4_e5m2 = __hip_fp8x4_e5m2_fnuz; +#define __FP8_E4M3_IS_FNUZ__ +#define __FP8_E5M2_IS_FNUZ__ #else using __fp8_e4m3 = __hip_fp8_e4m3; using __fp8_e5m2 = __hip_fp8_e5m2; @@ -66,8 +71,8 @@ using __bfloat162 = __nv_bfloat162; /// Software float8 with 4 exponent bits, 3 mantissa bits, exponent bias = 15. /// Format (MSB first): [sign:1][exponent:4][mantissa:3] -/// No infinities; exp=15 is NaN. Negative zero is NaN (fnuz convention). -/// Max finite value: 0.9375, min normal: ~6.1e-5, min subnormal: ~7.6e-6. +/// No infinities, no NaN. Encode saturates to ±1.75 (0x7e/0xfe). +/// Adapted from the Triton compiler's fp8e4b15 format. struct alignas(1) __fp8_e4m3b15 { uint8_t __x; @@ -97,35 +102,15 @@ struct alignas(1) __fp8_e4m3b15 { /// Algorithm: reinterpret fp8 bits into an fp16 bit pattern with exponent shifted by -8, /// then convert fp16 → float32. static MSCCLPP_HOST_DEVICE_INLINE float toFloat(uint8_t bits) { - // Handle special values: negative zero (0x80) → NaN, exponent=15 → NaN. - uint32_t exp = (bits >> 3) & 0xFu; - if (bits == 0x80 || exp == 15) { - union { - uint32_t u; - float f; - } nan_val = {0x7FC00000u}; - return nan_val.f; - } - if (bits == 0) return 0.0f; - - // Triton-style bit manipulation: fp8 → fp16 → fp32. - // fp8 layout: [S:1][E:4][M:3] (bias=15) - // fp16 layout: [S:1][E:5][M:10] (bias=15) - // - // Place fp8 in upper byte of fp16, then right-shift exponent+mantissa by 1 - // to convert E4 → E5 (both share bias=15). Sign bit stays at bit 15. + // Branch-free decode: fp8 → fp16 → fp32, no special-case handling. + // Encode saturates to ±1.75, so 0x7f/0xff are never produced. // Refer: // https://github.com/triton-lang/triton/blob/cf34004b8a67d290a962da166f5aa2fc66751326/python/triton/language/extra/cuda/utils.py#L34 uint16_t h = (uint16_t)bits << 8; // place fp8 in upper byte of fp16 uint16_t sign16 = h & 0x8000u; // extract sign at fp16 position uint16_t nosign = h & 0x7F00u; // exponent + mantissa (no sign) - uint16_t fp16_bits = sign16 | (nosign >> 1); // shift exponent right by 1 + uint16_t fp16_bits = sign16 | (nosign >> 1); // shift exponent right by 1 (E4→E5) - // For subnormals: when fp8 exponent=0, the above gives fp16 exponent=0 - // and fp16 mantissa = (fp8_mantissa << 7), which correctly represents - // the subnormal fp16 value since both share bias=15. - - // Convert fp16 bits to float via __half (works on host and device, CUDA and HIP). union { uint16_t u; __half h; @@ -139,14 +124,6 @@ struct alignas(1) __fp8_e4m3b15 { /// The key insight is to convert to fp16 first (which shares bias=15 with e4m3b15), /// then pack the fp16 bits back into 8 bits by shifting the exponent left by 1. static MSCCLPP_HOST_DEVICE_INLINE uint8_t fromFloat(float val) { - union { - float f; - uint32_t u; - } in = {val}; - - // NaN → 0x80 (negative-zero bit pattern = NaN in fnuz). - if ((in.u & 0x7F800000u) == 0x7F800000u && (in.u & 0x007FFFFFu) != 0) return 0x80u; - // Convert float32 → fp16 bits via __half (works on host and device, CUDA and HIP). __half h_val = __float2half_rn(val); union { @@ -155,32 +132,19 @@ struct alignas(1) __fp8_e4m3b15 { } cvt = {h_val}; uint16_t fp16_bits = cvt.u; - // Clamp absolute value to max finite e4m3b15: 0.9375 → fp16 = 0x3B80. + // Clamp abs to max encodable value: 1.75 → fp16 = 0x3F00. + // Matches Triton: encode saturates, 0x7f/0xff are never produced. uint16_t abs_fp16 = fp16_bits & 0x7FFFu; - if (abs_fp16 > 0x3B80u) abs_fp16 = 0x3B80u; + if (abs_fp16 > 0x3F00u) abs_fp16 = 0x3F00u; // Reconstruct with sign. uint16_t sign16 = fp16_bits & 0x8000u; - // Triton-style: fp16 → fp8. - // fp16 layout: [S:1][E:5][M:10] (bias=15) - // fp8 layout: [S:1][E:4][M:3] (bias=15) - // - // mad.lo.u32 a0, a0, 2, 0x00800080 → (abs_fp16 * 2 + 0x0080) - // This shifts left by 1 (undoing the right-shift in decode) and adds rounding bias. - // Then: lop3.b32 b0, $1, 0x80008000, a0, 0xea → (sign & 0x8000) | a0 - // Finally: prmt for byte extraction. - // - // Simplified for scalar: shift abs_fp16 left by 1, add rounding bias, take upper byte. + // fp16 → fp8: shift abs left by 1 (undo decode's right-shift), add rounding bias, take upper byte. uint16_t adjusted = (uint16_t)(abs_fp16 * 2u + 0x0080u); - // The upper byte now contains [E:4][M:3][round_bit]. - // Combine with sign and extract. uint16_t with_sign = sign16 | adjusted; uint8_t result = (uint8_t)(with_sign >> 8); - // Zero → 0x00 (ensure positive zero, not negative zero which is NaN). - if ((result & 0x7Fu) == 0) result = 0x00u; - return result; } }; @@ -199,16 +163,18 @@ namespace mscclpp { /// Data types supported by mscclpp operations. enum class DataType { - INT32, // 32-bit signed integer. - UINT32, // 32-bit unsigned integer. - FLOAT16, // IEEE 754 half precision. - FLOAT32, // IEEE 754 single precision. - BFLOAT16, // bfloat16 precision. - FLOAT8_E4M3, // float8 with E4M3 layout. - FLOAT8_E5M2, // float8 with E5M2 layout. - UINT8, // 8-bit unsigned integer. - FLOAT8_E4M3B15, // float8 with E4M3 layout, bias=15 (software, no HW accel). - AUTO = 255, // Sentinel: resolve to the input dtype at runtime. + INT32, // 32-bit signed integer. + UINT32, // 32-bit unsigned integer. + FLOAT16, // IEEE 754 half precision. + FLOAT32, // IEEE 754 single precision. + BFLOAT16, // bfloat16 precision. + FLOAT8_E4M3FN, // float8 E4M3, OCP variant (NV; AMD HIP > 6 with OCP enabled). + FLOAT8_E4M3FNUZ, // float8 E4M3, fnuz variant (AMD HIP 6, or HIP > 6 with FNUZ enabled). + FLOAT8_E5M2, // float8 E5M2, OCP variant (NV; AMD HIP > 6 with OCP enabled). + FLOAT8_E5M2FNUZ, // float8 E5M2, fnuz variant (AMD HIP 6, or HIP > 6 with FNUZ enabled). + UINT8, // 8-bit unsigned integer. + FLOAT8_E4M3B15, // float8 with E4M3 layout, bias=15 (software, no HW accel). + AUTO = 255, // Sentinel: resolve to the input dtype at runtime. }; /// Word array. @@ -1137,11 +1103,11 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to(const f16x2& v) { #if defined(MSCCLPP_DEVICE_CUDA) uint32_t in0; asm("mov.b32 %0, %1;" : "=r"(in0) : "r"(*reinterpret_cast(&v))); - // Clamp abs to max finite e4m3b15 (0x3B80 = 0.9375 in fp16). + // Clamp abs to max encodable e4m3b15 (0x3F00 = 1.75 in fp16). uint32_t lo = in0 & 0xFFFFu, hi = in0 >> 16; uint32_t alo = lo & 0x7FFFu, ahi = hi & 0x7FFFu; - alo = alo < 0x3B80u ? alo : 0x3B80u; - ahi = ahi < 0x3B80u ? ahi : 0x3B80u; + alo = alo < 0x3F00u ? alo : 0x3F00u; + ahi = ahi < 0x3F00u ? ahi : 0x3F00u; uint32_t a0 = alo | (ahi << 16); a0 = a0 * 2u + 0x00800080u; uint32_t b0 = a0 | (in0 & 0x80008000u); @@ -1152,7 +1118,7 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to(const f16x2& v) { uint32_t in0 = v.words[0]; uint32_t abs0 = in0 & 0x7fff7fffu; uint32_t a0; - asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3B803B80u)); + asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F003F00u)); a0 = a0 * 2u + 0x00800080u; uint32_t b0 = a0 | (in0 & 0x80008000u); uint16_t packed = (uint16_t)(((b0 >> 8) & 0xFFu) | ((b0 >> 16) & 0xFF00u)); @@ -1175,8 +1141,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to(const f16x4& v) { asm("mov.b32 %0, %1;" : "=r"(in1) : "r"(v.words[1])); uint32_t abs0 = in0 & 0x7fff7fffu; uint32_t abs1 = in1 & 0x7fff7fffu; - uint32_t a0 = __vminu2(abs0, 0x3B803B80u); - uint32_t a1 = __vminu2(abs1, 0x3B803B80u); + uint32_t a0 = __vminu2(abs0, 0x3F003F00u); + uint32_t a1 = __vminu2(abs1, 0x3F003F00u); a0 = a0 * 2u + 0x00800080u; a1 = a1 * 2u + 0x00800080u; uint32_t b0, b1; @@ -1189,8 +1155,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to(const f16x4& v) { uint32_t in0 = v.words[0], in1 = v.words[1]; uint32_t abs0 = in0 & 0x7fff7fffu, abs1 = in1 & 0x7fff7fffu; uint32_t a0, a1; - asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3B803B80u)); - asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a1) : "v"(abs1), "v"(0x3B803B80u)); + asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F003F00u)); + asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a1) : "v"(abs1), "v"(0x3F003F00u)); a0 = a0 * 2u + 0x00800080u; a1 = a1 * 2u + 0x00800080u; uint32_t b0 = a0 | (in0 & 0x80008000u); diff --git a/python/csrc/core_py.cpp b/python/csrc/core_py.cpp index b8649564..a94f9863 100644 --- a/python/csrc/core_py.cpp +++ b/python/csrc/core_py.cpp @@ -45,8 +45,10 @@ void register_core(nb::module_& m) { .value("float16", DataType::FLOAT16) .value("float32", DataType::FLOAT32) .value("bfloat16", DataType::BFLOAT16) - .value("float8_e4m3", DataType::FLOAT8_E4M3) + .value("float8_e4m3fn", DataType::FLOAT8_E4M3FN) + .value("float8_e4m3fnuz", DataType::FLOAT8_E4M3FNUZ) .value("float8_e5m2", DataType::FLOAT8_E5M2) + .value("float8_e5m2fnuz", DataType::FLOAT8_E5M2FNUZ) .value("uint8", DataType::UINT8) .value("float8_e4m3b15", DataType::FLOAT8_E4M3B15); @@ -328,4 +330,4 @@ NB_MODULE(_mscclpp, m) { // ext register_algorithm_collection_builder(m); -} \ No newline at end of file +} diff --git a/python/mscclpp/utils.py b/python/mscclpp/utils.py index 93cd786b..0f0a28d4 100644 --- a/python/mscclpp/utils.py +++ b/python/mscclpp/utils.py @@ -192,12 +192,14 @@ def torch_dtype_to_mscclpp_dtype(dtype: "torch.dtype") -> DataType: return DataType.int32 elif dtype == torch.bfloat16: return DataType.bfloat16 - # Hardware supports either OCP format or FNUZ format for float8. - # Mapping both to the same MSCClPP data type. - elif dtype == torch.float8_e5m2 or dtype == torch.float8_e5m2fnuz: + elif dtype == torch.float8_e5m2: return DataType.float8_e5m2 - elif dtype == torch.float8_e4m3fn or dtype == torch.float8_e4m3fnuz: - return DataType.float8_e4m3 + elif dtype == torch.float8_e5m2fnuz: + return DataType.float8_e5m2fnuz + elif dtype == torch.float8_e4m3fn: + return DataType.float8_e4m3fn + elif dtype == torch.float8_e4m3fnuz: + return DataType.float8_e4m3fnuz elif dtype == torch.uint8: return DataType.uint8 else: diff --git a/python/test/test_fp8_accum.py b/python/test/test_fp8_accum.py index 82981ce1..ba33c085 100644 --- a/python/test/test_fp8_accum.py +++ b/python/test/test_fp8_accum.py @@ -21,6 +21,13 @@ from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group # FP8 E4M3 (hardware) requires SM >= 89 (Ada / Hopper) on NVIDIA GPUs. # On AMD/ROCm (e.g. MI300X), FP8 is supported natively — no skip needed. _is_hip = hasattr(cp.cuda.runtime, "is_hip") and cp.cuda.runtime.is_hip +_gcn_arch_name = "" +if _is_hip: + _gcn_arch_name = cp.cuda.runtime.getDeviceProperties(0).get("gcnArchName", b"") + if isinstance(_gcn_arch_name, bytes): + _gcn_arch_name = _gcn_arch_name.decode() + _gcn_arch_name = _gcn_arch_name.split(":", maxsplit=1)[0] +_is_cdna4 = _gcn_arch_name.startswith("gfx95") _skip_fp8 = not _is_hip and int(cp.cuda.Device().compute_capability) < 89 pytestmark = pytest.mark.skipif(_skip_fp8, reason="FP8 accum tests require SM >= 89 on CUDA") @@ -90,7 +97,78 @@ def float_to_e4m3fn(f32_array, chunk_size=65536): # --------------------------------------------------------------------------- -# FP8 E4M3B15 helpers (bias=15, max=0.9375, NaN = exp==15 or bits==0x80) +# FP8 E4M3FNUZ helpers (AMD/ROCm; bias=8, max=240, NaN = bits==0x80, no -0) +# --------------------------------------------------------------------------- + + +def e4m3fnuz_to_float(uint8_array): + """Decode a cupy uint8 array of E4M3FNUZ bit patterns to float32.""" + bits = uint8_array.astype(cp.int32) + sign = (bits >> 7) & 1 + exp = (bits >> 3) & 0xF + mant = bits & 0x7 + + # Normal: (-1)^s * 2^(exp-8) * (1 + mant/8) + normal_val = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), (exp - 8).astype(cp.int32)) + # Subnormal (exp==0): (-1)^s * 2^(-7) * (mant/8) + subnormal_val = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-7)) + + result = cp.where(exp == 0, subnormal_val, normal_val) + result = cp.where(sign == 1, -result, result) + # Zero is only 0x00; the 0x80 encoding is reserved for NaN under fnuz. + result = cp.where(uint8_array.astype(cp.int32) == 0, cp.float32(0.0), result) + nan_mask = uint8_array.astype(cp.int32) == 0x80 + result = cp.where(nan_mask, cp.float32(float("nan")), result) + return result + + +def float_to_e4m3fnuz(f32_array, chunk_size=65536): + """Encode a cupy float32 array to uint8 E4M3FNUZ bit patterns. + + Same lookup-table approach as float_to_e4m3fn but using the fnuz table. + """ + all_bytes = cp.arange(128, dtype=cp.uint8) + all_floats = e4m3fnuz_to_float(all_bytes) + all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats) + + clamped = f32_array.astype(cp.float32) + clamped = cp.clip(clamped, -240.0, 240.0) + signs = (clamped < 0).astype(cp.uint8) + absval = cp.abs(clamped) + + result = cp.zeros(absval.shape, dtype=cp.uint8) + n = absval.size + absval_flat = absval.ravel() + result_flat = result.ravel() + + for start in range(0, n, chunk_size): + end = min(start + chunk_size, n) + chunk = absval_flat[start:end] + diffs = cp.abs(chunk[:, None] - all_floats[None, :]) + result_flat[start:end] = cp.argmin(diffs, axis=1).astype(cp.uint8) + + result = result_flat.reshape(absval.shape) + result = result | (signs << 7) + # 0x80 is NaN under fnuz (no negative zero). Collapse any encoding that + # landed on 0x80 (small negatives quantised to zero magnitude) to 0x00. + result = cp.where(result == 0x80, cp.uint8(0), result) + return result + + +# Platform-aware E4M3 native helpers: ROCm CDNA4 and CUDA use OCP fn; older ROCm uses fnuz. +if _is_hip and not _is_cdna4: + e4m3_native_to_float = e4m3fnuz_to_float + float_to_e4m3_native = float_to_e4m3fnuz + fp8_native_dtype = DataType.float8_e4m3fnuz +else: + e4m3_native_to_float = e4m3fn_to_float + float_to_e4m3_native = float_to_e4m3fn + fp8_native_dtype = DataType.float8_e4m3fn + + +# --------------------------------------------------------------------------- +# FP8 E4M3B15 helpers (bias=15, encode saturates to ±1.75, no NaN) +# Matches Triton's fp8e4b15: all 256 bit patterns are finite. # --------------------------------------------------------------------------- @@ -108,11 +186,6 @@ def e4m3b15_to_float(uint8_array): result = cp.where(exp == 0, subnormal_val, normal_val) result = cp.where(sign == 1, -result, result) - # Zero - result = cp.where((exp == 0) & (mant == 0), cp.float32(0.0), result) - # NaN: exp==15 or negative zero (0x80) - nan_mask = (exp == 15) | (uint8_array.astype(cp.int32) == 0x80) - result = cp.where(nan_mask, cp.float32(float("nan")), result) return result @@ -120,18 +193,17 @@ def float_to_e4m3b15(f32_array, chunk_size=65536): """Encode a cupy float32 array to uint8 E4M3B15 bit patterns. Same lookup-table approach as float_to_e4m3fn. + Saturates to ±1.75 (0x7e/0xfe), matching Triton's fp8e4b15. """ # Build lookup table of all 128 positive E4M3B15 values (0x00..0x7F) all_bytes = cp.arange(128, dtype=cp.uint8) all_floats = e4m3b15_to_float(all_bytes) # (128,) float32 - # Mark NaN entries as inf so they're never selected as nearest - all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats) - # Clamp input and extract sign - clamped = f32_array.astype(cp.float32) - clamped = cp.clip(clamped, -0.9375, 0.9375) - signs = (clamped < 0).astype(cp.uint8) - absval = cp.abs(clamped) + # Clamp input and extract sign. + values = f32_array.astype(cp.float32) + signs = cp.signbit(values).astype(cp.uint8) + absval = cp.abs(values) + absval = cp.clip(absval, cp.float32(0.0), cp.float32(1.75)) result = cp.zeros(absval.shape, dtype=cp.uint8) n = absval.size @@ -148,8 +220,6 @@ def float_to_e4m3b15(f32_array, chunk_size=65536): # Combine with sign bit result = result_flat.reshape(absval.shape) result = result | (signs << 7) - # Handle exact zero - result = cp.where(absval == 0, cp.uint8(0), result) return result @@ -226,12 +296,6 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int): buf = GpuBuffer(size, dtype=cp.uint8) - accum_configs = [ - ("fp8_native", DataType.float8_e4m3), - ("float16", DataType.float16), - ("float32", DataType.float32), - ] - # rsag_zero_copy and fullmesh need explicit block/thread counts if "rsag" in algo_name: nb = max(1, min(32, size // (world_size * 32))) @@ -243,13 +307,19 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int): nb = 0 nt = 0 + accum_configs = [ + ("fp8_native", fp8_native_dtype), + ("float16", DataType.float16), + ("float32", DataType.float32), + ] + errors = {} for accum_label, accum_dtype in accum_configs: # Generate deterministic per-rank data (use numpy to avoid hipRAND issues on ROCm) rng = np.random.RandomState(42 + rank) src_f32 = cp.asarray(rng.randn(size).astype(np.float32)) src_f32 = cp.clip(src_f32, -240.0, 240.0) - src_fp8 = float_to_e4m3fn(src_f32) + src_fp8 = float_to_e4m3_native(src_f32) # Copy into symmetric buffer buf[:] = src_fp8 @@ -260,12 +330,12 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int): algo, comm_group, buf, - dtype=DataType.float8_e4m3, + dtype=fp8_native_dtype, accum_dtype=accum_dtype, nblocks=nb, nthreads_per_block=nt, ) - result_f32 = e4m3fn_to_float(result) + result_f32 = e4m3_native_to_float(result) # Compute float32 reference: sum all ranks' quantized FP8 inputs in float32 ref_f32 = cp.zeros(size, dtype=cp.float32) @@ -273,12 +343,13 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int): rng_r = np.random.RandomState(42 + r) rank_data = cp.asarray(rng_r.randn(size).astype(np.float32)) rank_data = cp.clip(rank_data, -240.0, 240.0) - rank_data_fp8 = float_to_e4m3fn(rank_data) - ref_f32 += e4m3fn_to_float(rank_data_fp8) + rank_data_fp8 = float_to_e4m3_native(rank_data) + ref_f32 += e4m3_native_to_float(rank_data_fp8) - # Compute errors - abs_err = cp.abs(result_f32 - ref_f32) - mean_abs_err = float(cp.mean(abs_err)) + # Compute errors (only on valid, non-NaN entries) + valid = ~cp.isnan(result_f32) & ~cp.isnan(ref_f32) + abs_err = cp.abs(result_f32[valid] - ref_f32[valid]) + mean_abs_err = float(cp.mean(abs_err)) if abs_err.size > 0 else 0.0 errors[accum_label] = mean_abs_err # Reset between runs @@ -341,13 +412,10 @@ def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int): errors = {} for accum_label, accum_dtype in accum_configs: - # Generate deterministic per-rank random uint8 values in valid e4m3b15 range + # Generate deterministic per-rank random uint8 values covering the full e4m3b15 range. + # All 256 bit patterns are valid (no NaN in this format). rng = np.random.RandomState(42 + rank) - raw = cp.asarray(rng.randint(0, 0x78, (size,)).astype(np.uint8)) - signs = cp.asarray(rng.randint(0, 2, (size,)).astype(np.uint8)) << 7 - src_uint8 = raw | signs - # Fix negative zero -> positive zero - src_uint8 = cp.where(src_uint8 == 0x80, cp.uint8(0), src_uint8) + src_uint8 = cp.asarray(rng.randint(0, 256, (size,)).astype(np.uint8)) # Copy into symmetric buffer buf[:] = src_uint8 @@ -371,19 +439,15 @@ def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int): ref_f32 = cp.zeros(size, dtype=cp.float32) for r in range(world_size): rng_r = np.random.RandomState(42 + r) - raw_r = cp.asarray(rng_r.randint(0, 0x78, (size,)).astype(np.uint8)) - signs_r = cp.asarray(rng_r.randint(0, 2, (size,)).astype(np.uint8)) << 7 - bits_r = raw_r | signs_r - bits_r = cp.where(bits_r == 0x80, cp.uint8(0), bits_r) + bits_r = cp.asarray(rng_r.randint(0, 256, (size,)).astype(np.uint8)) ref_f32 += e4m3b15_to_float(bits_r) # Clamp reference to e4m3b15 representable range - ref_f32 = cp.clip(ref_f32, -0.9375, 0.9375) + ref_f32 = cp.clip(ref_f32, -1.75, 1.75) - # Compute errors (only on valid entries) - valid = ~cp.isnan(result_f32) & ~cp.isnan(ref_f32) - abs_err = cp.abs(result_f32[valid] - ref_f32[valid]) - mean_abs_err = float(cp.mean(abs_err)) if abs_err.size > 0 else 0.0 + # Compute errors + abs_err = cp.abs(result_f32 - ref_f32) + mean_abs_err = float(cp.mean(abs_err)) errors[accum_label] = mean_abs_err algo.reset() diff --git a/src/core/algorithm.cc b/src/core/algorithm.cc index ffa53aa8..c0713daa 100644 --- a/src/core/algorithm.cc +++ b/src/core/algorithm.cc @@ -3,6 +3,7 @@ #include #include +#include #include #include "logger.hpp" @@ -182,13 +183,41 @@ CommResult DslAlgorithm::execute(std::shared_ptr comm, const void* stream); break; #if defined(__FP8_TYPES_EXIST__) - case DataType::FLOAT8_E4M3: - executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3, + case DataType::FLOAT8_E4M3FN: +#if defined(__FP8_E4M3_IS_FNUZ__) + THROW(EXEC, Error, ErrorCode::InvalidUsage, + "FLOAT8_E4M3FN is not natively supported on this platform; use FLOAT8_E4M3FNUZ"); +#else + executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3FN, plan_, stream); +#endif + break; + case DataType::FLOAT8_E4M3FNUZ: +#if defined(__FP8_E4M3_IS_FNUZ__) + executor->execute(rank, (__fp8_e4m3*)input, (__fp8_e4m3*)output, inputSize, outputSize, DataType::FLOAT8_E4M3FNUZ, + plan_, stream); +#else + THROW(EXEC, Error, ErrorCode::InvalidUsage, + "FLOAT8_E4M3FNUZ is not natively supported on this platform; use FLOAT8_E4M3FN"); +#endif break; case DataType::FLOAT8_E5M2: +#if defined(__FP8_E5M2_IS_FNUZ__) + THROW(EXEC, Error, ErrorCode::InvalidUsage, + "FLOAT8_E5M2 is not natively supported on this platform; use FLOAT8_E5M2FNUZ"); +#else executor->execute(rank, (__fp8_e5m2*)input, (__fp8_e5m2*)output, inputSize, outputSize, DataType::FLOAT8_E5M2, plan_, stream); +#endif + break; + case DataType::FLOAT8_E5M2FNUZ: +#if defined(__FP8_E5M2_IS_FNUZ__) + executor->execute(rank, (__fp8_e5m2*)input, (__fp8_e5m2*)output, inputSize, outputSize, DataType::FLOAT8_E5M2FNUZ, + plan_, stream); +#else + THROW(EXEC, Error, ErrorCode::InvalidUsage, + "FLOAT8_E5M2FNUZ is not natively supported on this platform; use FLOAT8_E5M2"); +#endif break; #endif case DataType::FLOAT8_E4M3B15: @@ -230,4 +259,4 @@ std::pair, size_t> getFlagBuffer() { return {ptr, gDefaultFlagCount * sizeof(uint32_t)}; } -} // namespace mscclpp \ No newline at end of file +} // namespace mscclpp diff --git a/src/core/executor/execution_kernel.cu b/src/core/executor/execution_kernel.cu index 28ced77f..d639efb7 100644 --- a/src/core/executor/execution_kernel.cu +++ b/src/core/executor/execution_kernel.cu @@ -78,8 +78,10 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo ); #endif break; - case DataType::FLOAT8_E4M3: + case DataType::FLOAT8_E4M3FN: + case DataType::FLOAT8_E4M3FNUZ: case DataType::FLOAT8_E5M2: + case DataType::FLOAT8_E5M2FNUZ: // FP8 is not supported in CUDA execution kernel. break; case DataType::FLOAT8_E4M3B15: diff --git a/src/core/executor/executor.cc b/src/core/executor/executor.cc index bf2caf97..fcecc4dd 100644 --- a/src/core/executor/executor.cc +++ b/src/core/executor/executor.cc @@ -7,7 +7,6 @@ #include #include -#include "debug.h" #include "execution_kernel.hpp" #include "execution_plan.hpp" @@ -509,8 +508,7 @@ Executor::Executor(std::shared_ptr comm, std::shared_ptr def void Executor::execute(int rank, void* sendbuff, void* recvbuff, size_t sendBuffSize, [[maybe_unused]] size_t recvBuffSize, DataType dataType, const ExecutionPlan& plan, cudaStream_t stream, PacketType packetType) { - INFO(MSCCLPP_EXECUTOR, "Starting execution with plan: %s, collective: %s", plan.name().c_str(), - plan.collective().c_str()); + INFO(LogSubsys::EXEC, "Starting execution with plan: ", plan.name(), ", collective: ", plan.collective()); size_t sendMemRange, recvMemRange; CUdeviceptr sendBasePtr, recvBasePtr; MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendMemRange, (CUdeviceptr)sendbuff)); diff --git a/src/core/include/execution_kernel.hpp b/src/core/include/execution_kernel.hpp index 87b88888..cb808bc8 100644 --- a/src/core/include/execution_kernel.hpp +++ b/src/core/include/execution_kernel.hpp @@ -17,6 +17,7 @@ #include #include "execution_common.hpp" +#include "logger.hpp" #include "reduce_kernel.hpp" namespace mscclpp { @@ -876,7 +877,19 @@ class ExecutionKernel { #endif break; #if defined(__FP8_TYPES_EXIST__) - case DataType::FLOAT8_E4M3: + case DataType::FLOAT8_E4M3FN: + case DataType::FLOAT8_E4M3FNUZ: +#if defined(__FP8_E4M3_IS_FNUZ__) + if (dataType == DataType::FLOAT8_E4M3FN) { + THROW(LogSubsys::EXEC, Error, ErrorCode::InvalidUsage, + "FLOAT8_E4M3FN is not natively supported on this platform; use FLOAT8_E4M3FNUZ"); + } +#else + if (dataType == DataType::FLOAT8_E4M3FNUZ) { + THROW(LogSubsys::EXEC, Error, ErrorCode::InvalidUsage, + "FLOAT8_E4M3FNUZ is not natively supported on this platform; use FLOAT8_E4M3FN"); + } +#endif executionKernel<__fp8_e4m3, PacketType, ReuseScratch><<>>( rank, (__fp8_e4m3*)src, (__fp8_e4m3*)dst, (__fp8_e4m3*)scratch, scratchOffset, scratchChunkSize, plan, semaphores, localMemoryIdBegin, flag @@ -888,6 +901,18 @@ class ExecutionKernel { #endif break; case DataType::FLOAT8_E5M2: + case DataType::FLOAT8_E5M2FNUZ: +#if defined(__FP8_E5M2_IS_FNUZ__) + if (dataType == DataType::FLOAT8_E5M2) { + THROW(LogSubsys::EXEC, Error, ErrorCode::InvalidUsage, + "FLOAT8_E5M2 is not natively supported on this platform; use FLOAT8_E5M2FNUZ"); + } +#else + if (dataType == DataType::FLOAT8_E5M2FNUZ) { + THROW(LogSubsys::EXEC, Error, ErrorCode::InvalidUsage, + "FLOAT8_E5M2FNUZ is not natively supported on this platform; use FLOAT8_E5M2"); + } +#endif executionKernel<__fp8_e5m2, PacketType, ReuseScratch><<>>( rank, (__fp8_e5m2*)src, (__fp8_e5m2*)dst, (__fp8_e5m2*)scratch, scratchOffset, scratchChunkSize, plan, semaphores, localMemoryIdBegin, flag diff --git a/src/ext/collectives/allreduce/allreduce_packet.cu b/src/ext/collectives/allreduce/allreduce_packet.cu index e2d8ef73..6199f192 100644 --- a/src/ext/collectives/allreduce/allreduce_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_packet.cu @@ -200,7 +200,8 @@ inline std::pair getDefaultBlockNumAndThreadNum(size_t inputSize, int { bool isFp8 = dtype == DataType::FLOAT8_E4M3B15; #if defined(__FP8_TYPES_EXIST__) - isFp8 = isFp8 || dtype == DataType::FLOAT8_E4M3 || dtype == DataType::FLOAT8_E5M2; + isFp8 = isFp8 || dtype == DataType::FLOAT8_E4M3FN || dtype == DataType::FLOAT8_E4M3FNUZ || + dtype == DataType::FLOAT8_E5M2 || dtype == DataType::FLOAT8_E5M2FNUZ; #endif if (isFp8) { if (inputSize < (64 << 10)) { @@ -310,4 +311,4 @@ std::shared_ptr AllreducePacket::build() { } } // namespace collective -} // namespace mscclpp \ No newline at end of file +} // namespace mscclpp diff --git a/src/ext/collectives/include/allreduce/common.hpp b/src/ext/collectives/include/allreduce/common.hpp index 1e0e7e69..93b18e26 100644 --- a/src/ext/collectives/include/allreduce/common.hpp +++ b/src/ext/collectives/include/allreduce/common.hpp @@ -101,10 +101,20 @@ AllreduceFunc dispatchByDtype(mscclpp::DataType dtype, mscclpp::DataType accumDt return Adapter::call; #endif #if defined(__FP8_TYPES_EXIST__) - } else if (dtype == mscclpp::DataType::FLOAT8_E4M3) { +#if defined(__FP8_E4M3_IS_FNUZ__) + } else if (dtype == mscclpp::DataType::FLOAT8_E4M3FNUZ) { return dispatchFp8Accum(accumDtype, dtype); +#else + } else if (dtype == mscclpp::DataType::FLOAT8_E4M3FN) { + return dispatchFp8Accum(accumDtype, dtype); +#endif +#if defined(__FP8_E5M2_IS_FNUZ__) + } else if (dtype == mscclpp::DataType::FLOAT8_E5M2FNUZ) { + return dispatchFp8Accum(accumDtype, dtype); +#else } else if (dtype == mscclpp::DataType::FLOAT8_E5M2) { return dispatchFp8Accum(accumDtype, dtype); +#endif #endif } else if (dtype == mscclpp::DataType::FLOAT8_E4M3B15) { return dispatchFp8Accum(accumDtype, dtype); @@ -125,4 +135,4 @@ AllreduceFunc dispatch(ReduceOp op, mscclpp::DataType dtype, mscclpp::DataType a } // namespace collective } // namespace mscclpp -#endif // MSCCLPP_ALLREDUCE_COMMON_HPP_ \ No newline at end of file +#endif // MSCCLPP_ALLREDUCE_COMMON_HPP_ diff --git a/src/ext/nccl/algorithm_selector.cc b/src/ext/nccl/algorithm_selector.cc index 0b9592d7..c94aab34 100644 --- a/src/ext/nccl/algorithm_selector.cc +++ b/src/ext/nccl/algorithm_selector.cc @@ -20,7 +20,8 @@ static bool isNvlsSupportedForDataType(const AlgorithmSelectorConfig& config, Da return false; } - const bool isFp8 = dtype == DataType::FLOAT8_E4M3 || dtype == DataType::FLOAT8_E5M2; + const bool isFp8 = dtype == DataType::FLOAT8_E4M3FN || dtype == DataType::FLOAT8_E4M3FNUZ || + dtype == DataType::FLOAT8_E5M2 || dtype == DataType::FLOAT8_E5M2FNUZ; if (!isFp8) { return nvlsSupported; diff --git a/src/ext/nccl/datatype_conversion.hpp b/src/ext/nccl/datatype_conversion.hpp index dcfb645a..a5c74def 100644 --- a/src/ext/nccl/datatype_conversion.hpp +++ b/src/ext/nccl/datatype_conversion.hpp @@ -28,12 +28,21 @@ inline mscclpp::DataType ncclDataTypeToMscclpp(ncclDataType_t dtype) { return mscclpp::DataType::BFLOAT16; #ifdef __FP8_TYPES_EXIST__ case ncclFloat8e4m3: - return mscclpp::DataType::FLOAT8_E4M3; +#if defined(__FP8_E4M3_IS_FNUZ__) + return mscclpp::DataType::FLOAT8_E4M3FNUZ; +#else + return mscclpp::DataType::FLOAT8_E4M3FN; +#endif case ncclFloat8e5m2: +#if defined(__FP8_E5M2_IS_FNUZ__) + return mscclpp::DataType::FLOAT8_E5M2FNUZ; +#else return mscclpp::DataType::FLOAT8_E5M2; +#endif #endif default: - throw mscclpp::Error("Unsupported ncclDataType_t: " + std::to_string(dtype), mscclpp::ErrorCode::InvalidUsage); + THROW(mscclpp::LogSubsys::NCCL, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage, + "Unsupported ncclDataType_t: " + std::to_string(dtype)); } } @@ -41,8 +50,10 @@ inline mscclpp::DataType ncclDataTypeToMscclpp(ncclDataType_t dtype) { inline size_t getDataTypeSize(mscclpp::DataType dtype) { switch (dtype) { case mscclpp::DataType::UINT8: - case mscclpp::DataType::FLOAT8_E4M3: + case mscclpp::DataType::FLOAT8_E4M3FN: + case mscclpp::DataType::FLOAT8_E4M3FNUZ: case mscclpp::DataType::FLOAT8_E5M2: + case mscclpp::DataType::FLOAT8_E5M2FNUZ: case mscclpp::DataType::FLOAT8_E4M3B15: return 1; case mscclpp::DataType::FLOAT16: @@ -72,10 +83,32 @@ static inline ncclDataType_t mscclppToNcclDataType(mscclpp::DataType dtype) { case mscclpp::DataType::BFLOAT16: return ncclBfloat16; #ifdef __FP8_TYPES_EXIST__ - case mscclpp::DataType::FLOAT8_E4M3: +#if defined(__FP8_E4M3_IS_FNUZ__) + case mscclpp::DataType::FLOAT8_E4M3FNUZ: return ncclFloat8e4m3; + case mscclpp::DataType::FLOAT8_E4M3FN: + THROW(mscclpp::LogSubsys::NCCL, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage, + "FLOAT8_E4M3FN is not natively supported on this platform; use FLOAT8_E4M3FNUZ for NCCL collectives"); +#else + case mscclpp::DataType::FLOAT8_E4M3FN: + return ncclFloat8e4m3; + case mscclpp::DataType::FLOAT8_E4M3FNUZ: + THROW(mscclpp::LogSubsys::NCCL, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage, + "FLOAT8_E4M3FNUZ is not natively supported on this platform; use FLOAT8_E4M3FN for NCCL collectives"); +#endif +#if defined(__FP8_E5M2_IS_FNUZ__) + case mscclpp::DataType::FLOAT8_E5M2FNUZ: + return ncclFloat8e5m2; + case mscclpp::DataType::FLOAT8_E5M2: + THROW(mscclpp::LogSubsys::NCCL, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage, + "FLOAT8_E5M2 is not natively supported on this platform; use FLOAT8_E5M2FNUZ for NCCL collectives"); +#else case mscclpp::DataType::FLOAT8_E5M2: return ncclFloat8e5m2; + case mscclpp::DataType::FLOAT8_E5M2FNUZ: + THROW(mscclpp::LogSubsys::NCCL, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage, + "FLOAT8_E5M2FNUZ is not natively supported on this platform; use FLOAT8_E5M2 for NCCL collectives"); +#endif #endif case mscclpp::DataType::FLOAT8_E4M3B15: // float8_e4m3b15 has no NCCL equivalent; NCCL cannot reduce this type correctly. @@ -98,4 +131,4 @@ inline mscclpp::ReduceOp ncclRedOpToMscclpp(ncclRedOp_t op) { } } -#endif // MSCCLPP_DATATYPE_CONVERSION_HPP_ \ No newline at end of file +#endif // MSCCLPP_DATATYPE_CONVERSION_HPP_ From 9ec26fa4d11325ca33dd4dca83b99dee9146e6bf Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Mon, 4 May 2026 15:11:47 -0700 Subject: [PATCH 2/2] Reset GPU tokens before reuse (#795) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes a token-reuse bug in `TokenPool` that's independent of MNNVL. ## Bug `TokenPool` hands out 8-byte device-memory slots used as device-semaphore counters. The deleter only cleared the bitmap — the underlying GPU memory was left as-is. When a token was freed and later re-allocated, the new semaphore inherited the previous counter value instead of starting at 0, breaking subsequent `signal()/wait()` math. ## Fix * Add a synchronous `gpuMemset` host helper (mirrors `gpuMemcpy` / `gpuMemcpyAsync`). * Zero the slot inside the `TokenPool` deleter so recycled tokens hand out a clean counter. The very-first allocation is already zeroed by `gpuCallocPhysical` (`src/core/gpu_utils.cc:227-228`), so first-time tokens are also clean — the deleter only has to handle the recycle case. ## Notes * Public wrapper is named `mscclpp::gpuMemset` (not `mscclpp::memset`) for symmetry with `gpuMemcpy` and to avoid shadowing `std::memset` in TUs that pull the namespace in. * Zeroing happens on release rather than acquire so the cost is paid in the typically less perf-sensitive teardown path. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- include/mscclpp/gpu_utils.hpp | 7 +++++++ src/core/gpu_utils.cc | 7 +++++++ src/core/utils_internal.cc | 3 +++ 3 files changed, 17 insertions(+) diff --git a/include/mscclpp/gpu_utils.hpp b/include/mscclpp/gpu_utils.hpp index ecd13c47..b079e0fd 100644 --- a/include/mscclpp/gpu_utils.hpp +++ b/include/mscclpp/gpu_utils.hpp @@ -165,6 +165,7 @@ void gpuFreePhysical(void* ptr); void gpuMemcpyAsync(void* dst, const void* src, size_t bytes, cudaStream_t stream, cudaMemcpyKind kind = cudaMemcpyDefault); void gpuMemcpy(void* dst, const void* src, size_t bytes, cudaMemcpyKind kind = cudaMemcpyDefault); +void gpuMemset(void* ptr, int value, size_t bytes); /// A template function that allocates memory while ensuring that the memory will be freed when the returned object is /// destroyed. @@ -300,6 +301,12 @@ void gpuMemcpy(T* dst, const T* src, size_t nelems, cudaMemcpyKind kind = cudaMe detail::gpuMemcpy(dst, src, nelems * sizeof(T), kind); } +/// Sets `bytes` of memory at `ptr` to `value` synchronously. +/// @param ptr Destination address. +/// @param value Value to set (interpreted as unsigned char per CUDA semantics). +/// @param bytes Number of bytes to set. +inline void gpuMemset(void* ptr, int value, size_t bytes) { detail::gpuMemset(ptr, value, bytes); } + /// Check if NVLink SHARP (NVLS) is supported. /// /// @return True if NVLink SHARP (NVLS) is supported, false otherwise. diff --git a/src/core/gpu_utils.cc b/src/core/gpu_utils.cc index 09d5025d..1ce61322 100644 --- a/src/core/gpu_utils.cc +++ b/src/core/gpu_utils.cc @@ -267,6 +267,13 @@ void gpuMemcpy(void* dst, const void* src, size_t bytes, cudaMemcpyKind kind) { MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); } +void gpuMemset(void* ptr, int value, size_t bytes) { + AvoidCudaGraphCaptureGuard cgcGuard; + CudaStreamWithFlags stream(cudaStreamNonBlocking); + MSCCLPP_CUDATHROW(cudaMemsetAsync(ptr, value, bytes, stream)); + MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); +} + } // namespace detail bool isNvlsSupported() { diff --git a/src/core/utils_internal.cc b/src/core/utils_internal.cc index 9504a52c..8cc55430 100644 --- a/src/core/utils_internal.cc +++ b/src/core/utils_internal.cc @@ -248,6 +248,9 @@ TokenPool::TokenPool(size_t nToken) : nToken_(nToken) { std::shared_ptr TokenPool::getToken() { auto deleter = [self = shared_from_this()](uint64_t* token) { + // Zero the slot on release so the next allocator hands out a clean + // semaphore counter (matches a freshly-allocated slot). + mscclpp::gpuMemset(token, 0, sizeof(uint64_t)); size_t index = (token - self->baseAddr_) / UINT64_WIDTH; size_t bit = (token - self->baseAddr_) % UINT64_WIDTH; uint64_t mask = 1UL << bit;