This commit is contained in:
Binyang Li
2026-05-16 19:29:18 +00:00
parent 94af88d88d
commit f32cfb1fb8

View File

@@ -54,16 +54,12 @@ def _round_pow2(size: int) -> int:
# -- CustomizedComm -----------------------------------------------------------
def _env_int(name: str, default: int) -> int:
return int(os.environ.get(name, default))
class CustomizedComm:
"""Exposes all_reduce, all_gather, barrier with lazy per-size tuning."""
_TUNE_N_WARMUP = _env_int("TUNE_N_WARMUP", 2)
_TUNE_N_GRAPH_LAUNCHES = _env_int("TUNE_N_GRAPH_LAUNCHES", 3)
_TUNE_N_OPS_PER_GRAPH = _env_int("TUNE_N_OPS_PER_GRAPH", 20)
_TUNE_N_WARMUP = 3
_TUNE_N_GRAPH_LAUNCHES = 5
_TUNE_N_OPS_PER_GRAPH = 50
_CANDIDATE_NBLOCKS = [4, 8, 16, 24, 32, 48, 64, 112, 128]
_CANDIDATE_NTHREADS = [512, 768, 1024]
_NBLOCKS_LIMIT = {
@@ -88,10 +84,10 @@ class CustomizedComm:
("default_allreduce_rsag", 512 << 10, None, None),
]
_AR_CANDIDATES_SINGLE = [
("default_allreduce_packet", 0, 512 << 10, None),
("default_allreduce_allpair_packet", 0, 128 << 10, None),
("default_allreduce_nvls_packet", 0, 64 << 10, lambda c: c._nvls),
("default_allreduce_rsag_zero_copy", 512 << 10, None, lambda c: not (c._nvls and c.symmetric_memory)),
("default_allreduce_packet", 0, 4 << 20, None),
("default_allreduce_allpair_packet", 0, 512 << 10, None),
("default_allreduce_nvls_packet", 0, 512 << 10, lambda c: c._nvls),
("default_allreduce_rsag_zero_copy", 512 << 10, None, None),
("default_allreduce_nvls_zero_copy", 512 << 10, None, lambda c: c._nvls and c.symmetric_memory),
("default_allreduce_fullmesh", 0, None, lambda c: torch.version.hip is not None),
]
@@ -228,11 +224,6 @@ class CustomizedComm:
symmetric_memory=False,
)
def _is_tune_config_supported(self, algo, nb, nt):
if algo.name in ("default_allreduce_packet", "default_allreduce_allpair_packet"):
return nb >= self.world_size - 1 and nt in (512, 1024)
return True
def _tune_size(self, collective: str, target_size: int):
"""Auto-tune one (collective, target_size) pair and cache result."""
buf = self._ensure_tune_bufs()
@@ -248,15 +239,13 @@ class CustomizedComm:
if nb > nb_limit:
continue
for nt in self._CANDIDATE_NTHREADS:
if not self._is_tune_config_supported(algo, nb, nt):
continue
# Feasibility — sync result across ranks so all agree
ret = run(algo, nb, nt)
torch.cuda.synchronize()
self._time_buf[0] = float(ret)
self._exec_ar(self._time_buf[:1], *self._default_ar_config(), sym=self.symmetric_memory)
if self._time_buf[0].item() != 0:
continue
torch.cuda.synchronize()
used.add(algo)
# Warmup