Files
mscclpp/python/mscclpp_benchmark/comm.py
Binyang Li c9f8be64bb Add collective benchmark and correctness check (#814)
- Add unit-test for float8_e4m3b15 data type.
- And tuner and benchmark for allreduce/allgather algo, make sure the
correctness and performance.
2026-06-04 09:22:10 -07:00

410 lines
14 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import logging
from typing import Any
logger = logging.getLogger(__name__)
_ALLREDUCE_COLLECTIVE = "allreduce"
_ALLGATHER_COLLECTIVE = "allgather"
_mscclpp_module = None
from mscclpp_benchmark.gpu import current_device, device_name, set_device
from mscclpp_benchmark.tuning_config import HardwareProfile, TunedConfig, TunedConfigStore, normalize_sku
def _mscclpp():
global _mscclpp_module
if _mscclpp_module is None:
import mscclpp
import mscclpp.ext
_mscclpp_module = mscclpp
return _mscclpp_module
class Buffer:
def __init__(
self,
nbytes: int | None = None,
*,
dtype: str | Any = "float16",
shape: tuple[int, ...] | None = None,
buffer: Any | None = None,
) -> None:
self.dtype = dtype
self.element_size = _dtype_size(dtype)
if buffer is None:
if nbytes is None:
if shape is None:
raise ValueError("Either nbytes or shape is required")
nbytes = _numel(shape) * self.element_size
_ensure_device()
buffer = _mscclpp().RawGpuBuffer(int(nbytes))
self.buffer = buffer
self.nbytes = int(buffer.bytes())
self.shape = shape if shape is not None else (self.nbytes // self.element_size,)
@property
def ndim(self) -> int:
return len(self.shape)
@property
def size(self) -> int:
return _numel(self.shape)
def data_ptr(self) -> int:
return int(self.buffer.data())
class _AllReduceOp:
def __init__(self, comm: "Comm", x: Any, *, symmetric_memory: bool = False) -> None:
self._comm = comm
self._x = x
self._symmetric_memory = symmetric_memory
def __call__(self, **_: Any) -> Any:
self._comm.run(self._x, symmetric_memory=self._symmetric_memory)
return self._x
class _AllGatherOp:
def __init__(self, comm: "Comm", x: Any, *, dim: int, y: Any | None = None, symmetric_memory: bool = False) -> None:
shape = _shape(x)
if len(shape) == 0:
raise ValueError("MSCCL++ allgather requires a non-scalar buffer")
if dim % len(shape) != 0:
raise NotImplementedError("Raw-buffer allgather currently supports only dim=0")
if y is None:
y_shape = (comm._scale() * shape[0], *shape[1:])
y = Buffer(dtype=_dtype(x), shape=y_shape)
self._comm = comm
self._x = x
self.y = y
self._symmetric_memory = symmetric_memory
def __call__(self, **_: Any) -> Any:
self._comm.run(
self._x,
collective=_ALLGATHER_COLLECTIVE,
output_tensor=self.y,
symmetric_memory=self._symmetric_memory,
)
return self.y
class Comm:
"""Runtime MSCCL++ wrapper that owns algorithm handles and execution without Torch/CuPy tensors."""
def __init__(
self,
comm_group: Any,
scratch_buffer_size: int = 1 << 27,
*,
config_store: "TunedConfigStore | None" = None,
hardware_profile: HardwareProfile | None = None,
) -> None:
self._comm_group = comm_group
self._mpi_comm = getattr(comm_group, "_mpi_comm", None)
self._rank = comm_group.my_rank
self._closed = False
_ensure_device()
self._mscclpp = _mscclpp()
self._scratch_buffer = self._mscclpp.RawGpuBuffer(scratch_buffer_size)
self._config_store = TunedConfigStore.empty() if config_store is None else config_store
self._hardware_profile = (
_detect_hardware_profile(scale=self._scale()) if hardware_profile is None else hardware_profile
)
self._default_config_warning_keys: set[tuple[str, str, str, int]] = set()
algorithms = self._mscclpp.ext.AlgorithmCollectionBuilder().build_default_algorithms(
scratch_buffer=self._scratch_buffer.data(),
scratch_buffer_size=self._scratch_buffer.bytes(),
rank=self._rank,
)
self._algorithms_by_collective: dict[str, dict[str, Any]] = {}
for algorithm in algorithms:
self._algorithms_by_collective.setdefault(algorithm.collective, {})[algorithm.name] = algorithm
@property
def comm_group(self) -> Any:
return self._comm_group
@property
def rank(self) -> int:
return self._rank
@property
def nranks(self) -> int:
return self._comm_group.nranks
@property
def algorithms(self) -> dict[str, dict[str, Any]]:
return self._algorithms_by_collective
@property
def hardware_profile(self) -> HardwareProfile:
return self._hardware_profile
def make_allreduce(self, x: Any, *, symmetric_memory: bool = False) -> _AllReduceOp:
return _AllReduceOp(self, x, symmetric_memory=symmetric_memory)
def make_allgather(self, x: Any, dim: int, y: Any | None = None, *, symmetric_memory: bool = False) -> _AllGatherOp:
return _AllGatherOp(self, x, dim=dim, y=y, symmetric_memory=symmetric_memory)
def _scale(self) -> int:
if self._mpi_comm is not None:
return int(self._mpi_comm.Get_size())
return 1
def resolve_config(self, case: Any, *, symmetric_memory: bool = False) -> TunedConfig:
dtype_override = getattr(getattr(case, "dtype_spec", None), "mscclpp_dtype", None)
accum_dtype = getattr(getattr(case, "dtype_spec", None), "accum_dtype", None) or dtype_override
symmetric_memory = symmetric_memory or bool(getattr(case, "symmetric_memory", False))
return self._resolve_config(
case.collective,
case.input,
dtype_override=dtype_override,
accum_dtype=accum_dtype,
symmetric_memory=symmetric_memory,
)
def _resolve_config(
self,
collective: str,
buffer: Any,
*,
dtype_override: Any | None = None,
accum_dtype: Any | None = None,
symmetric_memory: bool = False,
) -> TunedConfig:
tuned_config = self._config_store.select(self._hardware_profile, collective, _nbytes(buffer))
if tuned_config is not None and tuned_config.algorithm in self._algorithms_by_collective.get(collective, {}):
return tuned_config
if self._rank == 0:
dim = int(_shape(buffer)[1]) if len(_shape(buffer)) > 1 else 1
warning_key = (
collective,
str(dtype_override if dtype_override is not None else _dtype(buffer)),
str(
accum_dtype
if accum_dtype is not None
else dtype_override if dtype_override is not None else _dtype(buffer)
),
dim,
)
if warning_key not in self._default_config_warning_keys:
self._default_config_warning_keys.add(warning_key)
logger.warning(
"MSCCL++ default config: no tuning for collective=%s profile=%s dtype=%s accum=%s dim=%s; perf may be poor",
collective,
self._hardware_profile,
warning_key[1],
warning_key[2],
dim,
)
return _default_tuned_config(
collective,
_nbytes(buffer),
self._algorithms_by_collective,
symmetric_memory=symmetric_memory,
)
def run(
self,
buffer: Any,
config: TunedConfig | None = None,
stream: Any | None = None,
*,
collective: str = _ALLREDUCE_COLLECTIVE,
output_tensor: Any | None = None,
dtype_override: Any | None = None,
accum_dtype: Any | None = None,
symmetric_memory: bool = False,
) -> int:
if self._closed:
raise RuntimeError("Cannot use a closed MSCCL++ comm")
raise_on_error = True
if hasattr(buffer, "input") and hasattr(buffer, "output") and hasattr(buffer, "dtype_spec"):
case = buffer
buffer = case.input
output_tensor = case.output
collective = case.collective
dtype_override = case.dtype_spec.mscclpp_dtype
accum_dtype = case.dtype_spec.accum_dtype or dtype_override
symmetric_memory = symmetric_memory or bool(getattr(case, "symmetric_memory", False))
raise_on_error = False
if collective not in self._algorithms_by_collective:
raise RuntimeError(f"No supported MSCCL++ {collective} algorithm is available")
if config is None:
config = self._resolve_config(
collective,
buffer,
dtype_override=dtype_override,
accum_dtype=accum_dtype,
symmetric_memory=symmetric_memory,
)
symmetric_memory = symmetric_memory or config.symmetric_memory
algorithm = self._algorithms_by_collective[collective][config.algorithm]
output = buffer if output_tensor is None else output_tensor
dtype = dtype_override if dtype_override is not None else _dtype_to_mscclpp(_dtype(buffer))
accum = accum_dtype if accum_dtype is not None else dtype
ret = algorithm.execute(
comm=self._comm_group.communicator,
input_buffer=_data_ptr(buffer),
output_buffer=_data_ptr(output),
input_size=_nbytes(buffer),
output_size=_nbytes(output),
dtype=dtype,
op=self._mscclpp.ReduceOp.SUM if collective == _ALLREDUCE_COLLECTIVE else self._mscclpp.ReduceOp.NOP,
stream=_stream_ptr(stream),
nblocks=config.nblocks or 0,
nthreads_per_block=config.nthreads or 0,
symmetric_memory=symmetric_memory,
accum_dtype=accum,
)
if ret != 0 and raise_on_error:
raise RuntimeError(f"MSCCL++ {collective} failed on rank {self._rank} with error code {ret}")
return ret
def reset(self, config: TunedConfig | None = None) -> None:
if config is not None:
for algorithms_by_name in self._algorithms_by_collective.values():
algorithm = algorithms_by_name.get(config.algorithm)
if algorithm is not None:
algorithm.reset()
return
for algorithms_by_name in self._algorithms_by_collective.values():
for algorithm in algorithms_by_name.values():
algorithm.reset()
def close(self) -> None:
self.reset()
self._algorithms_by_collective = {}
self._scratch_buffer = None
self._closed = True
self._mscclpp.ext.AlgorithmCollectionBuilder.reset()
def _numel(shape: tuple[int, ...]) -> int:
out = 1
for dim in shape:
out *= int(dim)
return out
def _dtype_size(dtype: Any) -> int:
dtype_name = _dtype_name(dtype)
if dtype_name in {"float16", "bfloat16"}:
return 2
if dtype_name in {"float32", "int32", "uint32"}:
return 4
if dtype_name in {"uint8", "float8_e4m3b15", "float8_e4m3fn", "float8_e4m3fnuz"}:
return 1
raise ValueError(f"Unknown data type size for {dtype}")
def _dtype_name(dtype: Any) -> str:
if isinstance(dtype, str):
return dtype.strip().lower().replace("-", "_")
name = str(dtype).rsplit(".", 1)[-1]
return name.strip().lower().replace("-", "_")
def _dtype_to_mscclpp(dtype: Any) -> Any:
dtype_name = _dtype_name(dtype)
mapping = {
"float16": _mscclpp().DataType.float16,
"float32": _mscclpp().DataType.float32,
"int32": _mscclpp().DataType.int32,
"uint8": _mscclpp().DataType.uint8,
"float8_e4m3b15": _mscclpp().DataType.float8_e4m3b15,
"float8_e4m3fn": _mscclpp().DataType.float8_e4m3fn,
"float8_e4m3fnuz": _mscclpp().DataType.float8_e4m3fnuz,
}
try:
return mapping[dtype_name]
except KeyError as exc:
raise ValueError(f"Unknown data type: {dtype}") from exc
def _data_ptr(buffer: Any) -> int:
if hasattr(buffer, "data_ptr"):
data_ptr = buffer.data_ptr
return int(data_ptr() if callable(data_ptr) else data_ptr)
if hasattr(buffer, "data"):
data = buffer.data
if callable(data):
return int(data())
if hasattr(data, "ptr"):
return int(data.ptr)
raise TypeError(f"Cannot get device pointer from {type(buffer)!r}")
def _stream_ptr(stream: Any | None) -> int:
if stream is None:
return 0
return int(getattr(stream, "ptr", stream))
def _nbytes(buffer: Any) -> int:
if hasattr(buffer, "nbytes"):
return int(buffer.nbytes)
if hasattr(buffer, "bytes"):
value = buffer.bytes
return int(value() if callable(value) else value)
raise TypeError(f"Cannot get byte size from {type(buffer)!r}")
def _shape(buffer: Any) -> tuple[int, ...]:
shape = getattr(buffer, "shape", None)
if shape is None:
return (_nbytes(buffer) // _dtype_size(_dtype(buffer)),)
return tuple(int(dim) for dim in shape)
def _dtype(buffer: Any) -> Any:
dtype = getattr(buffer, "dtype", None)
if dtype is None:
return "uint8"
return dtype
def _detect_hardware_profile(*, scale: int) -> HardwareProfile:
try:
sku = device_name()
except Exception:
sku = "UNKNOWN"
return HardwareProfile(sku=normalize_sku(sku), scale=scale)
def _ensure_device() -> None:
set_device(current_device())
def _default_tuned_config(
collective: str,
message_size: int,
algorithms_by_collective: dict[str, dict[str, Any]],
*,
symmetric_memory: bool = False,
) -> TunedConfig:
if collective == _ALLGATHER_COLLECTIVE:
return TunedConfig("default_allgather_fullmesh2", symmetric_memory=symmetric_memory)
available = algorithms_by_collective.get(collective, {})
if symmetric_memory and _mscclpp().is_nvls_supported() and "default_allreduce_nvls_zero_copy" in available:
return TunedConfig("default_allreduce_nvls_zero_copy", symmetric_memory=True)
if message_size <= 512 * 1024 and "default_allreduce_packet" in available:
return TunedConfig("default_allreduce_packet", symmetric_memory=symmetric_memory)
if "default_allreduce_rsag_zero_copy" in available:
return TunedConfig("default_allreduce_rsag_zero_copy", symmetric_memory=symmetric_memory)
if available:
return TunedConfig(next(iter(available)), symmetric_memory=symmetric_memory)
raise RuntimeError(f"No MSCCL++ algorithm is available for {collective}")