Files
mscclpp/python/mscclpp_benchmark/tuner.py
Binyang Li c9f8be64bb Add collective benchmark and correctness check (#814)
- Add unit-test for float8_e4m3b15 data type.
- And tuner and benchmark for allreduce/allgather algo, make sure the
correctness and performance.
2026-06-04 09:22:10 -07:00

85 lines
3.2 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import Any, Callable, Iterable
from mscclpp_benchmark.tuning_config import TunedConfig
class OfflineTuner:
def __init__(
self,
comm: Any,
*,
candidate_nblocks: Iterable[int],
candidate_nthreads: Iterable[int],
n_warmup: int,
n_graph_launches: int,
n_ops_per_graph: int,
candidate_algorithms: Callable[[Any, Any], list[tuple[Any, Any]]],
check_correctness: Callable[..., bool],
measure: Callable[..., float | None],
) -> None:
self.comm = comm
self.candidate_nblocks = tuple(candidate_nblocks)
self.candidate_nthreads = tuple(candidate_nthreads)
self.n_warmup = n_warmup
self.n_graph_launches = n_graph_launches
self.n_ops_per_graph = n_ops_per_graph
self._candidate_algorithms = candidate_algorithms
self._check_correctness = check_correctness
self._measure = measure
def tune(self, case: Any) -> TunedConfig | None:
best_config: TunedConfig | None = None
best_time_us = float("inf")
symmetric_memory = bool(getattr(case, "symmetric_memory", False))
candidates = self._candidate_algorithms(self.comm, case)
if not candidates:
if self.comm.rank == 0:
print(
f"[skip] no supported tuning candidates for collective={case.collective} "
f"size={case.message_size}",
flush=True,
)
return None
for algorithm, candidate_spec in candidates:
for nblocks in self.candidate_nblocks:
if candidate_spec.max_nblocks is not None and nblocks > candidate_spec.max_nblocks:
continue
for nthreads in self.candidate_nthreads:
config = TunedConfig(
algorithm=algorithm.name,
nblocks=nblocks,
nthreads=nthreads,
symmetric_memory=symmetric_memory,
)
if not self._check_correctness(self.comm, case, config):
self.comm.reset(config)
continue
self.comm.reset(config)
time_us = self._measure(
self.comm,
case,
config,
n_warmup=self.n_warmup,
n_graph_launches=self.n_graph_launches,
n_ops_per_graph=self.n_ops_per_graph,
)
self.comm.reset(config)
if time_us is None or time_us >= best_time_us:
continue
best_time_us = time_us
best_config = TunedConfig(
algorithm=algorithm.name,
nblocks=nblocks,
nthreads=nthreads,
symmetric_memory=symmetric_memory,
time_us=time_us,
)
if best_config is None:
return self.comm.resolve_config(case)
return best_config