From f32cfb1fb87be2adce4a33b695ccec43441bf3bc Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Sat, 16 May 2026 19:29:18 +0000 Subject: [PATCH] update --- .../customized_comm_with_tuning.py | 27 ++++++------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/examples/torch-integration/customized_comm_with_tuning.py b/examples/torch-integration/customized_comm_with_tuning.py index 0a07ca32..040fda58 100644 --- a/examples/torch-integration/customized_comm_with_tuning.py +++ b/examples/torch-integration/customized_comm_with_tuning.py @@ -54,16 +54,12 @@ def _round_pow2(size: int) -> int: # -- CustomizedComm ----------------------------------------------------------- -def _env_int(name: str, default: int) -> int: - return int(os.environ.get(name, default)) - - class CustomizedComm: """Exposes all_reduce, all_gather, barrier with lazy per-size tuning.""" - _TUNE_N_WARMUP = _env_int("TUNE_N_WARMUP", 2) - _TUNE_N_GRAPH_LAUNCHES = _env_int("TUNE_N_GRAPH_LAUNCHES", 3) - _TUNE_N_OPS_PER_GRAPH = _env_int("TUNE_N_OPS_PER_GRAPH", 20) + _TUNE_N_WARMUP = 3 + _TUNE_N_GRAPH_LAUNCHES = 5 + _TUNE_N_OPS_PER_GRAPH = 50 _CANDIDATE_NBLOCKS = [4, 8, 16, 24, 32, 48, 64, 112, 128] _CANDIDATE_NTHREADS = [512, 768, 1024] _NBLOCKS_LIMIT = { @@ -88,10 +84,10 @@ class CustomizedComm: ("default_allreduce_rsag", 512 << 10, None, None), ] _AR_CANDIDATES_SINGLE = [ - ("default_allreduce_packet", 0, 512 << 10, None), - ("default_allreduce_allpair_packet", 0, 128 << 10, None), - ("default_allreduce_nvls_packet", 0, 64 << 10, lambda c: c._nvls), - ("default_allreduce_rsag_zero_copy", 512 << 10, None, lambda c: not (c._nvls and c.symmetric_memory)), + ("default_allreduce_packet", 0, 4 << 20, None), + ("default_allreduce_allpair_packet", 0, 512 << 10, None), + ("default_allreduce_nvls_packet", 0, 512 << 10, lambda c: c._nvls), + ("default_allreduce_rsag_zero_copy", 512 << 10, None, None), ("default_allreduce_nvls_zero_copy", 512 << 10, None, lambda c: c._nvls and c.symmetric_memory), ("default_allreduce_fullmesh", 0, None, lambda c: torch.version.hip is not None), ] @@ -228,11 +224,6 @@ class CustomizedComm: symmetric_memory=False, ) - def _is_tune_config_supported(self, algo, nb, nt): - if algo.name in ("default_allreduce_packet", "default_allreduce_allpair_packet"): - return nb >= self.world_size - 1 and nt in (512, 1024) - return True - def _tune_size(self, collective: str, target_size: int): """Auto-tune one (collective, target_size) pair and cache result.""" buf = self._ensure_tune_bufs() @@ -248,15 +239,13 @@ class CustomizedComm: if nb > nb_limit: continue for nt in self._CANDIDATE_NTHREADS: - if not self._is_tune_config_supported(algo, nb, nt): - continue # Feasibility — sync result across ranks so all agree ret = run(algo, nb, nt) + torch.cuda.synchronize() self._time_buf[0] = float(ret) self._exec_ar(self._time_buf[:1], *self._default_ar_config(), sym=self.symmetric_memory) if self._time_buf[0].item() != 0: continue - torch.cuda.synchronize() used.add(algo) # Warmup