mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-25 07:14:40 +00:00
update
This commit is contained in:
@@ -54,16 +54,12 @@ def _round_pow2(size: int) -> int:
|
||||
# -- CustomizedComm -----------------------------------------------------------
|
||||
|
||||
|
||||
def _env_int(name: str, default: int) -> int:
|
||||
return int(os.environ.get(name, default))
|
||||
|
||||
|
||||
class CustomizedComm:
|
||||
"""Exposes all_reduce, all_gather, barrier with lazy per-size tuning."""
|
||||
|
||||
_TUNE_N_WARMUP = _env_int("TUNE_N_WARMUP", 2)
|
||||
_TUNE_N_GRAPH_LAUNCHES = _env_int("TUNE_N_GRAPH_LAUNCHES", 3)
|
||||
_TUNE_N_OPS_PER_GRAPH = _env_int("TUNE_N_OPS_PER_GRAPH", 20)
|
||||
_TUNE_N_WARMUP = 3
|
||||
_TUNE_N_GRAPH_LAUNCHES = 5
|
||||
_TUNE_N_OPS_PER_GRAPH = 50
|
||||
_CANDIDATE_NBLOCKS = [4, 8, 16, 24, 32, 48, 64, 112, 128]
|
||||
_CANDIDATE_NTHREADS = [512, 768, 1024]
|
||||
_NBLOCKS_LIMIT = {
|
||||
@@ -88,10 +84,10 @@ class CustomizedComm:
|
||||
("default_allreduce_rsag", 512 << 10, None, None),
|
||||
]
|
||||
_AR_CANDIDATES_SINGLE = [
|
||||
("default_allreduce_packet", 0, 512 << 10, None),
|
||||
("default_allreduce_allpair_packet", 0, 128 << 10, None),
|
||||
("default_allreduce_nvls_packet", 0, 64 << 10, lambda c: c._nvls),
|
||||
("default_allreduce_rsag_zero_copy", 512 << 10, None, lambda c: not (c._nvls and c.symmetric_memory)),
|
||||
("default_allreduce_packet", 0, 4 << 20, None),
|
||||
("default_allreduce_allpair_packet", 0, 512 << 10, None),
|
||||
("default_allreduce_nvls_packet", 0, 512 << 10, lambda c: c._nvls),
|
||||
("default_allreduce_rsag_zero_copy", 512 << 10, None, None),
|
||||
("default_allreduce_nvls_zero_copy", 512 << 10, None, lambda c: c._nvls and c.symmetric_memory),
|
||||
("default_allreduce_fullmesh", 0, None, lambda c: torch.version.hip is not None),
|
||||
]
|
||||
@@ -228,11 +224,6 @@ class CustomizedComm:
|
||||
symmetric_memory=False,
|
||||
)
|
||||
|
||||
def _is_tune_config_supported(self, algo, nb, nt):
|
||||
if algo.name in ("default_allreduce_packet", "default_allreduce_allpair_packet"):
|
||||
return nb >= self.world_size - 1 and nt in (512, 1024)
|
||||
return True
|
||||
|
||||
def _tune_size(self, collective: str, target_size: int):
|
||||
"""Auto-tune one (collective, target_size) pair and cache result."""
|
||||
buf = self._ensure_tune_bufs()
|
||||
@@ -248,15 +239,13 @@ class CustomizedComm:
|
||||
if nb > nb_limit:
|
||||
continue
|
||||
for nt in self._CANDIDATE_NTHREADS:
|
||||
if not self._is_tune_config_supported(algo, nb, nt):
|
||||
continue
|
||||
# Feasibility — sync result across ranks so all agree
|
||||
ret = run(algo, nb, nt)
|
||||
torch.cuda.synchronize()
|
||||
self._time_buf[0] = float(ret)
|
||||
self._exec_ar(self._time_buf[:1], *self._default_ar_config(), sym=self.symmetric_memory)
|
||||
if self._time_buf[0].item() != 0:
|
||||
continue
|
||||
torch.cuda.synchronize()
|
||||
used.add(algo)
|
||||
|
||||
# Warmup
|
||||
|
||||
Reference in New Issue
Block a user