mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-06-29 02:47:23 +00:00
This pull request introduces support for ROCm 7.2 across the build system, CI pipelines, Docker images, and documentation, while also improving ROCm FP8 type selection and CUDA IPC memory handle management. It updates dependencies and configurations to ensure compatibility with ROCm 7.2, adds new options for native FP8 variants, and refines some benchmarking and internal memory handling logic. Pls notice: there is an issue in rocm7.2 (rocm7.2 user lib + rocm6.2 driver) when execution code in this order: allocating memory -> ipc communication -> allocate new memory -> free old memory. --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
402 lines
15 KiB
Python
402 lines
15 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
from __future__ import annotations
|
|
|
|
import math
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import cupy as cp
|
|
from mpi4py import MPI
|
|
|
|
_mscclpp_module = None
|
|
|
|
|
|
def _mscclpp():
|
|
global _mscclpp_module
|
|
if _mscclpp_module is None:
|
|
import mscclpp
|
|
|
|
_mscclpp_module = mscclpp
|
|
return _mscclpp_module
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CorrectnessStats:
|
|
ok: bool
|
|
max_abs_diff: float = 0.0
|
|
mean_abs_diff: float = 0.0
|
|
mismatches: int = 0
|
|
total: int = 0
|
|
|
|
def __bool__(self) -> bool:
|
|
return self.ok
|
|
|
|
|
|
def config_accum_dtype(case: Any) -> Any:
|
|
return case.dtype_spec.accum_dtype or case.dtype_spec.mscclpp_dtype
|
|
|
|
|
|
def fill_case_for_benchmark(case: Any, rank: int) -> None:
|
|
values = _benchmark_input_values(case, rank)
|
|
encoded = _encode_correctness_input(case, values)
|
|
if case.collective == "allreduce":
|
|
case.input[...] = encoded
|
|
return
|
|
case.output.fill(0)
|
|
case.input[...] = encoded
|
|
|
|
|
|
def check_correctness(
|
|
comm: Any,
|
|
case: Any,
|
|
config: Any,
|
|
*,
|
|
niter: int = 1,
|
|
) -> CorrectnessStats:
|
|
all_ok = True
|
|
local_max_abs_diff = 0.0
|
|
local_sum_abs_diff = 0.0
|
|
local_mismatches = 0
|
|
local_total = 0
|
|
for iteration in range(niter):
|
|
_fill_case_for_correctness(case, comm.rank, iteration)
|
|
ret = comm.run(case, config)
|
|
cp.cuda.runtime.deviceSynchronize()
|
|
comm.comm_group.barrier()
|
|
if ret != 0:
|
|
all_ok = False
|
|
continue
|
|
|
|
expected, stats_expected = _expected_outputs(case, comm.nranks, iteration)
|
|
iter_stats = _local_diff_stats(case, case.output, expected, comm.nranks, stats_expected=stats_expected)
|
|
local_ok = _compare_output(case, case.output, expected, comm.nranks)
|
|
all_ok = all_ok and local_ok
|
|
local_max_abs_diff = max(local_max_abs_diff, iter_stats.max_abs_diff)
|
|
local_sum_abs_diff += iter_stats.mean_abs_diff * iter_stats.total
|
|
local_mismatches += iter_stats.mismatches
|
|
local_total += iter_stats.total
|
|
|
|
if not local_ok:
|
|
mismatch = _mismatch_mask(case, case.output, expected, comm.nranks)
|
|
print(
|
|
"not close: "
|
|
f"iter={iteration}, rank={comm.rank}, output={case.output[mismatch][0]}, "
|
|
f"expected={expected[mismatch][0]}, max_abs_diff={iter_stats.max_abs_diff:.6g}, "
|
|
f"mean_abs_diff={iter_stats.mean_abs_diff:.6g}, mismatches={iter_stats.mismatches}/{iter_stats.total}",
|
|
flush=True,
|
|
)
|
|
|
|
global_ok = bool(MPI.COMM_WORLD.allreduce(all_ok, op=MPI.LAND))
|
|
global_max_abs_diff = float(MPI.COMM_WORLD.allreduce(local_max_abs_diff, op=MPI.MAX))
|
|
global_sum_abs_diff = float(MPI.COMM_WORLD.allreduce(local_sum_abs_diff, op=MPI.SUM))
|
|
global_mismatches = int(MPI.COMM_WORLD.allreduce(local_mismatches, op=MPI.SUM))
|
|
global_total = int(MPI.COMM_WORLD.allreduce(local_total, op=MPI.SUM))
|
|
global_mean_abs_diff = global_sum_abs_diff / global_total if global_total else 0.0
|
|
return CorrectnessStats(
|
|
ok=global_ok,
|
|
max_abs_diff=global_max_abs_diff,
|
|
mean_abs_diff=global_mean_abs_diff,
|
|
mismatches=global_mismatches,
|
|
total=global_total,
|
|
)
|
|
|
|
|
|
def _fill_case_for_correctness(case: Any, rank: int, iteration: int) -> None:
|
|
values = _correctness_input_values(case, rank, iteration)
|
|
encoded = _encode_correctness_input(case, values)
|
|
if case.collective == "allreduce":
|
|
case.input[...] = encoded
|
|
return
|
|
case.output.fill(0)
|
|
case.input[...] = encoded
|
|
|
|
|
|
def _correctness_input_values(case: Any, rank: int, iteration: int):
|
|
shape = case.input.shape
|
|
rng = cp.random.RandomState(_correctness_seed(rank, iteration))
|
|
return _random_input_values(case, rng, shape)
|
|
|
|
|
|
def _benchmark_input_values(case: Any, rank: int):
|
|
rng = cp.random.RandomState(17_000_003 + rank)
|
|
return _random_input_values(case, rng, case.input.shape)
|
|
|
|
|
|
def _random_input_values(case: Any, rng, shape):
|
|
if case.dtype_spec.fp8_format is not None:
|
|
value_range = _fp8_correctness_input_range(case)
|
|
return rng.uniform(-value_range, value_range, size=shape).astype(cp.float32)
|
|
if case.dtype_spec.cupy_dtype == cp.int32:
|
|
return rng.randint(-1, 2, size=shape).astype(cp.int32)
|
|
if case.dtype_spec.cupy_dtype == cp.uint8:
|
|
return rng.randint(0, 2, size=shape).astype(cp.uint8)
|
|
return rng.uniform(-1.0, 1.0, size=shape).astype(cp.float32)
|
|
|
|
|
|
def _correctness_seed(rank: int, iteration: int) -> int:
|
|
return (iteration + 1) * 1_000_003 + rank
|
|
|
|
|
|
def _fp8_correctness_input_range(case: Any) -> float:
|
|
if case.collective != "allreduce":
|
|
return 1.0
|
|
fp8_format = case.dtype_spec.fp8_format
|
|
if fp8_format is None:
|
|
return 1.0
|
|
return min(1.0, _fp8_max_abs_value(fp8_format) / max(1, MPI.COMM_WORLD.size))
|
|
|
|
|
|
def _encode_correctness_input(case: Any, values):
|
|
if case.dtype_spec.fp8_format is not None:
|
|
# FP8 buffers are stored as uint8 raw bytes, so a normal astype(uint8) cast would not produce FP8 bits.
|
|
return _encode_fp8_values(case.dtype_spec.fp8_format, values)
|
|
return values.astype(case.dtype_spec.cupy_dtype)
|
|
|
|
|
|
def _local_diff_stats(case: Any, output, expected, nranks: int, *, stats_expected=None) -> CorrectnessStats:
|
|
mismatch = _mismatch_mask(case, output, expected, nranks)
|
|
mismatches = int(cp.count_nonzero(mismatch).item())
|
|
total = int(output.size)
|
|
if total == 0:
|
|
return CorrectnessStats(ok=mismatches == 0)
|
|
|
|
output_values = _stats_values(case, output)
|
|
expected_values = _stats_values(case, expected) if stats_expected is None else stats_expected.astype(cp.float64)
|
|
abs_diff = cp.abs(output_values - expected_values)
|
|
return CorrectnessStats(
|
|
ok=mismatches == 0,
|
|
max_abs_diff=float(cp.max(abs_diff).item()),
|
|
mean_abs_diff=float(cp.mean(abs_diff).item()),
|
|
mismatches=mismatches,
|
|
total=total,
|
|
)
|
|
|
|
|
|
def _stats_values(case: Any, values):
|
|
# Convert storage buffers into numeric values before computing max/mean diff.
|
|
if case.dtype_spec.fp8_format is not None:
|
|
return _decode_fp8_array(case.dtype_spec.fp8_format, values)
|
|
if cp.issubdtype(values.dtype, cp.floating):
|
|
return values.astype(cp.float64)
|
|
return values.astype(cp.int64)
|
|
|
|
|
|
def _expected_outputs(case: Any, nranks: int, iteration: int):
|
|
if case.collective == "allreduce":
|
|
encoded_inputs = _encoded_rank_inputs(case, nranks, iteration)
|
|
if case.dtype_spec.fp8_format is not None:
|
|
stats_expected = _expected_fp8_accum_values(case, encoded_inputs)
|
|
return _encode_reduced_output(case, stats_expected), stats_expected
|
|
return _encode_reduced_output(case, sum(values.astype(cp.float32) for values in encoded_inputs)), None
|
|
|
|
expected = cp.empty_like(case.output)
|
|
chunk = case.input.size
|
|
for rank, values in enumerate(_encoded_rank_inputs(case, nranks, iteration)):
|
|
expected[rank * chunk : (rank + 1) * chunk] = values.reshape(-1)
|
|
return expected, None
|
|
|
|
|
|
def _encoded_rank_inputs(case: Any, nranks: int, iteration: int) -> list[Any]:
|
|
return [_encode_correctness_input(case, _correctness_input_values(case, rank, iteration)) for rank in range(nranks)]
|
|
|
|
|
|
def _expected_fp8_accum_values(case: Any, encoded_inputs: list[Any]):
|
|
fp8_format = case.dtype_spec.fp8_format
|
|
if fp8_format is None:
|
|
raise ValueError("FP8 format is required")
|
|
|
|
accum_dtype = config_accum_dtype(case)
|
|
if accum_dtype == _mscclpp().DataType.float16:
|
|
acc = cp.zeros_like(_decode_fp8_array(fp8_format, encoded_inputs[0]), dtype=cp.float16)
|
|
for values in encoded_inputs:
|
|
acc = (acc + _decode_fp8_array(fp8_format, values).astype(cp.float16)).astype(cp.float16)
|
|
return acc.astype(cp.float32)
|
|
|
|
if accum_dtype == _mscclpp().DataType.float32:
|
|
acc = cp.zeros_like(_decode_fp8_array(fp8_format, encoded_inputs[0]), dtype=cp.float32)
|
|
for values in encoded_inputs:
|
|
acc += _decode_fp8_array(fp8_format, values).astype(cp.float32)
|
|
return acc
|
|
|
|
acc = encoded_inputs[0]
|
|
for values in encoded_inputs[1:]:
|
|
acc = _encode_fp8_values(fp8_format, _decode_fp8_array(fp8_format, acc) + _decode_fp8_array(fp8_format, values))
|
|
return _decode_fp8_array(fp8_format, acc).astype(cp.float32)
|
|
|
|
|
|
def _encode_reduced_output(case: Any, values):
|
|
if case.dtype_spec.fp8_format is not None:
|
|
return _encode_fp8_values(case.dtype_spec.fp8_format, values)
|
|
return values.astype(case.output.dtype)
|
|
|
|
|
|
def _compare_output(case: Any, output, expected, nranks: int) -> bool:
|
|
return bool(cp.all(~_mismatch_mask(case, output, expected, nranks)).item())
|
|
|
|
|
|
def _mismatch_mask(case: Any, output, expected, nranks: int):
|
|
tolerance = _comparison_tolerance(case, nranks)
|
|
if tolerance is None:
|
|
return output != expected
|
|
rtol, atol = tolerance
|
|
return ~cp.isclose(_stats_values(case, output), _stats_values(case, expected), rtol=rtol, atol=atol)
|
|
|
|
|
|
def _comparison_tolerance(case: Any, nranks: int) -> tuple[float, float] | None:
|
|
scale = max(1, nranks) if case.collective == "allreduce" else 1
|
|
if case.dtype_spec.fp8_format is not None:
|
|
accum_dtype = config_accum_dtype(case)
|
|
if accum_dtype == _mscclpp().DataType.float32:
|
|
return None
|
|
atol = _max_fp8_spacing(case.dtype_spec.fp8_format, float(scale))
|
|
if accum_dtype == _mscclpp().DataType.float16:
|
|
return (0.0, atol)
|
|
return (0.0, atol * 2)
|
|
if case.dtype_spec.cupy_dtype == cp.float16:
|
|
return (1.0e-2, 5.0e-4 * scale)
|
|
if case.dtype_spec.cupy_dtype == cp.float32:
|
|
return (1.0e-5 * scale, 1.0e-6 * scale)
|
|
return None
|
|
|
|
|
|
_FP8_TABLES: dict[str, list[tuple[int, float]]] = {}
|
|
_FP8_LOOKUP_CACHE: dict[str, tuple[Any, Any]] = {}
|
|
_FP8_SPACING_CACHE: dict[tuple[str, float], float] = {}
|
|
|
|
|
|
def _encode_fp8_values(fp8_format: str, values):
|
|
values = values.astype(cp.float32)
|
|
if fp8_format == "e4m3b15":
|
|
return _encode_e4m3b15_values(values)
|
|
|
|
# Round each value to the nearest representable FP8 value (ties to even).
|
|
table_values, table_bytes = _fp8_lookup_arrays(fp8_format)
|
|
flat_values = values.ravel()
|
|
|
|
# For each value find its two surrounding table entries: lower <= value <= upper.
|
|
upper = cp.clip(cp.searchsorted(table_values, flat_values), 1, table_values.size - 1)
|
|
lower = upper - 1
|
|
|
|
# Pick the closer neighbor; on an exact tie pick the one with an even byte.
|
|
dist_to_upper = table_values[upper] - flat_values
|
|
dist_to_lower = flat_values - table_values[lower]
|
|
upper_is_even = (table_bytes[upper] & cp.uint8(1)) == 0
|
|
pick_upper = (dist_to_upper < dist_to_lower) | ((dist_to_upper == dist_to_lower) & upper_is_even)
|
|
|
|
return cp.where(pick_upper, table_bytes[upper], table_bytes[lower]).reshape(values.shape)
|
|
|
|
|
|
def _fp8_lookup_arrays(fp8_format: str):
|
|
# Cache a sorted (value -> byte) table per format for fast nearest-value lookup.
|
|
if fp8_format in _FP8_LOOKUP_CACHE:
|
|
return _FP8_LOOKUP_CACHE[fp8_format]
|
|
|
|
# Different bytes can decode to the same value (e.g. +0 and -0); keep one byte per value.
|
|
byte_for_value: dict[float, int] = {}
|
|
for byte, value in _FP8_TABLES.setdefault(fp8_format, _build_fp8_table(fp8_format)):
|
|
if value not in byte_for_value or byte < byte_for_value[value]:
|
|
byte_for_value[value] = byte
|
|
|
|
table = sorted(byte_for_value.items())
|
|
table_values = cp.asarray([value for value, _ in table], dtype=cp.float32)
|
|
table_bytes = cp.asarray([byte for _, byte in table], dtype=cp.uint8)
|
|
_FP8_LOOKUP_CACHE[fp8_format] = (table_values, table_bytes)
|
|
return _FP8_LOOKUP_CACHE[fp8_format]
|
|
|
|
|
|
def _max_fp8_spacing(fp8_format: str, max_abs_value: float) -> float:
|
|
cache_key = (fp8_format, max_abs_value)
|
|
if cache_key in _FP8_SPACING_CACHE:
|
|
return _FP8_SPACING_CACHE[cache_key]
|
|
|
|
values = sorted(
|
|
{
|
|
value
|
|
for _, value in _FP8_TABLES.setdefault(fp8_format, _build_fp8_table(fp8_format))
|
|
if abs(value) <= max_abs_value
|
|
}
|
|
)
|
|
if len(values) < 2:
|
|
spacing = 0.0
|
|
else:
|
|
spacing = max(right - left for left, right in zip(values, values[1:]))
|
|
_FP8_SPACING_CACHE[cache_key] = spacing
|
|
return spacing
|
|
|
|
|
|
def _fp8_max_abs_value(fp8_format: str) -> float:
|
|
return max(abs(value) for _, value in _FP8_TABLES.setdefault(fp8_format, _build_fp8_table(fp8_format)))
|
|
|
|
|
|
def _encode_e4m3b15_values(values):
|
|
# Mirrors the device e4m3b15 encode (gpu_data_types.hpp): clamp the fp16 intermediate
|
|
# to 0x3F80 (+/-1.875) so the max encodable byte is 0x7F/0xFF.
|
|
fp16_bits = values.astype(cp.float16).view(cp.uint16)
|
|
abs_fp16 = fp16_bits & cp.uint16(0x7FFF)
|
|
abs_fp16 = cp.minimum(abs_fp16, cp.uint16(0x3F80)).astype(cp.uint32)
|
|
sign16 = (fp16_bits & cp.uint16(0x8000)).astype(cp.uint32)
|
|
adjusted = abs_fp16 * cp.uint32(2) + cp.uint32(0x0080)
|
|
return (((sign16 | adjusted) >> cp.uint32(8)) & cp.uint32(0xFF)).astype(cp.uint8)
|
|
|
|
|
|
def _build_fp8_table(fp8_format: str) -> list[tuple[int, float]]:
|
|
table = []
|
|
for byte in range(256):
|
|
value = _decode_fp8_scalar(fp8_format, byte)
|
|
if not math.isnan(value):
|
|
table.append((byte, value))
|
|
return table
|
|
|
|
|
|
def _decode_fp8_scalar(fp8_format: str, byte: int) -> float:
|
|
if fp8_format == "e4m3fnuz" and byte == 0x80:
|
|
return float("nan")
|
|
sign = -1.0 if byte & 0x80 else 1.0
|
|
return sign * _decode_fp8_positive(fp8_format, byte & 0x7F)
|
|
|
|
|
|
def _decode_fp8_positive(fp8_format: str, byte: int) -> float:
|
|
exp = (byte >> 3) & 0xF
|
|
mant = byte & 0x7
|
|
if fp8_format == "e4m3fn" and exp == 0xF and mant == 0x7:
|
|
return float("nan")
|
|
if exp == 0 and mant == 0:
|
|
return 0.0
|
|
if fp8_format == "e4m3fn":
|
|
return math.ldexp(mant / 8.0, -6) if exp == 0 else math.ldexp(1.0 + mant / 8.0, exp - 7)
|
|
if fp8_format == "e4m3fnuz":
|
|
return math.ldexp(mant / 8.0, -7) if exp == 0 else math.ldexp(1.0 + mant / 8.0, exp - 8)
|
|
if fp8_format == "e4m3b15":
|
|
return math.ldexp(mant / 8.0, -14) if exp == 0 else math.ldexp(1.0 + mant / 8.0, exp - 15)
|
|
raise ValueError(f"Unknown FP8 format: {fp8_format}")
|
|
|
|
|
|
def _decode_fp8_array(fp8_format: str, values):
|
|
bits = values.astype(cp.int32)
|
|
sign = (bits >> 7) & 1
|
|
exp = (bits >> 3) & 0xF
|
|
mant = bits & 0x7
|
|
|
|
if fp8_format == "e4m3fn":
|
|
subnormal = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-6))
|
|
normal = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), exp.astype(cp.int32) - 7)
|
|
decoded = cp.where(exp == 0, subnormal, normal)
|
|
decoded = cp.where((exp == 0xF) & (mant == 0x7), cp.nan, decoded)
|
|
elif fp8_format == "e4m3fnuz":
|
|
subnormal = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-7))
|
|
normal = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), exp.astype(cp.int32) - 8)
|
|
decoded = cp.where(exp == 0, subnormal, normal)
|
|
elif fp8_format == "e4m3b15":
|
|
subnormal = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-14))
|
|
normal = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), exp.astype(cp.int32) - 15)
|
|
decoded = cp.where(exp == 0, subnormal, normal)
|
|
else:
|
|
raise ValueError(f"Unknown FP8 format: {fp8_format}")
|
|
|
|
result = cp.where(sign == 1, -decoded, decoded)
|
|
if fp8_format == "e4m3fnuz":
|
|
result = cp.where(bits == 0x80, cp.float32(float("nan")), result)
|
|
return result
|