Files
mscclpp/python/mscclpp_benchmark/gpu.py
Binyang Li c9f8be64bb Add collective benchmark and correctness check (#814)
- Add unit-test for float8_e4m3b15 data type.
- And tuner and benchmark for allreduce/allgather algo, make sure the
correctness and performance.
2026-06-04 09:22:10 -07:00

188 lines
5.9 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable
_API_NAMES = {
"get_device_count": ("hipGetDeviceCount", "cudaGetDeviceCount"),
"get_device": ("hipGetDevice", "cudaGetDevice"),
"get_device_properties": ("hipGetDeviceProperties", "cudaGetDeviceProperties"),
"set_device": ("hipSetDevice", "cudaSetDevice"),
"stream_begin_capture": ("hipStreamBeginCapture", "cudaStreamBeginCapture"),
"stream_end_capture": ("hipStreamEndCapture", "cudaStreamEndCapture"),
"graph_instantiate": ("hipGraphInstantiate", "cudaGraphInstantiate"),
"graph_launch": ("hipGraphLaunch", "cudaGraphLaunch"),
"graph_destroy": ("hipGraphDestroy", "cudaGraphDestroy"),
"graph_exec_destroy": ("hipGraphExecDestroy", "cudaGraphExecDestroy"),
"get_error_string": ("hipGetErrorString", "cudaGetErrorString"),
}
@dataclass(frozen=True)
class _Runtime:
name: str
success: Any
capture_mode_relaxed: Any
funcs: dict[str, Callable[..., Any] | None]
@classmethod
def create(cls, name: str, module: Any, success: Any, capture_mode_relaxed: Any) -> "_Runtime":
index = 0 if name == "hip" else 1
funcs = {
attr: (None if names[index] is None else getattr(module, names[index]))
for attr, names in _API_NAMES.items()
}
return cls(name=name, success=success, capture_mode_relaxed=capture_mode_relaxed, funcs=funcs)
def call(self, name: str, *args: Any) -> tuple[Any, ...]:
fn = self.funcs[name]
if fn is None:
raise RuntimeError(f"{name} is not available for {self.name}")
result = fn(*args)
if not isinstance(result, tuple):
result = (result,)
self.check(result[0], name)
return result[1:]
def check(self, error: Any, api: str) -> None:
if error == self.success:
return
result = self.funcs["get_error_string"](error)
if not isinstance(result, tuple):
result = (result,)
err, message = result
if err != self.success:
raise RuntimeError(f"{api} failed with error {int(error)}")
decoded = message.decode("utf-8") if isinstance(message, bytes) else str(message)
raise RuntimeError(f"{api} failed: {decoded} ({int(error)})")
def _load_runtime() -> _Runtime:
errors: list[str] = []
try:
from hip import hip
runtime = _Runtime.create(
name="hip",
module=hip,
success=hip.hipError_t.hipSuccess,
capture_mode_relaxed=hip.hipStreamCaptureMode.hipStreamCaptureModeRelaxed,
)
count = runtime.call("get_device_count")[0]
if count and count > 0:
return runtime
errors.append(f"hipGetDeviceCount returned count={count}")
except ImportError as exc:
errors.append(f"hip-python unavailable: {exc}")
try:
from cuda.bindings import runtime as cuda_runtime
runtime = _Runtime.create(
name="cuda",
module=cuda_runtime,
success=cuda_runtime.cudaError_t.cudaSuccess,
capture_mode_relaxed=cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeRelaxed,
)
count = runtime.call("get_device_count")[0]
if count and count > 0:
return runtime
errors.append(f"cudaGetDeviceCount returned count={count}")
except ImportError as exc:
errors.append(f"cuda-bindings unavailable: {exc}")
raise RuntimeError("No usable CUDA/HIP Python runtime found: " + "; ".join(errors))
_RUNTIME = _load_runtime()
class Graph:
def __init__(self, graph_exec: Any) -> None:
self._graph_exec = graph_exec
def launch(self, stream: Any) -> None:
_api("graph_launch")(self._graph_exec, _stream_ptr(stream))
def close(self) -> None:
if self._graph_exec is not None:
_api("graph_exec_destroy")(self._graph_exec)
self._graph_exec = None
def init_runtime() -> None:
return None
def capture_graph(stream: Any, capture_fn: Callable[[], None]) -> Graph:
_api("set_device")(current_device())
stream_ptr = _stream_ptr(stream)
_api("stream_begin_capture")(stream_ptr, _RUNTIME.capture_mode_relaxed)
graph = None
try:
capture_fn()
graph = _api("stream_end_capture")(stream_ptr)[0]
except Exception:
try:
_api("stream_end_capture")(stream_ptr)
except Exception:
pass
raise
try:
graph_exec = _instantiate_graph(graph)
return Graph(graph_exec)
finally:
if graph is not None:
_api("graph_destroy")(graph)
def current_device() -> int:
return int(_api("get_device")()[0])
def device_name(device_id: int | None = None) -> str:
if device_id is None:
device_id = current_device()
prop = _api("get_device_properties")(int(device_id))[0]
name = getattr(prop, "name", "UNKNOWN")
return name.decode("utf-8") if isinstance(name, bytes) else str(name)
def _stream_ptr(stream: Any) -> int:
return int(getattr(stream, "ptr", stream))
def _instantiate_graph(graph: Any) -> Any:
if _RUNTIME.name == "hip":
return _api("graph_instantiate")(graph, None, 0)[0]
return _api("graph_instantiate")(graph, 0)[0]
def _api(name: str) -> Callable[..., tuple[Any, ...]]:
api = globals().get(name)
if api is None:
api = __getattr__(name)
return api
def _make_api(name: str) -> Callable[..., tuple[Any, ...]]:
def api(*args: Any) -> tuple[Any, ...]:
return _RUNTIME.call(name, *args)
api.__name__ = name
return api
def __getattr__(name: str) -> Callable[..., tuple[Any, ...]]:
if name in _API_NAMES:
api = _make_api(name)
globals()[name] = api
return api
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")