Files
mscclpp/python/mscclpp_benchmark/correctness.py
Binyang Li 9aab9cacc0 support rocm7.2 (#819)
This pull request introduces support for ROCm 7.2 across the build
system, CI pipelines, Docker images, and documentation, while also
improving ROCm FP8 type selection and CUDA IPC memory handle management.
It updates dependencies and configurations to ensure compatibility with
ROCm 7.2, adds new options for native FP8 variants, and refines some
benchmarking and internal memory handling logic.

Pls notice: there is an issue in rocm7.2 (rocm7.2 user lib + rocm6.2
driver) when execution code in this order: allocating memory -> ipc
communication -> allocate new memory -> free old memory.

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-06-24 16:09:34 -07:00

402 lines
15 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Any
import cupy as cp
from mpi4py import MPI
_mscclpp_module = None
def _mscclpp():
global _mscclpp_module
if _mscclpp_module is None:
import mscclpp
_mscclpp_module = mscclpp
return _mscclpp_module
@dataclass(frozen=True)
class CorrectnessStats:
ok: bool
max_abs_diff: float = 0.0
mean_abs_diff: float = 0.0
mismatches: int = 0
total: int = 0
def __bool__(self) -> bool:
return self.ok
def config_accum_dtype(case: Any) -> Any:
return case.dtype_spec.accum_dtype or case.dtype_spec.mscclpp_dtype
def fill_case_for_benchmark(case: Any, rank: int) -> None:
values = _benchmark_input_values(case, rank)
encoded = _encode_correctness_input(case, values)
if case.collective == "allreduce":
case.input[...] = encoded
return
case.output.fill(0)
case.input[...] = encoded
def check_correctness(
comm: Any,
case: Any,
config: Any,
*,
niter: int = 1,
) -> CorrectnessStats:
all_ok = True
local_max_abs_diff = 0.0
local_sum_abs_diff = 0.0
local_mismatches = 0
local_total = 0
for iteration in range(niter):
_fill_case_for_correctness(case, comm.rank, iteration)
ret = comm.run(case, config)
cp.cuda.runtime.deviceSynchronize()
comm.comm_group.barrier()
if ret != 0:
all_ok = False
continue
expected, stats_expected = _expected_outputs(case, comm.nranks, iteration)
iter_stats = _local_diff_stats(case, case.output, expected, comm.nranks, stats_expected=stats_expected)
local_ok = _compare_output(case, case.output, expected, comm.nranks)
all_ok = all_ok and local_ok
local_max_abs_diff = max(local_max_abs_diff, iter_stats.max_abs_diff)
local_sum_abs_diff += iter_stats.mean_abs_diff * iter_stats.total
local_mismatches += iter_stats.mismatches
local_total += iter_stats.total
if not local_ok:
mismatch = _mismatch_mask(case, case.output, expected, comm.nranks)
print(
"not close: "
f"iter={iteration}, rank={comm.rank}, output={case.output[mismatch][0]}, "
f"expected={expected[mismatch][0]}, max_abs_diff={iter_stats.max_abs_diff:.6g}, "
f"mean_abs_diff={iter_stats.mean_abs_diff:.6g}, mismatches={iter_stats.mismatches}/{iter_stats.total}",
flush=True,
)
global_ok = bool(MPI.COMM_WORLD.allreduce(all_ok, op=MPI.LAND))
global_max_abs_diff = float(MPI.COMM_WORLD.allreduce(local_max_abs_diff, op=MPI.MAX))
global_sum_abs_diff = float(MPI.COMM_WORLD.allreduce(local_sum_abs_diff, op=MPI.SUM))
global_mismatches = int(MPI.COMM_WORLD.allreduce(local_mismatches, op=MPI.SUM))
global_total = int(MPI.COMM_WORLD.allreduce(local_total, op=MPI.SUM))
global_mean_abs_diff = global_sum_abs_diff / global_total if global_total else 0.0
return CorrectnessStats(
ok=global_ok,
max_abs_diff=global_max_abs_diff,
mean_abs_diff=global_mean_abs_diff,
mismatches=global_mismatches,
total=global_total,
)
def _fill_case_for_correctness(case: Any, rank: int, iteration: int) -> None:
values = _correctness_input_values(case, rank, iteration)
encoded = _encode_correctness_input(case, values)
if case.collective == "allreduce":
case.input[...] = encoded
return
case.output.fill(0)
case.input[...] = encoded
def _correctness_input_values(case: Any, rank: int, iteration: int):
shape = case.input.shape
rng = cp.random.RandomState(_correctness_seed(rank, iteration))
return _random_input_values(case, rng, shape)
def _benchmark_input_values(case: Any, rank: int):
rng = cp.random.RandomState(17_000_003 + rank)
return _random_input_values(case, rng, case.input.shape)
def _random_input_values(case: Any, rng, shape):
if case.dtype_spec.fp8_format is not None:
value_range = _fp8_correctness_input_range(case)
return rng.uniform(-value_range, value_range, size=shape).astype(cp.float32)
if case.dtype_spec.cupy_dtype == cp.int32:
return rng.randint(-1, 2, size=shape).astype(cp.int32)
if case.dtype_spec.cupy_dtype == cp.uint8:
return rng.randint(0, 2, size=shape).astype(cp.uint8)
return rng.uniform(-1.0, 1.0, size=shape).astype(cp.float32)
def _correctness_seed(rank: int, iteration: int) -> int:
return (iteration + 1) * 1_000_003 + rank
def _fp8_correctness_input_range(case: Any) -> float:
if case.collective != "allreduce":
return 1.0
fp8_format = case.dtype_spec.fp8_format
if fp8_format is None:
return 1.0
return min(1.0, _fp8_max_abs_value(fp8_format) / max(1, MPI.COMM_WORLD.size))
def _encode_correctness_input(case: Any, values):
if case.dtype_spec.fp8_format is not None:
# FP8 buffers are stored as uint8 raw bytes, so a normal astype(uint8) cast would not produce FP8 bits.
return _encode_fp8_values(case.dtype_spec.fp8_format, values)
return values.astype(case.dtype_spec.cupy_dtype)
def _local_diff_stats(case: Any, output, expected, nranks: int, *, stats_expected=None) -> CorrectnessStats:
mismatch = _mismatch_mask(case, output, expected, nranks)
mismatches = int(cp.count_nonzero(mismatch).item())
total = int(output.size)
if total == 0:
return CorrectnessStats(ok=mismatches == 0)
output_values = _stats_values(case, output)
expected_values = _stats_values(case, expected) if stats_expected is None else stats_expected.astype(cp.float64)
abs_diff = cp.abs(output_values - expected_values)
return CorrectnessStats(
ok=mismatches == 0,
max_abs_diff=float(cp.max(abs_diff).item()),
mean_abs_diff=float(cp.mean(abs_diff).item()),
mismatches=mismatches,
total=total,
)
def _stats_values(case: Any, values):
# Convert storage buffers into numeric values before computing max/mean diff.
if case.dtype_spec.fp8_format is not None:
return _decode_fp8_array(case.dtype_spec.fp8_format, values)
if cp.issubdtype(values.dtype, cp.floating):
return values.astype(cp.float64)
return values.astype(cp.int64)
def _expected_outputs(case: Any, nranks: int, iteration: int):
if case.collective == "allreduce":
encoded_inputs = _encoded_rank_inputs(case, nranks, iteration)
if case.dtype_spec.fp8_format is not None:
stats_expected = _expected_fp8_accum_values(case, encoded_inputs)
return _encode_reduced_output(case, stats_expected), stats_expected
return _encode_reduced_output(case, sum(values.astype(cp.float32) for values in encoded_inputs)), None
expected = cp.empty_like(case.output)
chunk = case.input.size
for rank, values in enumerate(_encoded_rank_inputs(case, nranks, iteration)):
expected[rank * chunk : (rank + 1) * chunk] = values.reshape(-1)
return expected, None
def _encoded_rank_inputs(case: Any, nranks: int, iteration: int) -> list[Any]:
return [_encode_correctness_input(case, _correctness_input_values(case, rank, iteration)) for rank in range(nranks)]
def _expected_fp8_accum_values(case: Any, encoded_inputs: list[Any]):
fp8_format = case.dtype_spec.fp8_format
if fp8_format is None:
raise ValueError("FP8 format is required")
accum_dtype = config_accum_dtype(case)
if accum_dtype == _mscclpp().DataType.float16:
acc = cp.zeros_like(_decode_fp8_array(fp8_format, encoded_inputs[0]), dtype=cp.float16)
for values in encoded_inputs:
acc = (acc + _decode_fp8_array(fp8_format, values).astype(cp.float16)).astype(cp.float16)
return acc.astype(cp.float32)
if accum_dtype == _mscclpp().DataType.float32:
acc = cp.zeros_like(_decode_fp8_array(fp8_format, encoded_inputs[0]), dtype=cp.float32)
for values in encoded_inputs:
acc += _decode_fp8_array(fp8_format, values).astype(cp.float32)
return acc
acc = encoded_inputs[0]
for values in encoded_inputs[1:]:
acc = _encode_fp8_values(fp8_format, _decode_fp8_array(fp8_format, acc) + _decode_fp8_array(fp8_format, values))
return _decode_fp8_array(fp8_format, acc).astype(cp.float32)
def _encode_reduced_output(case: Any, values):
if case.dtype_spec.fp8_format is not None:
return _encode_fp8_values(case.dtype_spec.fp8_format, values)
return values.astype(case.output.dtype)
def _compare_output(case: Any, output, expected, nranks: int) -> bool:
return bool(cp.all(~_mismatch_mask(case, output, expected, nranks)).item())
def _mismatch_mask(case: Any, output, expected, nranks: int):
tolerance = _comparison_tolerance(case, nranks)
if tolerance is None:
return output != expected
rtol, atol = tolerance
return ~cp.isclose(_stats_values(case, output), _stats_values(case, expected), rtol=rtol, atol=atol)
def _comparison_tolerance(case: Any, nranks: int) -> tuple[float, float] | None:
scale = max(1, nranks) if case.collective == "allreduce" else 1
if case.dtype_spec.fp8_format is not None:
accum_dtype = config_accum_dtype(case)
if accum_dtype == _mscclpp().DataType.float32:
return None
atol = _max_fp8_spacing(case.dtype_spec.fp8_format, float(scale))
if accum_dtype == _mscclpp().DataType.float16:
return (0.0, atol)
return (0.0, atol * 2)
if case.dtype_spec.cupy_dtype == cp.float16:
return (1.0e-2, 5.0e-4 * scale)
if case.dtype_spec.cupy_dtype == cp.float32:
return (1.0e-5 * scale, 1.0e-6 * scale)
return None
_FP8_TABLES: dict[str, list[tuple[int, float]]] = {}
_FP8_LOOKUP_CACHE: dict[str, tuple[Any, Any]] = {}
_FP8_SPACING_CACHE: dict[tuple[str, float], float] = {}
def _encode_fp8_values(fp8_format: str, values):
values = values.astype(cp.float32)
if fp8_format == "e4m3b15":
return _encode_e4m3b15_values(values)
# Round each value to the nearest representable FP8 value (ties to even).
table_values, table_bytes = _fp8_lookup_arrays(fp8_format)
flat_values = values.ravel()
# For each value find its two surrounding table entries: lower <= value <= upper.
upper = cp.clip(cp.searchsorted(table_values, flat_values), 1, table_values.size - 1)
lower = upper - 1
# Pick the closer neighbor; on an exact tie pick the one with an even byte.
dist_to_upper = table_values[upper] - flat_values
dist_to_lower = flat_values - table_values[lower]
upper_is_even = (table_bytes[upper] & cp.uint8(1)) == 0
pick_upper = (dist_to_upper < dist_to_lower) | ((dist_to_upper == dist_to_lower) & upper_is_even)
return cp.where(pick_upper, table_bytes[upper], table_bytes[lower]).reshape(values.shape)
def _fp8_lookup_arrays(fp8_format: str):
# Cache a sorted (value -> byte) table per format for fast nearest-value lookup.
if fp8_format in _FP8_LOOKUP_CACHE:
return _FP8_LOOKUP_CACHE[fp8_format]
# Different bytes can decode to the same value (e.g. +0 and -0); keep one byte per value.
byte_for_value: dict[float, int] = {}
for byte, value in _FP8_TABLES.setdefault(fp8_format, _build_fp8_table(fp8_format)):
if value not in byte_for_value or byte < byte_for_value[value]:
byte_for_value[value] = byte
table = sorted(byte_for_value.items())
table_values = cp.asarray([value for value, _ in table], dtype=cp.float32)
table_bytes = cp.asarray([byte for _, byte in table], dtype=cp.uint8)
_FP8_LOOKUP_CACHE[fp8_format] = (table_values, table_bytes)
return _FP8_LOOKUP_CACHE[fp8_format]
def _max_fp8_spacing(fp8_format: str, max_abs_value: float) -> float:
cache_key = (fp8_format, max_abs_value)
if cache_key in _FP8_SPACING_CACHE:
return _FP8_SPACING_CACHE[cache_key]
values = sorted(
{
value
for _, value in _FP8_TABLES.setdefault(fp8_format, _build_fp8_table(fp8_format))
if abs(value) <= max_abs_value
}
)
if len(values) < 2:
spacing = 0.0
else:
spacing = max(right - left for left, right in zip(values, values[1:]))
_FP8_SPACING_CACHE[cache_key] = spacing
return spacing
def _fp8_max_abs_value(fp8_format: str) -> float:
return max(abs(value) for _, value in _FP8_TABLES.setdefault(fp8_format, _build_fp8_table(fp8_format)))
def _encode_e4m3b15_values(values):
# Mirrors the device e4m3b15 encode (gpu_data_types.hpp): clamp the fp16 intermediate
# to 0x3F80 (+/-1.875) so the max encodable byte is 0x7F/0xFF.
fp16_bits = values.astype(cp.float16).view(cp.uint16)
abs_fp16 = fp16_bits & cp.uint16(0x7FFF)
abs_fp16 = cp.minimum(abs_fp16, cp.uint16(0x3F80)).astype(cp.uint32)
sign16 = (fp16_bits & cp.uint16(0x8000)).astype(cp.uint32)
adjusted = abs_fp16 * cp.uint32(2) + cp.uint32(0x0080)
return (((sign16 | adjusted) >> cp.uint32(8)) & cp.uint32(0xFF)).astype(cp.uint8)
def _build_fp8_table(fp8_format: str) -> list[tuple[int, float]]:
table = []
for byte in range(256):
value = _decode_fp8_scalar(fp8_format, byte)
if not math.isnan(value):
table.append((byte, value))
return table
def _decode_fp8_scalar(fp8_format: str, byte: int) -> float:
if fp8_format == "e4m3fnuz" and byte == 0x80:
return float("nan")
sign = -1.0 if byte & 0x80 else 1.0
return sign * _decode_fp8_positive(fp8_format, byte & 0x7F)
def _decode_fp8_positive(fp8_format: str, byte: int) -> float:
exp = (byte >> 3) & 0xF
mant = byte & 0x7
if fp8_format == "e4m3fn" and exp == 0xF and mant == 0x7:
return float("nan")
if exp == 0 and mant == 0:
return 0.0
if fp8_format == "e4m3fn":
return math.ldexp(mant / 8.0, -6) if exp == 0 else math.ldexp(1.0 + mant / 8.0, exp - 7)
if fp8_format == "e4m3fnuz":
return math.ldexp(mant / 8.0, -7) if exp == 0 else math.ldexp(1.0 + mant / 8.0, exp - 8)
if fp8_format == "e4m3b15":
return math.ldexp(mant / 8.0, -14) if exp == 0 else math.ldexp(1.0 + mant / 8.0, exp - 15)
raise ValueError(f"Unknown FP8 format: {fp8_format}")
def _decode_fp8_array(fp8_format: str, values):
bits = values.astype(cp.int32)
sign = (bits >> 7) & 1
exp = (bits >> 3) & 0xF
mant = bits & 0x7
if fp8_format == "e4m3fn":
subnormal = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-6))
normal = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), exp.astype(cp.int32) - 7)
decoded = cp.where(exp == 0, subnormal, normal)
decoded = cp.where((exp == 0xF) & (mant == 0x7), cp.nan, decoded)
elif fp8_format == "e4m3fnuz":
subnormal = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-7))
normal = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), exp.astype(cp.int32) - 8)
decoded = cp.where(exp == 0, subnormal, normal)
elif fp8_format == "e4m3b15":
subnormal = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-14))
normal = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), exp.astype(cp.int32) - 15)
decoded = cp.where(exp == 0, subnormal, normal)
else:
raise ValueError(f"Unknown FP8 format: {fp8_format}")
result = cp.where(sign == 1, -decoded, decoded)
if fp8_format == "e4m3fnuz":
result = cp.where(bits == 0x80, cp.float32(float("nan")), result)
return result