Files
mscclpp/python/mscclpp_benchmark/correctness.py
Binyang Li c9f8be64bb Add collective benchmark and correctness check (#814)
- Add unit-test for float8_e4m3b15 data type.
- And tuner and benchmark for allreduce/allgather algo, make sure the
correctness and performance.
2026-06-04 09:22:10 -07:00

403 lines
15 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Any
import cupy as cp
from mpi4py import MPI
_mscclpp_module = None
def _mscclpp():
global _mscclpp_module
if _mscclpp_module is None:
import mscclpp
_mscclpp_module = mscclpp
return _mscclpp_module
@dataclass(frozen=True)
class CorrectnessStats:
ok: bool
max_abs_diff: float = 0.0
mean_abs_diff: float = 0.0
mismatches: int = 0
total: int = 0
def __bool__(self) -> bool:
return self.ok
def config_accum_dtype(case: Any) -> Any:
return case.dtype_spec.accum_dtype or case.dtype_spec.mscclpp_dtype
def fill_case_for_benchmark(case: Any, rank: int) -> None:
values = _benchmark_input_values(case, rank)
encoded = _encode_correctness_input(case, values)
if case.collective == "allreduce":
case.input[...] = encoded
return
case.output.fill(0)
case.input[...] = encoded
def check_correctness(
comm: Any,
case: Any,
config: Any,
*,
niter: int = 1,
) -> CorrectnessStats:
all_ok = True
local_max_abs_diff = 0.0
local_sum_abs_diff = 0.0
local_mismatches = 0
local_total = 0
for iteration in range(niter):
_fill_case_for_correctness(case, comm.rank, iteration)
comm.comm_group.barrier()
ret = comm.run(case, config)
cp.cuda.runtime.deviceSynchronize()
comm.comm_group.barrier()
if ret != 0:
all_ok = False
continue
expected, stats_expected = _expected_outputs(case, comm.nranks, iteration)
iter_stats = _local_diff_stats(case, case.output, expected, comm.nranks, stats_expected=stats_expected)
local_ok = _compare_output(case, case.output, expected, comm.nranks)
all_ok = all_ok and local_ok
local_max_abs_diff = max(local_max_abs_diff, iter_stats.max_abs_diff)
local_sum_abs_diff += iter_stats.mean_abs_diff * iter_stats.total
local_mismatches += iter_stats.mismatches
local_total += iter_stats.total
if not local_ok:
mismatch = _mismatch_mask(case, case.output, expected, comm.nranks)
print(
"not close: "
f"iter={iteration}, rank={comm.rank}, output={case.output[mismatch][0]}, "
f"expected={expected[mismatch][0]}, max_abs_diff={iter_stats.max_abs_diff:.6g}, "
f"mean_abs_diff={iter_stats.mean_abs_diff:.6g}, mismatches={iter_stats.mismatches}/{iter_stats.total}",
flush=True,
)
global_ok = bool(MPI.COMM_WORLD.allreduce(all_ok, op=MPI.LAND))
global_max_abs_diff = float(MPI.COMM_WORLD.allreduce(local_max_abs_diff, op=MPI.MAX))
global_sum_abs_diff = float(MPI.COMM_WORLD.allreduce(local_sum_abs_diff, op=MPI.SUM))
global_mismatches = int(MPI.COMM_WORLD.allreduce(local_mismatches, op=MPI.SUM))
global_total = int(MPI.COMM_WORLD.allreduce(local_total, op=MPI.SUM))
global_mean_abs_diff = global_sum_abs_diff / global_total if global_total else 0.0
return CorrectnessStats(
ok=global_ok,
max_abs_diff=global_max_abs_diff,
mean_abs_diff=global_mean_abs_diff,
mismatches=global_mismatches,
total=global_total,
)
def _fill_case_for_correctness(case: Any, rank: int, iteration: int) -> None:
values = _correctness_input_values(case, rank, iteration)
encoded = _encode_correctness_input(case, values)
if case.collective == "allreduce":
case.input[...] = encoded
return
case.output.fill(0)
case.input[...] = encoded
def _correctness_input_values(case: Any, rank: int, iteration: int):
shape = case.input.shape
rng = cp.random.RandomState(_correctness_seed(rank, iteration))
return _random_input_values(case, rng, shape)
def _benchmark_input_values(case: Any, rank: int):
rng = cp.random.RandomState(17_000_003 + rank)
return _random_input_values(case, rng, case.input.shape)
def _random_input_values(case: Any, rng, shape):
if case.dtype_spec.fp8_format is not None:
value_range = _fp8_correctness_input_range(case)
return rng.uniform(-value_range, value_range, size=shape).astype(cp.float32)
if case.dtype_spec.cupy_dtype == cp.int32:
return rng.randint(-1, 2, size=shape).astype(cp.int32)
if case.dtype_spec.cupy_dtype == cp.uint8:
return rng.randint(0, 2, size=shape).astype(cp.uint8)
return rng.uniform(-1.0, 1.0, size=shape).astype(cp.float32)
def _correctness_seed(rank: int, iteration: int) -> int:
return (iteration + 1) * 1_000_003 + rank
def _fp8_correctness_input_range(case: Any) -> float:
if case.collective != "allreduce":
return 1.0
fp8_format = case.dtype_spec.fp8_format
if fp8_format is None:
return 1.0
return min(1.0, _fp8_max_abs_value(fp8_format) / max(1, MPI.COMM_WORLD.size))
def _encode_correctness_input(case: Any, values):
if case.dtype_spec.fp8_format is not None:
# FP8 buffers are stored as uint8 raw bytes, so a normal astype(uint8) cast would not produce FP8 bits.
return _encode_fp8_values(case.dtype_spec.fp8_format, values)
return values.astype(case.dtype_spec.cupy_dtype)
def _local_diff_stats(case: Any, output, expected, nranks: int, *, stats_expected=None) -> CorrectnessStats:
mismatch = _mismatch_mask(case, output, expected, nranks)
mismatches = int(cp.count_nonzero(mismatch).item())
total = int(output.size)
if total == 0:
return CorrectnessStats(ok=mismatches == 0)
output_values = _stats_values(case, output)
expected_values = _stats_values(case, expected) if stats_expected is None else stats_expected.astype(cp.float64)
abs_diff = cp.abs(output_values - expected_values)
return CorrectnessStats(
ok=mismatches == 0,
max_abs_diff=float(cp.max(abs_diff).item()),
mean_abs_diff=float(cp.mean(abs_diff).item()),
mismatches=mismatches,
total=total,
)
def _stats_values(case: Any, values):
# Convert storage buffers into numeric values before computing max/mean diff.
if case.dtype_spec.fp8_format is not None:
return _decode_fp8_array(case.dtype_spec.fp8_format, values)
if cp.issubdtype(values.dtype, cp.floating):
return values.astype(cp.float64)
return values.astype(cp.int64)
def _expected_outputs(case: Any, nranks: int, iteration: int):
if case.collective == "allreduce":
encoded_inputs = _encoded_rank_inputs(case, nranks, iteration)
if case.dtype_spec.fp8_format is not None:
stats_expected = _expected_fp8_accum_values(case, encoded_inputs)
return _encode_reduced_output(case, stats_expected), stats_expected
return _encode_reduced_output(case, sum(values.astype(cp.float32) for values in encoded_inputs)), None
expected = cp.empty_like(case.output)
chunk = case.input.size
for rank, values in enumerate(_encoded_rank_inputs(case, nranks, iteration)):
expected[rank * chunk : (rank + 1) * chunk] = values.reshape(-1)
return expected, None
def _encoded_rank_inputs(case: Any, nranks: int, iteration: int) -> list[Any]:
return [_encode_correctness_input(case, _correctness_input_values(case, rank, iteration)) for rank in range(nranks)]
def _expected_fp8_accum_values(case: Any, encoded_inputs: list[Any]):
fp8_format = case.dtype_spec.fp8_format
if fp8_format is None:
raise ValueError("FP8 format is required")
accum_dtype = config_accum_dtype(case)
if accum_dtype == _mscclpp().DataType.float16:
acc = cp.zeros_like(_decode_fp8_array(fp8_format, encoded_inputs[0]), dtype=cp.float16)
for values in encoded_inputs:
acc = (acc + _decode_fp8_array(fp8_format, values).astype(cp.float16)).astype(cp.float16)
return acc.astype(cp.float32)
if accum_dtype == _mscclpp().DataType.float32:
acc = cp.zeros_like(_decode_fp8_array(fp8_format, encoded_inputs[0]), dtype=cp.float32)
for values in encoded_inputs:
acc += _decode_fp8_array(fp8_format, values).astype(cp.float32)
return acc
acc = encoded_inputs[0]
for values in encoded_inputs[1:]:
acc = _encode_fp8_values(fp8_format, _decode_fp8_array(fp8_format, acc) + _decode_fp8_array(fp8_format, values))
return _decode_fp8_array(fp8_format, acc).astype(cp.float32)
def _encode_reduced_output(case: Any, values):
if case.dtype_spec.fp8_format is not None:
return _encode_fp8_values(case.dtype_spec.fp8_format, values)
return values.astype(case.output.dtype)
def _compare_output(case: Any, output, expected, nranks: int) -> bool:
return bool(cp.all(~_mismatch_mask(case, output, expected, nranks)).item())
def _mismatch_mask(case: Any, output, expected, nranks: int):
tolerance = _comparison_tolerance(case, nranks)
if tolerance is None:
return output != expected
rtol, atol = tolerance
return ~cp.isclose(_stats_values(case, output), _stats_values(case, expected), rtol=rtol, atol=atol)
def _comparison_tolerance(case: Any, nranks: int) -> tuple[float, float] | None:
scale = max(1, nranks) if case.collective == "allreduce" else 1
if case.dtype_spec.fp8_format is not None:
accum_dtype = config_accum_dtype(case)
if accum_dtype == _mscclpp().DataType.float32:
return None
atol = _max_fp8_spacing(case.dtype_spec.fp8_format, float(scale))
if accum_dtype == _mscclpp().DataType.float16:
return (0.0, atol)
return (0.0, atol * 2)
if case.dtype_spec.cupy_dtype == cp.float16:
return (1.0e-2, 5.0e-4 * scale)
if case.dtype_spec.cupy_dtype == cp.float32:
return (1.0e-5 * scale, 1.0e-6 * scale)
return None
_FP8_TABLES: dict[str, list[tuple[int, float]]] = {}
_FP8_LOOKUP_CACHE: dict[str, tuple[Any, Any]] = {}
_FP8_SPACING_CACHE: dict[tuple[str, float], float] = {}
def _encode_fp8_values(fp8_format: str, values):
values = values.astype(cp.float32)
if fp8_format == "e4m3b15":
return _encode_e4m3b15_values(values)
# Round each value to the nearest representable FP8 value (ties to even).
table_values, table_bytes = _fp8_lookup_arrays(fp8_format)
flat_values = values.ravel()
# For each value find its two surrounding table entries: lower <= value <= upper.
upper = cp.clip(cp.searchsorted(table_values, flat_values), 1, table_values.size - 1)
lower = upper - 1
# Pick the closer neighbor; on an exact tie pick the one with an even byte.
dist_to_upper = table_values[upper] - flat_values
dist_to_lower = flat_values - table_values[lower]
upper_is_even = (table_bytes[upper] & cp.uint8(1)) == 0
pick_upper = (dist_to_upper < dist_to_lower) | ((dist_to_upper == dist_to_lower) & upper_is_even)
return cp.where(pick_upper, table_bytes[upper], table_bytes[lower]).reshape(values.shape)
def _fp8_lookup_arrays(fp8_format: str):
# Cache a sorted (value -> byte) table per format for fast nearest-value lookup.
if fp8_format in _FP8_LOOKUP_CACHE:
return _FP8_LOOKUP_CACHE[fp8_format]
# Different bytes can decode to the same value (e.g. +0 and -0); keep one byte per value.
byte_for_value: dict[float, int] = {}
for byte, value in _FP8_TABLES.setdefault(fp8_format, _build_fp8_table(fp8_format)):
if value not in byte_for_value or byte < byte_for_value[value]:
byte_for_value[value] = byte
table = sorted(byte_for_value.items())
table_values = cp.asarray([value for value, _ in table], dtype=cp.float32)
table_bytes = cp.asarray([byte for _, byte in table], dtype=cp.uint8)
_FP8_LOOKUP_CACHE[fp8_format] = (table_values, table_bytes)
return _FP8_LOOKUP_CACHE[fp8_format]
def _max_fp8_spacing(fp8_format: str, max_abs_value: float) -> float:
cache_key = (fp8_format, max_abs_value)
if cache_key in _FP8_SPACING_CACHE:
return _FP8_SPACING_CACHE[cache_key]
values = sorted(
{
value
for _, value in _FP8_TABLES.setdefault(fp8_format, _build_fp8_table(fp8_format))
if abs(value) <= max_abs_value
}
)
if len(values) < 2:
spacing = 0.0
else:
spacing = max(right - left for left, right in zip(values, values[1:]))
_FP8_SPACING_CACHE[cache_key] = spacing
return spacing
def _fp8_max_abs_value(fp8_format: str) -> float:
return max(abs(value) for _, value in _FP8_TABLES.setdefault(fp8_format, _build_fp8_table(fp8_format)))
def _encode_e4m3b15_values(values):
# Mirrors the device e4m3b15 encode (gpu_data_types.hpp): clamp the fp16 intermediate
# to 0x3F80 (+/-1.875) so the max encodable byte is 0x7F/0xFF.
fp16_bits = values.astype(cp.float16).view(cp.uint16)
abs_fp16 = fp16_bits & cp.uint16(0x7FFF)
abs_fp16 = cp.minimum(abs_fp16, cp.uint16(0x3F80)).astype(cp.uint32)
sign16 = (fp16_bits & cp.uint16(0x8000)).astype(cp.uint32)
adjusted = abs_fp16 * cp.uint32(2) + cp.uint32(0x0080)
return (((sign16 | adjusted) >> cp.uint32(8)) & cp.uint32(0xFF)).astype(cp.uint8)
def _build_fp8_table(fp8_format: str) -> list[tuple[int, float]]:
table = []
for byte in range(256):
value = _decode_fp8_scalar(fp8_format, byte)
if not math.isnan(value):
table.append((byte, value))
return table
def _decode_fp8_scalar(fp8_format: str, byte: int) -> float:
if fp8_format == "e4m3fnuz" and byte == 0x80:
return float("nan")
sign = -1.0 if byte & 0x80 else 1.0
return sign * _decode_fp8_positive(fp8_format, byte & 0x7F)
def _decode_fp8_positive(fp8_format: str, byte: int) -> float:
exp = (byte >> 3) & 0xF
mant = byte & 0x7
if fp8_format == "e4m3fn" and exp == 0xF and mant == 0x7:
return float("nan")
if exp == 0 and mant == 0:
return 0.0
if fp8_format == "e4m3fn":
return math.ldexp(mant / 8.0, -6) if exp == 0 else math.ldexp(1.0 + mant / 8.0, exp - 7)
if fp8_format == "e4m3fnuz":
return math.ldexp(mant / 8.0, -7) if exp == 0 else math.ldexp(1.0 + mant / 8.0, exp - 8)
if fp8_format == "e4m3b15":
return math.ldexp(mant / 8.0, -14) if exp == 0 else math.ldexp(1.0 + mant / 8.0, exp - 15)
raise ValueError(f"Unknown FP8 format: {fp8_format}")
def _decode_fp8_array(fp8_format: str, values):
bits = values.astype(cp.int32)
sign = (bits >> 7) & 1
exp = (bits >> 3) & 0xF
mant = bits & 0x7
if fp8_format == "e4m3fn":
subnormal = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-6))
normal = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), exp.astype(cp.int32) - 7)
decoded = cp.where(exp == 0, subnormal, normal)
decoded = cp.where((exp == 0xF) & (mant == 0x7), cp.nan, decoded)
elif fp8_format == "e4m3fnuz":
subnormal = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-7))
normal = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), exp.astype(cp.int32) - 8)
decoded = cp.where(exp == 0, subnormal, normal)
elif fp8_format == "e4m3b15":
subnormal = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-14))
normal = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), exp.astype(cp.int32) - 15)
decoded = cp.where(exp == 0, subnormal, normal)
else:
raise ValueError(f"Unknown FP8 format: {fp8_format}")
result = cp.where(sign == 1, -decoded, decoded)
if fp8_format == "e4m3fnuz":
result = cp.where(bits == 0x80, cp.float32(float("nan")), result)
return result