mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-06-29 19:07:30 +00:00
- Add unit-test for float8_e4m3b15 data type. - And tuner and benchmark for allreduce/allgather algo, make sure the correctness and performance.
85 lines
3.2 KiB
Python
85 lines
3.2 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Callable, Iterable
|
|
|
|
from mscclpp_benchmark.tuning_config import TunedConfig
|
|
|
|
|
|
class OfflineTuner:
|
|
def __init__(
|
|
self,
|
|
comm: Any,
|
|
*,
|
|
candidate_nblocks: Iterable[int],
|
|
candidate_nthreads: Iterable[int],
|
|
n_warmup: int,
|
|
n_graph_launches: int,
|
|
n_ops_per_graph: int,
|
|
candidate_algorithms: Callable[[Any, Any], list[tuple[Any, Any]]],
|
|
check_correctness: Callable[..., bool],
|
|
measure: Callable[..., float | None],
|
|
) -> None:
|
|
self.comm = comm
|
|
self.candidate_nblocks = tuple(candidate_nblocks)
|
|
self.candidate_nthreads = tuple(candidate_nthreads)
|
|
self.n_warmup = n_warmup
|
|
self.n_graph_launches = n_graph_launches
|
|
self.n_ops_per_graph = n_ops_per_graph
|
|
self._candidate_algorithms = candidate_algorithms
|
|
self._check_correctness = check_correctness
|
|
self._measure = measure
|
|
|
|
def tune(self, case: Any) -> TunedConfig | None:
|
|
best_config: TunedConfig | None = None
|
|
best_time_us = float("inf")
|
|
symmetric_memory = bool(getattr(case, "symmetric_memory", False))
|
|
candidates = self._candidate_algorithms(self.comm, case)
|
|
if not candidates:
|
|
if self.comm.rank == 0:
|
|
print(
|
|
f"[skip] no supported tuning candidates for collective={case.collective} "
|
|
f"size={case.message_size}",
|
|
flush=True,
|
|
)
|
|
return None
|
|
for algorithm, candidate_spec in candidates:
|
|
for nblocks in self.candidate_nblocks:
|
|
if candidate_spec.max_nblocks is not None and nblocks > candidate_spec.max_nblocks:
|
|
continue
|
|
for nthreads in self.candidate_nthreads:
|
|
config = TunedConfig(
|
|
algorithm=algorithm.name,
|
|
nblocks=nblocks,
|
|
nthreads=nthreads,
|
|
symmetric_memory=symmetric_memory,
|
|
)
|
|
if not self._check_correctness(self.comm, case, config):
|
|
self.comm.reset(config)
|
|
continue
|
|
self.comm.reset(config)
|
|
time_us = self._measure(
|
|
self.comm,
|
|
case,
|
|
config,
|
|
n_warmup=self.n_warmup,
|
|
n_graph_launches=self.n_graph_launches,
|
|
n_ops_per_graph=self.n_ops_per_graph,
|
|
)
|
|
self.comm.reset(config)
|
|
if time_us is None or time_us >= best_time_us:
|
|
continue
|
|
best_time_us = time_us
|
|
best_config = TunedConfig(
|
|
algorithm=algorithm.name,
|
|
nblocks=nblocks,
|
|
nthreads=nthreads,
|
|
symmetric_memory=symmetric_memory,
|
|
time_us=time_us,
|
|
)
|
|
if best_config is None:
|
|
return self.comm.resolve_config(case)
|
|
return best_config
|