mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
merge main
This commit is contained in:
@@ -24,4 +24,7 @@ set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME _mscclpp)
|
||||
set_target_properties(mscclpp_py PROPERTIES INSTALL_RPATH "\$ORIGIN/lib")
|
||||
target_link_libraries(mscclpp_py PRIVATE dlpack mscclpp mscclpp_collectives ${GPU_LIBRARIES})
|
||||
target_include_directories(mscclpp_py SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
|
||||
if(MSCCLPP_USE_ROCM)
|
||||
target_compile_definitions(mscclpp_py PRIVATE MSCCLPP_USE_ROCM)
|
||||
endif()
|
||||
install(TARGETS mscclpp_py LIBRARY DESTINATION .)
|
||||
|
||||
@@ -75,15 +75,17 @@ void register_algorithm(nb::module_& m) {
|
||||
[](Algorithm& self, std::shared_ptr<Communicator> comm, uintptr_t input, uintptr_t output,
|
||||
size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op, uintptr_t stream,
|
||||
std::shared_ptr<Executor> executor, int nBlocks, int nThreadsPerBlock, bool symmetricMemory,
|
||||
std::unordered_map<std::string, uintptr_t> extras) {
|
||||
std::unordered_map<std::string, uintptr_t> extras, int32_t accumDtype) {
|
||||
return self.execute(comm, reinterpret_cast<const void*>(input), reinterpret_cast<void*>(output),
|
||||
inputSize, outputSize, dtype, op, reinterpret_cast<cudaStream_t>(stream), executor,
|
||||
nBlocks, nThreadsPerBlock, symmetricMemory, extras);
|
||||
nBlocks, nThreadsPerBlock, symmetricMemory, extras,
|
||||
static_cast<DataType>(accumDtype));
|
||||
},
|
||||
nb::arg("comm"), nb::arg("input"), nb::arg("output"), nb::arg("input_size"), nb::arg("output_size"),
|
||||
nb::arg("dtype"), nb::arg("op") = ReduceOp::NOP, nb::arg("stream") = 0, nb::arg("executor") = nullptr,
|
||||
nb::arg("n_blocks") = 0, nb::arg("n_threads_per_block") = 0, nb::arg("symmetric_memory") = false,
|
||||
nb::arg("extras") = std::unordered_map<std::string, uintptr_t>())
|
||||
nb::arg("extras") = std::unordered_map<std::string, uintptr_t>(),
|
||||
nb::arg("accum_dtype") = static_cast<int32_t>(DataType::AUTO))
|
||||
.def("reset", &Algorithm::reset);
|
||||
|
||||
nb::class_<Algorithm::Constraint>(algorithmClass, "Constraint")
|
||||
|
||||
@@ -47,7 +47,8 @@ void register_core(nb::module_& m) {
|
||||
.value("bfloat16", DataType::BFLOAT16)
|
||||
.value("float8_e4m3", DataType::FLOAT8_E4M3)
|
||||
.value("float8_e5m2", DataType::FLOAT8_E5M2)
|
||||
.value("uint8", DataType::UINT8);
|
||||
.value("uint8", DataType::UINT8)
|
||||
.value("float8_e4m3b15", DataType::FLOAT8_E4M3B15);
|
||||
|
||||
nb::class_<Bootstrap>(m, "CppBootstrap")
|
||||
.def("get_rank", &Bootstrap::getRank)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/function.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
|
||||
@@ -34,6 +34,19 @@ static DLDataType getDlType(std::string type) {
|
||||
return DLDataType{kDLBfloat, 16, 1};
|
||||
} else if (type == "torch.float16") {
|
||||
return DLDataType{kDLFloat, 16, 1};
|
||||
} else if (type == "torch.float8_e4m3fn") {
|
||||
return DLDataType{kDLFloat8_e4m3fn, 8, 1};
|
||||
} else if (type == "torch.float8_e4m3fnuz") {
|
||||
return DLDataType{kDLFloat8_e4m3fnuz, 8, 1};
|
||||
} else if (type == "torch.float8_e5m2") {
|
||||
return DLDataType{kDLFloat8_e5m2, 8, 1};
|
||||
} else if (type == "torch.float8_e5m2fnuz") {
|
||||
return DLDataType{kDLFloat8_e5m2fnuz, 8, 1};
|
||||
} else if (type == "torch.uint8") {
|
||||
return DLDataType{kDLUInt, 8, 1};
|
||||
} else if (type == "fp8_e4m3b15") {
|
||||
// No standard DLPack code for fp8_e4m3b15; store as raw uint8 bytes.
|
||||
return DLDataType{kDLUInt, 8, 1};
|
||||
} else {
|
||||
throw Error("Unsupported type: " + type, ErrorCode::InvalidUsage);
|
||||
}
|
||||
|
||||
@@ -177,6 +177,7 @@ class Algorithm:
|
||||
nthreads_per_block=0,
|
||||
symmetric_memory: bool = False,
|
||||
extras: Optional[Dict[str, int]] = None,
|
||||
accum_dtype: Optional[CppDataType] = None,
|
||||
) -> int:
|
||||
"""Execute the collective algorithm.
|
||||
|
||||
@@ -194,10 +195,14 @@ class Algorithm:
|
||||
nthreads_per_block: Number of threads per block (0 for auto-selection).
|
||||
symmetric_memory: Whether to use symmetric memory optimization (default: False).
|
||||
extras: Additional algorithm-specific parameters.
|
||||
accum_dtype: Data type for accumulation during reduction. If None, defaults to
|
||||
the same as dtype. Use DataType.float32 for high-precision FP8 accumulation.
|
||||
|
||||
Returns:
|
||||
The result code (0 for success).
|
||||
"""
|
||||
merged_extras = dict(extras) if extras is not None else {}
|
||||
accum_dtype = accum_dtype if accum_dtype is not None else dtype
|
||||
return self._algorithm.execute(
|
||||
comm,
|
||||
int(input_buffer),
|
||||
@@ -211,7 +216,8 @@ class Algorithm:
|
||||
nblocks,
|
||||
nthreads_per_block,
|
||||
symmetric_memory,
|
||||
extras if extras is not None else {},
|
||||
merged_extras,
|
||||
int(accum_dtype),
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
|
||||
@@ -140,7 +140,7 @@ class MemoryChannel:
|
||||
|
||||
for tb_id in tb_list:
|
||||
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type)
|
||||
tb_channel_ids = get_program().setup_channel(tb, self)
|
||||
tb_channel_ids = get_program().setup_channel(tb_id, self)
|
||||
op = GetOperation(
|
||||
src_buff=[RemoteChunk(src_chunk.buffer, src_chunk.index, src_chunk.size, tb_chunk_id)],
|
||||
dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)],
|
||||
|
||||
@@ -745,7 +745,7 @@ class ReduceOperation(BaseOperation):
|
||||
remote_dst_buff=self.remote_dst_buff + other.dst_buff,
|
||||
channel_ids=self.channel_ids,
|
||||
put_channel_ids=self.put_channel_ids + other.channel_ids,
|
||||
channel_type=self.channel_type,
|
||||
channel_type=other.channel_type,
|
||||
reduce_operation=self.reduce_operation,
|
||||
tbg_info=self.tbg_info,
|
||||
packet=self.packet,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
mpi4py==4.1.1
|
||||
cupy==13.6.0
|
||||
mpi4py
|
||||
cupy
|
||||
prettytable
|
||||
netifaces
|
||||
pytest
|
||||
|
||||
397
python/test/test_fp8_accum.py
Normal file
397
python/test/test_fp8_accum.py
Normal file
@@ -0,0 +1,397 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# Correctness test for FP8 allreduce with different accumulation types.
|
||||
#
|
||||
# Verifies that FP8 allreduce with higher-precision accumulation produces
|
||||
# results at least as accurate as native FP8 accumulation, by comparing
|
||||
# against a float32 reference.
|
||||
#
|
||||
# Usage:
|
||||
# mpirun -np 8 pytest python/test/test_fp8_accum.py -v
|
||||
|
||||
import cupy as cp
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mscclpp import CommGroup, GpuBuffer, DataType, ReduceOp, is_nvls_supported
|
||||
from mscclpp.ext import AlgorithmCollectionBuilder
|
||||
from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group
|
||||
|
||||
# FP8 E4M3 (hardware) requires SM >= 89 (Ada / Hopper) on NVIDIA GPUs.
|
||||
# On AMD/ROCm (e.g. MI300X), FP8 is supported natively — no skip needed.
|
||||
_is_hip = hasattr(cp.cuda.runtime, "is_hip") and cp.cuda.runtime.is_hip
|
||||
_skip_fp8 = not _is_hip and int(cp.cuda.Device().compute_capability) < 89
|
||||
pytestmark = pytest.mark.skipif(_skip_fp8, reason="FP8 accum tests require SM >= 89 on CUDA")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FP8 E4M3FN helpers (bias=7, no infinity, NaN = exp=15 & mant=7)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def e4m3fn_to_float(uint8_array):
|
||||
"""Decode a cupy uint8 array of E4M3FN bit patterns to float32."""
|
||||
bits = uint8_array.astype(cp.int32)
|
||||
sign = (bits >> 7) & 1
|
||||
exp = (bits >> 3) & 0xF
|
||||
mant = bits & 0x7
|
||||
|
||||
# Normal: (-1)^s * 2^(exp-7) * (1 + mant/8)
|
||||
normal_val = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), (exp - 7).astype(cp.int32))
|
||||
# Subnormal (exp==0): (-1)^s * 2^(-6) * (mant/8)
|
||||
subnormal_val = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-6))
|
||||
|
||||
result = cp.where(exp == 0, subnormal_val, normal_val)
|
||||
result = cp.where(sign == 1, -result, result)
|
||||
# Zero
|
||||
result = cp.where((exp == 0) & (mant == 0), cp.float32(0.0), result)
|
||||
# NaN: exp==15 & mant==7
|
||||
nan_mask = (exp == 15) & (mant == 7)
|
||||
result = cp.where(nan_mask, cp.float32(float("nan")), result)
|
||||
return result
|
||||
|
||||
|
||||
def float_to_e4m3fn(f32_array, chunk_size=65536):
|
||||
"""Encode a cupy float32 array to uint8 E4M3FN bit patterns.
|
||||
|
||||
Uses a lookup-table approach: precompute all 128 positive E4M3FN values,
|
||||
then find nearest match per element via chunked broadcast comparison.
|
||||
"""
|
||||
# Build lookup table of all 128 positive E4M3FN values (0x00..0x7F)
|
||||
all_bytes = cp.arange(128, dtype=cp.uint8)
|
||||
all_floats = e4m3fn_to_float(all_bytes) # (128,) float32
|
||||
# Mark NaN entries as inf so they're never selected as nearest
|
||||
all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats)
|
||||
|
||||
# Clamp input and extract sign
|
||||
clamped = f32_array.astype(cp.float32)
|
||||
clamped = cp.clip(clamped, -448.0, 448.0)
|
||||
signs = (clamped < 0).astype(cp.uint8)
|
||||
absval = cp.abs(clamped)
|
||||
|
||||
result = cp.zeros(absval.shape, dtype=cp.uint8)
|
||||
n = absval.size
|
||||
absval_flat = absval.ravel()
|
||||
result_flat = result.ravel()
|
||||
|
||||
for start in range(0, n, chunk_size):
|
||||
end = min(start + chunk_size, n)
|
||||
chunk = absval_flat[start:end]
|
||||
# (chunk_size, 128) difference matrix
|
||||
diffs = cp.abs(chunk[:, None] - all_floats[None, :])
|
||||
result_flat[start:end] = cp.argmin(diffs, axis=1).astype(cp.uint8)
|
||||
|
||||
# Combine with sign bit
|
||||
result = result_flat.reshape(absval.shape)
|
||||
result = result | (signs << 7)
|
||||
# Handle exact zero
|
||||
result = cp.where(absval == 0, cp.uint8(0), result)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FP8 E4M3B15 helpers (bias=15, max=0.9375, NaN = exp==15 or bits==0x80)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def e4m3b15_to_float(uint8_array):
|
||||
"""Decode a cupy uint8 array of E4M3B15 bit patterns to float32."""
|
||||
bits = uint8_array.astype(cp.int32)
|
||||
sign = (bits >> 7) & 1
|
||||
exp = (bits >> 3) & 0xF
|
||||
mant = bits & 0x7
|
||||
|
||||
# Normal: (-1)^s * 2^(exp-15) * (1 + mant/8)
|
||||
normal_val = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), (exp - 15).astype(cp.int32))
|
||||
# Subnormal (exp==0): (-1)^s * 2^(-14) * (mant/8)
|
||||
subnormal_val = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-14))
|
||||
|
||||
result = cp.where(exp == 0, subnormal_val, normal_val)
|
||||
result = cp.where(sign == 1, -result, result)
|
||||
# Zero
|
||||
result = cp.where((exp == 0) & (mant == 0), cp.float32(0.0), result)
|
||||
# NaN: exp==15 or negative zero (0x80)
|
||||
nan_mask = (exp == 15) | (uint8_array.astype(cp.int32) == 0x80)
|
||||
result = cp.where(nan_mask, cp.float32(float("nan")), result)
|
||||
return result
|
||||
|
||||
|
||||
def float_to_e4m3b15(f32_array, chunk_size=65536):
|
||||
"""Encode a cupy float32 array to uint8 E4M3B15 bit patterns.
|
||||
|
||||
Same lookup-table approach as float_to_e4m3fn.
|
||||
"""
|
||||
# Build lookup table of all 128 positive E4M3B15 values (0x00..0x7F)
|
||||
all_bytes = cp.arange(128, dtype=cp.uint8)
|
||||
all_floats = e4m3b15_to_float(all_bytes) # (128,) float32
|
||||
# Mark NaN entries as inf so they're never selected as nearest
|
||||
all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats)
|
||||
|
||||
# Clamp input and extract sign
|
||||
clamped = f32_array.astype(cp.float32)
|
||||
clamped = cp.clip(clamped, -0.9375, 0.9375)
|
||||
signs = (clamped < 0).astype(cp.uint8)
|
||||
absval = cp.abs(clamped)
|
||||
|
||||
result = cp.zeros(absval.shape, dtype=cp.uint8)
|
||||
n = absval.size
|
||||
absval_flat = absval.ravel()
|
||||
result_flat = result.ravel()
|
||||
|
||||
for start in range(0, n, chunk_size):
|
||||
end = min(start + chunk_size, n)
|
||||
chunk = absval_flat[start:end]
|
||||
# (chunk_size, 128) difference matrix
|
||||
diffs = cp.abs(chunk[:, None] - all_floats[None, :])
|
||||
result_flat[start:end] = cp.argmin(diffs, axis=1).astype(cp.uint8)
|
||||
|
||||
# Combine with sign bit
|
||||
result = result_flat.reshape(absval.shape)
|
||||
result = result | (signs << 7)
|
||||
# Handle exact zero
|
||||
result = cp.where(absval == 0, cp.uint8(0), result)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared test helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def setup_algorithms(mpi_group):
|
||||
"""Build default algorithms and return (comm_group, algo_map, scratch_buf)."""
|
||||
comm_group = CommGroup(mpi_group.comm)
|
||||
scratch = GpuBuffer(1 << 27, dtype=cp.uint8) # 128 MB
|
||||
AlgorithmCollectionBuilder.reset()
|
||||
builder = AlgorithmCollectionBuilder()
|
||||
algorithms = builder.build_default_algorithms(
|
||||
scratch_buffer=scratch.data.ptr,
|
||||
scratch_buffer_size=scratch.nbytes,
|
||||
rank=comm_group.my_rank,
|
||||
)
|
||||
algo_map = {a.name: a for a in algorithms}
|
||||
return comm_group, algo_map, scratch
|
||||
|
||||
|
||||
def run_allreduce(algo, comm_group, buffer, dtype, accum_dtype=None, nblocks=0, nthreads_per_block=0):
|
||||
"""Run allreduce in-place on buffer and return a copy of the result."""
|
||||
ret = algo.execute(
|
||||
comm=comm_group.communicator,
|
||||
input_buffer=buffer.data.ptr,
|
||||
output_buffer=buffer.data.ptr,
|
||||
input_size=buffer.nbytes,
|
||||
output_size=buffer.nbytes,
|
||||
dtype=dtype,
|
||||
op=ReduceOp.SUM,
|
||||
stream=cp.cuda.get_current_stream().ptr,
|
||||
nblocks=nblocks,
|
||||
nthreads_per_block=nthreads_per_block,
|
||||
symmetric_memory=True,
|
||||
accum_dtype=accum_dtype,
|
||||
)
|
||||
cp.cuda.Device().synchronize()
|
||||
assert ret == 0, f"Allreduce failed with error code {ret}"
|
||||
return buffer.copy()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: FP8 E4M3 accumulation correctness
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@parametrize_mpi_groups(8)
|
||||
@pytest.mark.parametrize(
|
||||
"algo_name",
|
||||
[
|
||||
"default_allreduce_packet",
|
||||
"default_allreduce_nvls_packet",
|
||||
"default_allreduce_fullmesh",
|
||||
"default_allreduce_rsag_zero_copy",
|
||||
"default_allreduce_allpair_packet",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("size", [1024, 4096, 16384, 65536, 262144, 1048576])
|
||||
def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
"""Verify that FP8 E4M3 allreduce with higher-precision accumulation is at
|
||||
least as accurate as native FP8 accumulation, across all algorithm variants."""
|
||||
rank = mpi_group.comm.rank
|
||||
world_size = mpi_group.comm.size
|
||||
|
||||
comm_group, algo_map, scratch = setup_algorithms(mpi_group)
|
||||
if algo_name not in algo_map:
|
||||
pytest.skip(f"{algo_name} not available")
|
||||
if "nvls" in algo_name and not is_nvls_supported():
|
||||
pytest.skip(f"{algo_name} requires NVLS which is not supported on this platform")
|
||||
algo = algo_map[algo_name]
|
||||
|
||||
buf = GpuBuffer(size, dtype=cp.uint8)
|
||||
|
||||
accum_configs = [
|
||||
("fp8_native", DataType.float8_e4m3),
|
||||
("float16", DataType.float16),
|
||||
("float32", DataType.float32),
|
||||
]
|
||||
|
||||
# rsag_zero_copy and fullmesh need explicit block/thread counts
|
||||
if "rsag" in algo_name:
|
||||
nb = max(1, min(32, size // (world_size * 32)))
|
||||
nt = 1024
|
||||
elif "fullmesh" in algo_name:
|
||||
nb = 35
|
||||
nt = 512
|
||||
else:
|
||||
nb = 0
|
||||
nt = 0
|
||||
|
||||
errors = {}
|
||||
for accum_label, accum_dtype in accum_configs:
|
||||
# Generate deterministic per-rank data (use numpy to avoid hipRAND issues on ROCm)
|
||||
rng = np.random.RandomState(42 + rank)
|
||||
src_f32 = cp.asarray(rng.randn(size).astype(np.float32))
|
||||
src_f32 = cp.clip(src_f32, -240.0, 240.0)
|
||||
src_fp8 = float_to_e4m3fn(src_f32)
|
||||
|
||||
# Copy into symmetric buffer
|
||||
buf[:] = src_fp8
|
||||
cp.cuda.Device().synchronize()
|
||||
|
||||
# Run allreduce
|
||||
result = run_allreduce(
|
||||
algo,
|
||||
comm_group,
|
||||
buf,
|
||||
dtype=DataType.float8_e4m3,
|
||||
accum_dtype=accum_dtype,
|
||||
nblocks=nb,
|
||||
nthreads_per_block=nt,
|
||||
)
|
||||
result_f32 = e4m3fn_to_float(result)
|
||||
|
||||
# Compute float32 reference: sum all ranks' quantized FP8 inputs in float32
|
||||
ref_f32 = cp.zeros(size, dtype=cp.float32)
|
||||
for r in range(world_size):
|
||||
rng_r = np.random.RandomState(42 + r)
|
||||
rank_data = cp.asarray(rng_r.randn(size).astype(np.float32))
|
||||
rank_data = cp.clip(rank_data, -240.0, 240.0)
|
||||
rank_data_fp8 = float_to_e4m3fn(rank_data)
|
||||
ref_f32 += e4m3fn_to_float(rank_data_fp8)
|
||||
|
||||
# Compute errors
|
||||
abs_err = cp.abs(result_f32 - ref_f32)
|
||||
mean_abs_err = float(cp.mean(abs_err))
|
||||
errors[accum_label] = mean_abs_err
|
||||
|
||||
# Reset between runs
|
||||
algo.reset()
|
||||
|
||||
# Higher-precision accumulation should be at least as accurate as native fp8
|
||||
assert (
|
||||
errors["float16"] <= errors["fp8_native"] + 1e-6
|
||||
), f"float16 accum ({errors['float16']:.6f}) worse than native ({errors['fp8_native']:.6f})"
|
||||
assert (
|
||||
errors["float32"] <= errors["fp8_native"] + 1e-6
|
||||
), f"float32 accum ({errors['float32']:.6f}) worse than native ({errors['fp8_native']:.6f})"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: FP8 E4M3B15 accumulation correctness
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@parametrize_mpi_groups(8)
|
||||
@pytest.mark.parametrize(
|
||||
"algo_name",
|
||||
[
|
||||
"default_allreduce_packet",
|
||||
"default_allreduce_nvls_packet",
|
||||
"default_allreduce_rsag_zero_copy",
|
||||
"default_allreduce_fullmesh",
|
||||
"default_allreduce_allpair_packet",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("size", [1024, 4096, 65536])
|
||||
def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
"""Verify that FP8 E4M3B15 allreduce with higher-precision accumulation is at
|
||||
least as accurate as native E4M3B15 accumulation."""
|
||||
rank = mpi_group.comm.rank
|
||||
world_size = mpi_group.comm.size
|
||||
|
||||
comm_group, algo_map, scratch = setup_algorithms(mpi_group)
|
||||
if algo_name not in algo_map:
|
||||
pytest.skip(f"{algo_name} not available")
|
||||
if "nvls" in algo_name and not is_nvls_supported():
|
||||
pytest.skip(f"{algo_name} requires NVLS which is not supported on this platform")
|
||||
|
||||
algo = algo_map[algo_name]
|
||||
buf = GpuBuffer(size, dtype=cp.uint8)
|
||||
|
||||
accum_configs = [
|
||||
("e4m3b15_native", DataType.float8_e4m3b15),
|
||||
("float16", DataType.float16),
|
||||
("float32", DataType.float32),
|
||||
]
|
||||
|
||||
# rsag_zero_copy needs explicit block/thread counts, scaled to data size
|
||||
if "rsag" in algo_name:
|
||||
nb = max(1, min(32, size // (world_size * 32)))
|
||||
nt = 1024
|
||||
else:
|
||||
nb = 0
|
||||
nt = 0
|
||||
|
||||
errors = {}
|
||||
for accum_label, accum_dtype in accum_configs:
|
||||
# Generate deterministic per-rank random uint8 values in valid e4m3b15 range
|
||||
rng = np.random.RandomState(42 + rank)
|
||||
raw = cp.asarray(rng.randint(0, 0x78, (size,)).astype(np.uint8))
|
||||
signs = cp.asarray(rng.randint(0, 2, (size,)).astype(np.uint8)) << 7
|
||||
src_uint8 = raw | signs
|
||||
# Fix negative zero -> positive zero
|
||||
src_uint8 = cp.where(src_uint8 == 0x80, cp.uint8(0), src_uint8)
|
||||
|
||||
# Copy into symmetric buffer
|
||||
buf[:] = src_uint8
|
||||
cp.cuda.Device().synchronize()
|
||||
|
||||
# Run allreduce
|
||||
result = run_allreduce(
|
||||
algo,
|
||||
comm_group,
|
||||
buf,
|
||||
dtype=DataType.float8_e4m3b15,
|
||||
accum_dtype=accum_dtype,
|
||||
nblocks=nb,
|
||||
nthreads_per_block=nt,
|
||||
)
|
||||
|
||||
# Decode result
|
||||
result_f32 = e4m3b15_to_float(result)
|
||||
|
||||
# Compute float32 reference
|
||||
ref_f32 = cp.zeros(size, dtype=cp.float32)
|
||||
for r in range(world_size):
|
||||
rng_r = np.random.RandomState(42 + r)
|
||||
raw_r = cp.asarray(rng_r.randint(0, 0x78, (size,)).astype(np.uint8))
|
||||
signs_r = cp.asarray(rng_r.randint(0, 2, (size,)).astype(np.uint8)) << 7
|
||||
bits_r = raw_r | signs_r
|
||||
bits_r = cp.where(bits_r == 0x80, cp.uint8(0), bits_r)
|
||||
ref_f32 += e4m3b15_to_float(bits_r)
|
||||
|
||||
# Clamp reference to e4m3b15 representable range
|
||||
ref_f32 = cp.clip(ref_f32, -0.9375, 0.9375)
|
||||
|
||||
# Compute errors (only on valid entries)
|
||||
valid = ~cp.isnan(result_f32) & ~cp.isnan(ref_f32)
|
||||
abs_err = cp.abs(result_f32[valid] - ref_f32[valid])
|
||||
mean_abs_err = float(cp.mean(abs_err)) if abs_err.size > 0 else 0.0
|
||||
errors[accum_label] = mean_abs_err
|
||||
|
||||
algo.reset()
|
||||
|
||||
# Higher-precision accumulation should be at least as accurate as native
|
||||
assert (
|
||||
errors["float16"] <= errors["e4m3b15_native"] + 1e-8
|
||||
), f"float16 accum ({errors['float16']:.8f}) worse than native ({errors['e4m3b15_native']:.8f})"
|
||||
assert (
|
||||
errors["float32"] <= errors["e4m3b15_native"] + 1e-8
|
||||
), f"float32 accum ({errors['float32']:.8f}) worse than native ({errors['e4m3b15_native']:.8f})"
|
||||
Reference in New Issue
Block a user