diff --git a/examples/torch-integration/customized_comm_with_tuning.py b/examples/torch-integration/customized_comm_with_tuning.py index 41be5825..060a0097 100644 --- a/examples/torch-integration/customized_comm_with_tuning.py +++ b/examples/torch-integration/customized_comm_with_tuning.py @@ -1,193 +1,117 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# MSCCLPP_MASTER_ADDR= MSCCLPP_MASTER_PORT= torchrun --nnodes=1 --nproc_per_node=8 customized_comm_with_tuning.py +# torchrun --nnodes=1 --nproc_per_node=8 examples/torch-integration/customized_comm_with_tuning.py import os -import torch -import mscclpp.utils as mscclpp_utils -import mscclpp -import mscclpp.ext -import netifaces as ni import ipaddress +import netifaces as ni +import torch +import mscclpp +import mscclpp.ext +import mscclpp.utils as mscclpp_utils -def load_algorithms(scratch_buffer: torch.tensor, rank: int) -> mscclpp.AlgorithmCollection: - collection_builder = mscclpp.ext.AlgorithmCollectionBuilder() - return collection_builder.build_default_algorithms( - scratch_buffer=scratch_buffer.data_ptr(), scratch_buffer_size=scratch_buffer.nbytes, rank=rank +# -- Helpers ------------------------------------------------------------------ + + +def _make_tensor(size_bytes: int, dtype: torch.dtype) -> torch.Tensor: + """Allocate a tensor backed by RawGpuBuffer (symmetric memory).""" + # PyTorch's from_dlpack does not support certain float8 DLPack type codes. + # Work around by importing as uint8 and reinterpreting via .view(). + _DLPACK_UNSUPPORTED = (torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz) + if dtype in _DLPACK_UNSUPPORTED: + dlpack = mscclpp.RawGpuBuffer(size_bytes).to_dlpack(data_type=str(torch.uint8)) + return torch.utils.dlpack.from_dlpack(dlpack).view(dtype) + dlpack = mscclpp.RawGpuBuffer(size_bytes).to_dlpack(data_type=str(dtype)) + return torch.utils.dlpack.from_dlpack(dlpack) + + +def _load_algorithms(scratch: torch.Tensor, rank: int): + return mscclpp.ext.AlgorithmCollectionBuilder().build_default_algorithms( + scratch_buffer=scratch.data_ptr(), + scratch_buffer_size=scratch.nbytes, + rank=rank, ) -def interfaces_for_ip_netifaces(ip: str): +def _interfaces_for_ip(ip: str): target = ipaddress.ip_address(ip) - for interface in ni.interfaces(): - addresses = ni.ifaddresses(interface) - if ni.AF_INET in addresses: - for link in addresses[ni.AF_INET]: - if "addr" in link: - addr = ipaddress.ip_address(link["addr"]) - if addr == target: - return interface + for iface in ni.interfaces(): + addrs = ni.ifaddresses(iface) + if ni.AF_INET in addrs: + for link in addrs[ni.AF_INET]: + if "addr" in link and ipaddress.ip_address(link["addr"]) == target: + return iface return None -def to_mscclpp_reduce_op(op: torch.distributed.ReduceOp) -> mscclpp.ReduceOp: +def _to_mscclpp_op(op) -> mscclpp.ReduceOp: if op == torch.distributed.ReduceOp.SUM: return mscclpp.ReduceOp.SUM - elif op == torch.distributed.ReduceOp.MIN: + if op == torch.distributed.ReduceOp.MIN: return mscclpp.ReduceOp.MIN - else: - raise ValueError(f"unsupported op: {op}") + raise ValueError(f"unsupported op: {op}") + + +def _round_pow2(size: int) -> int: + """Round up to next power-of-2, clamped to [1024, 256 MB].""" + size = max(size, 1024) + size = min(size, 256 << 20) + return 1 << (size - 1).bit_length() + + +# -- CustomizedComm ----------------------------------------------------------- class CustomizedComm: - def __init__(self, comm: mscclpp.CommGroup): + """Exposes all_reduce, all_gather, barrier with lazy per-size tuning.""" + + _TUNE_N_WARMUP = 5 + _TUNE_N_GRAPH_LAUNCHES = 10 + _TUNE_N_OPS_PER_GRAPH = 100 + _CANDIDATE_NBLOCKS = [4, 8, 16, 24, 32, 48, 64, 128] + _CANDIDATE_NTHREADS = [512, 768, 1024] + _NBLOCKS_LIMIT = { + "default_allreduce_nvls_packet": 16, + "default_allreduce_packet": 56, + "default_allreduce_allpair_packet": 56, + "default_allreduce_fullmesh": 64, + "default_allgather_fullmesh2": 32, + } + + def __init__(self, comm: mscclpp.CommGroup, symmetric_memory: bool = False): self.comm = comm self.rank = comm.my_rank self.world_size = comm.nranks - self.local_rank = comm.my_rank % comm.nranks_per_node - self.n_ranks_per_node = comm.nranks_per_node - dlpack = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(torch.float16)) - self.scratch_buffer = torch.utils.dlpack.from_dlpack(dlpack) - algorithms = load_algorithms(scratch_buffer=self.scratch_buffer, rank=self.rank) - self._algorithm_nvls_packet = [ - algo - for algo in algorithms - if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_packet" - ][0] - self._algorithm_rsag_zero_copy = [ - algo - for algo in algorithms - if algo.collective == "allreduce" and algo.name == "default_allreduce_rsag_zero_copy" - ][0] - self._algorithm_packet = [ - algo for algo in algorithms if algo.collective == "allreduce" and algo.name == "default_allreduce_packet" - ][0] - if mscclpp.is_nvls_supported(): - self._algorithm_nvls_zero_copy = [ - algo - for algo in algorithms - if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_zero_copy" - ][0] - self._tune(n_warmup=5, n_graph_launches=10, n_ops_per_graph=100) + self.symmetric_memory = symmetric_memory + self._nvls = mscclpp.is_nvls_supported() - def _tune(self, n_warmup, n_graph_launches, n_ops_per_graph): - sizes = [1 << i for i in range(10, 28)] - # Pre-fill with defaults for barrier - self.best_configs = {1024: (self._algorithm_nvls_packet, 0, 0)} + self._scratch = _make_tensor(1 << 27, torch.float16) + self._barrier_tensor = _make_tensor(4096, torch.float32) - tune_tensor = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(torch.float16)) - tune_tensor = torch.utils.dlpack.from_dlpack(tune_tensor) - tune_tensor.normal_() - candidates_nblocks = [4, 8, 16, 24, 32, 48, 64, 128] - candidates_nthreads = [512, 768, 1024] + algos = _load_algorithms(self._scratch, self.rank) + self._algos = {(a.collective, a.name): a for a in algos} - for size in sizes: - algos = [] - if mscclpp.is_nvls_supported(): - algos.append(self._algorithm_nvls_zero_copy) - if size <= 4 * 1024 * 1024: - algos.append(self._algorithm_nvls_packet) - algos.append(self._algorithm_packet) - if size >= 512 * 1024: - algos.append(self._algorithm_rsag_zero_copy) + # {collective: {rounded_size: (algo, nblocks, nthreads)}} + self._tune_cache: dict[str, dict[int, tuple]] = {"allreduce": {}, "allgather": {}} + self._tune_buf = None + self._time_buf = None - best_time = float("inf") - best_config = None + def _algo(self, collective: str, name: str): + return self._algos.get((collective, name)) - for algo in algos: - for nb in candidates_nblocks: - if algo.name == "default_allreduce_nvls_packet" and nb > 16: - continue - if algo.name == "default_allreduce_packet" and nb > 56: - continue - for nt in candidates_nthreads: - if self._run_algo(algo, tune_tensor, size, nb, nt) != 0: - continue + def _default_ar_config(self): + """Fallback allreduce config for barrier / timing sync.""" + pkt = self._algo("allreduce", "default_allreduce_nvls_packet") + if self._nvls and pkt: + return (pkt, 0, 0) + return (self._algo("allreduce", "default_allreduce_packet"), 0, 0) - for _ in range(n_warmup): - self._run_algo(algo, tune_tensor, size, nb, nt) - self.barrier() + # -- low-level execute -- - capture_stream = torch.cuda.Stream() - capture_stream.wait_stream(torch.cuda.current_stream()) - - g = torch.cuda.CUDAGraph() - # Warmup on capture stream - with torch.cuda.stream(capture_stream): - self._run_algo(algo, tune_tensor, size, nb, nt) - capture_stream.synchronize() - - with torch.cuda.graph(g, stream=capture_stream): - for _ in range(n_ops_per_graph): - self._run_algo(algo, tune_tensor, size, nb, nt) - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record(capture_stream) - with torch.cuda.stream(capture_stream): - for _ in range(n_graph_launches): - g.replay() - end_event.record(capture_stream) - end_event.synchronize() - - elapsed = start_event.elapsed_time(end_event) - - # Synchronize timing results across all ranks to ensure consistent algorithm selection - # replicate n times such due to algo limitations - time_tensor = torch.full((self.world_size,), elapsed, dtype=torch.float64, device="cuda").to( - dtype=torch.float32 - ) - torch.cuda.current_stream().wait_stream(capture_stream) - # TODO: use all_reduce may cause problem if the time elapsed between different algos are too close. - # May change to broadcast in the future if that becomes an issue. - self.all_reduce(time_tensor, op=torch.distributed.ReduceOp.SUM) - avg_time = time_tensor[self.rank].item() / self.world_size - - if avg_time < best_time: - best_time = avg_time - best_config = (algo, nb, nt) - - if best_config: - self.best_configs[size] = best_config - if self.rank == 0: - print( - f"Size {size}: Best Algo {best_config[0].name} nblocks {best_config[1]} nthreads {best_config[2]} Time {(best_time/(n_graph_launches * n_ops_per_graph))*1000:.2f} us" - ) - # reset the algorithms after tuning - torch.cuda.synchronize() - for algo in algos: - algo.reset() - - def _run_algo(self, algo: mscclpp.Algorithm, tensor, size, nblocks, nthreads): - return algo.execute( - comm=self.comm.communicator, - input_buffer=tensor.data_ptr(), - output_buffer=tensor.data_ptr(), - input_size=size, - output_size=size, - dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(tensor.dtype), - op=mscclpp.ReduceOp.SUM, - stream=torch.cuda.current_stream().cuda_stream, - nblocks=nblocks, - nthreads_per_block=nthreads, - symmetric_memory=True, - ) - - def get_tuned_config(self, size): - if size < 1024: - target_size = 1024 - elif size > 256 * 1024 * 1024: - target_size = 256 * 1024 * 1024 - else: - target_size = 1 << (size - 1).bit_length() - return self.best_configs.get(target_size) - - def all_reduce(self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM, stream: torch.cuda.Stream = None): - assert op == torch.distributed.ReduceOp.SUM - config = self.get_tuned_config(tensor.nbytes) - algo, nblocks, nthreads = config if config else (self._algorithm_nvls_packet, 0, 0) + def _exec_ar(self, tensor, algo, nb, nt, op=mscclpp.ReduceOp.SUM, stream=None, accum_dtype=None, sym=True): + s = stream.cuda_stream if stream else torch.cuda.current_stream().cuda_stream ret = algo.execute( comm=self.comm.communicator, input_buffer=tensor.data_ptr(), @@ -195,107 +119,357 @@ class CustomizedComm: input_size=tensor.nbytes, output_size=tensor.nbytes, dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(tensor.dtype), - op=to_mscclpp_reduce_op(op), - stream=stream.cuda_stream if stream is not None else torch.cuda.current_stream().cuda_stream, - nblocks=nblocks, - nthreads_per_block=nthreads, - symmetric_memory=True, + op=op, + stream=s, + nblocks=nb, + nthreads_per_block=nt, + symmetric_memory=sym, + accum_dtype=accum_dtype, ) if ret != 0: - print(f"Rank {self.rank}: Algo {algo.name} failed with error {ret}") + print(f"Rank {self.rank}: {algo.name} failed ({ret})") + return ret + + def _exec_ag(self, inp, out, algo, nb, nt, stream=None, sym=None): + if sym is None: + sym = self.symmetric_memory + s = stream.cuda_stream if stream else torch.cuda.current_stream().cuda_stream + ret = algo.execute( + comm=self.comm.communicator, + input_buffer=inp.data_ptr(), + output_buffer=out.data_ptr(), + input_size=inp.nbytes, + output_size=out.nbytes, + dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(inp.dtype), + op=mscclpp.ReduceOp.NOP, + stream=s, + nblocks=nb, + nthreads_per_block=nt, + symmetric_memory=sym, + ) + if ret != 0: + print(f"Rank {self.rank}: AG {algo.name} failed ({ret})") + return ret + + def _barrier_internal(self): + a, nb, nt = self._default_ar_config() + self._exec_ar(self._barrier_tensor, a, nb, nt, sym=True) + + # -- lazy tuning -- + + def _ensure_tune_bufs(self): + if self._tune_buf is None: + self._tune_buf = _make_tensor(1 << 27, torch.float16) + self._tune_buf.normal_() + self._time_buf = _make_tensor(4096, torch.float32) + return self._tune_buf + + def _ar_candidates(self, size: int): + out = [] + if size <= 4 << 20: + a = self._algo("allreduce", "default_allreduce_nvls_packet") + if self._nvls and a: + out.append(a) + a = self._algo("allreduce", "default_allreduce_packet") + if a: + out.append(a) + a = self._algo("allreduce", "default_allreduce_allpair_packet") + if a: + out.append(a) + if size >= 512 << 10: + a = self._algo("allreduce", "default_allreduce_nvls_zero_copy") + if self._nvls and self.symmetric_memory and a: + out.append(a) + a = self._algo("allreduce", "default_allreduce_rsag_zero_copy") + if a: + out.append(a) + if torch.version.hip is not None: + a = self._algo("allreduce", "default_allreduce_fullmesh") + if a: + out.append(a) + return out + + def _ag_candidates(self): + a = self._algo("allgather", "default_allgather_fullmesh2") + return [a] if a else [] + + def _run_tune(self, collective, algo, buf, size, nb, nt): + """Single tune invocation for either collective.""" + if collective == "allreduce": + return algo.execute( + comm=self.comm.communicator, + input_buffer=buf.data_ptr(), + output_buffer=buf.data_ptr(), + input_size=size, + output_size=size, + dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(buf.dtype), + op=mscclpp.ReduceOp.SUM, + stream=torch.cuda.current_stream().cuda_stream, + nblocks=nb, + nthreads_per_block=nt, + symmetric_memory=True, + ) + else: + total = size * self.world_size + out_ptr = buf.data_ptr() + return algo.execute( + comm=self.comm.communicator, + input_buffer=out_ptr + self.rank * size, + output_buffer=out_ptr, + input_size=size, + output_size=total, + dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(buf.dtype), + op=mscclpp.ReduceOp.NOP, + stream=torch.cuda.current_stream().cuda_stream, + nblocks=nb, + nthreads_per_block=nt, + symmetric_memory=False, + ) + + def _tune_size(self, collective: str, target_size: int): + """Auto-tune one (collective, target_size) pair and cache result.""" + buf = self._ensure_tune_bufs() + cands = self._ar_candidates(target_size) if collective == "allreduce" else self._ag_candidates() + + best_time, best_cfg = float("inf"), None + used = set() + run = lambda a, nb, nt: self._run_tune(collective, a, buf, target_size, nb, nt) + + for algo in cands: + nb_limit = self._NBLOCKS_LIMIT.get(algo.name, 128) + for nb in self._CANDIDATE_NBLOCKS: + if nb > nb_limit: + continue + for nt in self._CANDIDATE_NTHREADS: + # Feasibility — sync result across ranks so all agree + ret = run(algo, nb, nt) + torch.cuda.synchronize() + self._time_buf[0] = float(ret) + self._exec_ar(self._time_buf[:1], *self._default_ar_config(), sym=True) + if self._time_buf[0].item() != 0: + continue + used.add(algo) + + # Warmup + for _ in range(self._TUNE_N_WARMUP): + run(algo, nb, nt) + + # CUDA-graph timed benchmark + cs = torch.cuda.Stream() + cs.wait_stream(torch.cuda.current_stream()) + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, stream=cs): + for _ in range(self._TUNE_N_OPS_PER_GRAPH): + run(algo, nb, nt) + + start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + start.record(cs) + with torch.cuda.stream(cs): + for _ in range(self._TUNE_N_GRAPH_LAUNCHES): + g.replay() + end.record(cs) + end.synchronize() + elapsed = start.elapsed_time(end) + + # Cross-rank timing sync + self._time_buf.fill_(elapsed) + torch.cuda.current_stream().wait_stream(cs) + self._exec_ar(self._time_buf, *self._default_ar_config(), sym=True) + avg = self._time_buf[self.rank].item() / self.world_size + + if avg < best_time: + best_time, best_cfg = avg, (algo, nb, nt) + + if best_cfg: + self._tune_cache[collective][target_size] = best_cfg + if self.rank == 0: + n = self._TUNE_N_GRAPH_LAUNCHES * self._TUNE_N_OPS_PER_GRAPH + print( + f"[tune] {collective} size={target_size}: {best_cfg[0].name} " + f"nb={best_cfg[1]} nt={best_cfg[2]} time={best_time / n * 1000:.2f}us", + flush=True, + ) + else: + fb = ( + self._default_ar_config() + if collective == "allreduce" + else ((self._ag_candidates()[0], 32, 512) if self._ag_candidates() else None) + ) + self._tune_cache[collective][target_size] = fb + + torch.cuda.synchronize() + self._barrier_internal() + for a in used: + a.reset() + + # -- public API -- + + def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, stream=None, accum_dtype=None): + sz = _round_pow2(tensor.nbytes) + if sz not in self._tune_cache["allreduce"]: + self._tune_size("allreduce", sz) + a, nb, nt = self._tune_cache["allreduce"][sz] + self._exec_ar( + tensor, a, nb, nt, op=_to_mscclpp_op(op), stream=stream, accum_dtype=accum_dtype, sym=self.symmetric_memory + ) + + def all_gather(self, output_tensor, input_tensor, stream=None): + sz = _round_pow2(input_tensor.nbytes) + if sz not in self._tune_cache["allgather"]: + self._tune_size("allgather", sz) + a, nb, nt = self._tune_cache["allgather"][sz] + self._exec_ag(input_tensor, output_tensor, a, nb, nt, stream=stream, sym=self.symmetric_memory) def barrier(self): - tensor = torch.empty(self.world_size, dtype=torch.float, device=torch.device("cuda")) - self.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM, stream=torch.cuda.current_stream()) - - def benchmark(self, n_warmup=10, n_graph_launches=10, n_iter_per_graph=100): - low = 5 * 1024 - high = 80 * 1024 * 1024 - sizes = [] - curr = low - while curr <= high: - sizes.append(curr) - curr *= 2 - - if self.rank == 0: - print(f"{'Size (Bytes)':<20} {'Time (us)':<20} {'AlgoBW (GB/s)':<20}") - - dtype = torch.float16 - capture_stream = torch.cuda.Stream() - - # Allocate a single large RawGpuBuffer (symmetric memory) and reuse it for all sizes. - # Cannot allocate per-size tensors with symmetric memory. - bench_buf = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(dtype)) - bench_buf = torch.utils.dlpack.from_dlpack(bench_buf) - bench_buf.normal_() - - for size in sizes: - n_elements = size // bench_buf.element_size() - tensor = bench_buf[:n_elements] - - capture_stream.wait_stream(torch.cuda.current_stream()) - # Capture Graph - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g, stream=capture_stream): - for _ in range(n_iter_per_graph): - self.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) - - # warmup: Execute the graph once to prime the driver - with torch.cuda.stream(capture_stream): - for _ in range(n_warmup): - g.replay() - self.barrier() - capture_stream.synchronize() - - # Benchmark - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record(capture_stream) - with torch.cuda.stream(capture_stream): - for _ in range(n_graph_launches): - g.replay() - end_event.record(capture_stream) - end_event.synchronize() - - # Get elapsed time in milliseconds - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / (n_graph_launches * n_iter_per_graph) - time_us = avg_time_ms * 1000 - - alg_bw = size / (avg_time_ms * 1e-3) if avg_time_ms > 0 else 0 - if self.rank == 0: - print(f"{size:<20} {time_us:<20.2f} {alg_bw / 1e9:<20.2f}") + self._barrier_internal() def destroy(self): - self._algorithm_nvls_nonzero_copy = None - self._algorithm_nvls_packet = None - self.scratch_buffer = None - self.comm = None + self._algos.clear() + self._tune_cache = {"allreduce": {}, "allgather": {}} + self._tune_buf = self._time_buf = self._barrier_tensor = self._scratch = self.comm = None -def init_dist() -> CustomizedComm: - rank = int(os.environ["RANK"]) - world = int(os.environ["WORLD_SIZE"]) - master_addr = os.environ["MSCCLPP_MASTER_ADDR"] - master_port = os.environ["MSCCLPP_MASTER_PORT"] - interface = interfaces_for_ip_netifaces(master_addr) - if interface is None: - raise ValueError(f"Cannot find network interface for IP address {master_addr}") - interfaceIpPortTrio = f"{interface}:{master_addr}:{master_port}" - mscclpp_group = mscclpp.CommGroup(interfaceIpPortTrio=interfaceIpPortTrio, rank=rank, size=world) - return CustomizedComm(mscclpp_group) +# -- Benchmarks (standalone) -------------------------------------------------- + + +def _bench_sizes(low=5 * 1024, high=80 << 20): + sizes, c = [], low + while c <= high: + sizes.append(c) + c *= 2 + return sizes + + +def benchmark_allreduce( + comm: CustomizedComm, dtype=torch.float16, accum_dtype=None, n_warmup=10, n_graph_launches=10, n_iter=100 +): + sizes = _bench_sizes() + if comm.rank == 0: + print(f"\n{'='*60}\nAllreduce Benchmark\n{'='*60}") + print(f"{'Nelements':<18} {'Size(B)':<18} {'Time(us)':<18} {'AlgoBW(GB/s)':<18}") + + cs = torch.cuda.Stream() + buf = _make_tensor(1 << 27, dtype) + buf.normal_() if dtype in (torch.float16, torch.float32, torch.bfloat16) else buf.fill_(0) + + for size in sizes: + nelems = size // buf.element_size() + t = buf[: size // buf.element_size()] + comm.all_reduce(t, accum_dtype=accum_dtype) + torch.cuda.synchronize() + + cs.wait_stream(torch.cuda.current_stream()) + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, stream=cs): + for _ in range(n_iter): + comm.all_reduce(t, accum_dtype=accum_dtype) + with torch.cuda.stream(cs): + for _ in range(n_warmup): + g.replay() + comm.barrier() + cs.synchronize() + + s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + s.record(cs) + with torch.cuda.stream(cs): + for _ in range(n_graph_launches): + g.replay() + e.record(cs) + e.synchronize() + + ms = s.elapsed_time(e) / (n_graph_launches * n_iter) + if comm.rank == 0: + print(f"{nelems:<18} {size:<18} {ms*1000:<18.2f} {size/(ms*1e-3)/1e9:<18.2f}") + + +def benchmark_allgather(comm: CustomizedComm, dtype=torch.float16, n_warmup=10, n_graph_launches=10, n_iter=100): + sizes = _bench_sizes() + if comm.rank == 0: + print(f"\n{'='*60}\nAllgather Benchmark\n{'='*60}") + print(f"{'PerRank(B)':<18} {'Total(B)':<18} {'Time(us)':<18} {'AlgoBW(GB/s)':<18}") + + cs = torch.cuda.Stream() + buf = _make_tensor(1 << 27, dtype) + buf.normal_() if dtype in (torch.float16, torch.float32, torch.bfloat16) else buf.fill_(0) + + for prs in sizes: + total = prs * comm.world_size + if total > buf.nbytes: + break + nt = total // buf.element_size() + npr = prs // buf.element_size() + out = buf[:nt] + inp = out[comm.rank * npr : (comm.rank + 1) * npr] + + comm.all_gather(out, inp) + torch.cuda.synchronize() + + cs.wait_stream(torch.cuda.current_stream()) + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, stream=cs): + for _ in range(n_iter): + comm.all_gather(out, inp) + with torch.cuda.stream(cs): + for _ in range(n_warmup): + g.replay() + comm.barrier() + cs.synchronize() + + s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + s.record(cs) + with torch.cuda.stream(cs): + for _ in range(n_graph_launches): + g.replay() + e.record(cs) + e.synchronize() + + ms = s.elapsed_time(e) / (n_graph_launches * n_iter) + if comm.rank == 0: + print(f"{prs:<18} {total:<18} {ms*1000:<18.2f} {total/(ms*1e-3)/1e9:<18.2f}") + + +# -- Bootstrap & main --------------------------------------------------------- + + +def init_dist() -> mscclpp.CommGroup: + addr = os.environ.get("MSCCLPP_MASTER_ADDR") + if addr: + rank, world = int(os.environ["RANK"]), int(os.environ["WORLD_SIZE"]) + port = os.environ["MSCCLPP_MASTER_PORT"] + iface = _interfaces_for_ip(addr) + if not iface: + raise ValueError(f"No interface for {addr}") + return mscclpp.CommGroup(interfaceIpPortTrio=f"{iface}:{addr}:{port}", rank=rank, size=world) + import torch.distributed as dist + + dist.init_process_group(backend="gloo") + return mscclpp.CommGroup(torch_group=dist.group.WORLD) def main(): local = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(local) - comm = init_dist() - comm.benchmark(n_warmup=5, n_graph_launches=10, n_iter_per_graph=100) - comm.barrier() + + dtype_str = os.environ.get("DTYPE", "float16") + dtype = getattr(torch, dtype_str, torch.float16) + accum_map = {"float32": mscclpp.DataType.float32, "float16": mscclpp.DataType.float16} + accum_str = os.environ.get("ACCUM_DTYPE") + accum_dtype = accum_map.get(accum_str) if accum_str else None + + comm_group = init_dist() + cc = CustomizedComm(comm_group) + + print(f"rank {local} starting benchmarks with dtype={dtype} accum_dtype={accum_dtype}...") + benchmark_allreduce(cc, dtype=dtype, accum_dtype=accum_dtype) + cc.barrier() torch.cuda.synchronize() - comm.destroy() - print(f"rank {local} All-reduce operation completed successfully.") + + benchmark_allgather(cc, dtype=dtype) + cc.barrier() + torch.cuda.synchronize() + + cc.destroy() + print(f"rank {local} completed successfully.") if __name__ == "__main__": diff --git a/include/mscclpp/gpu_data_types.hpp b/include/mscclpp/gpu_data_types.hpp index fa31a28f..41bd5928 100644 --- a/include/mscclpp/gpu_data_types.hpp +++ b/include/mscclpp/gpu_data_types.hpp @@ -1072,6 +1072,15 @@ MSCCLPP_DEVICE_INLINE f16x2 to(const f8_e4m3b15x2& v) { __half2 h; asm("mov.b32 %0, %1;" : "=r"(*reinterpret_cast(&h)) : "r"(out0)); return h; +#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + // gfx942: same bit manipulation as CUDA, store packed fp16 bits via words[]. + uint16_t in = v.storage.__x; + uint32_t a0 = ((uint32_t)(in & 0xFFu) << 8) | ((uint32_t)(in >> 8) << 24); + uint32_t b0 = (a0 & 0x7f007f00u) >> 1; + uint32_t out0 = b0 | (a0 & 0x80008000u); + f16x2 result; + result.words[0] = out0; + return result; #else f16x2 result; result.data[0] = __float2half(float(v.data[0])); @@ -1100,6 +1109,17 @@ MSCCLPP_DEVICE_INLINE f16x4 to(const f8_e4m3b15x4& v) { asm("mov.b32 %0, %1;" : "=r"(result.words[0]) : "r"(out0)); asm("mov.b32 %0, %1;" : "=r"(result.words[1]) : "r"(out1)); return result; +#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + // gfx942: __byte_perm + bitwise E4→E5 shift (no lop3), store via words[]. + uint32_t in = v.storage.__x; + uint32_t a0 = __byte_perm(0u, in, 0x5746u); + uint32_t out0 = ((a0 >> 1) & 0x3f803f80u) | (a0 & 0x80008000u); + uint32_t a1 = __byte_perm(a0, 0u, 0x2301u); + uint32_t out1 = ((a1 >> 1) & 0x3f803f80u) | (a1 & 0x80008000u); + f16x4 result; + result.words[0] = out0; + result.words[1] = out1; + return result; #else f16x4 result; #pragma unroll @@ -1127,6 +1147,16 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to(const f16x2& v) { uint32_t b0 = a0 | (in0 & 0x80008000u); uint16_t packed = (uint16_t)(((b0 >> 8) & 0xFFu) | ((b0 >> 16) & 0xFF00u)); return bit_cast(packed); +#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + // gfx942: read packed fp16 bits, clamp via v_pk_min_u16, shift E5→E4, pack. + uint32_t in0 = v.words[0]; + uint32_t abs0 = in0 & 0x7fff7fffu; + uint32_t a0; + asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3B803B80u)); + a0 = a0 * 2u + 0x00800080u; + uint32_t b0 = a0 | (in0 & 0x80008000u); + uint16_t packed = (uint16_t)(((b0 >> 8) & 0xFFu) | ((b0 >> 16) & 0xFF00u)); + return bit_cast(packed); #else f8_e4m3b15x2 result; result.data[0] = __fp8_e4m3b15(__half2float(v.data[0])); @@ -1154,6 +1184,19 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to(const f16x4& v) { asm("lop3.b32 %0, %1, %2, %3, 0xf8;" : "=r"(b1) : "r"(a1), "r"(in1), "r"(0x80008000u)); uint32_t packed = __byte_perm(b0, b1, 0x7531u); return bit_cast(packed); +#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + // gfx942: read packed fp16 bits, clamp via v_pk_min_u16, shift E5→E4, __byte_perm pack. + uint32_t in0 = v.words[0], in1 = v.words[1]; + uint32_t abs0 = in0 & 0x7fff7fffu, abs1 = in1 & 0x7fff7fffu; + uint32_t a0, a1; + asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3B803B80u)); + asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a1) : "v"(abs1), "v"(0x3B803B80u)); + a0 = a0 * 2u + 0x00800080u; + a1 = a1 * 2u + 0x00800080u; + uint32_t b0 = a0 | (in0 & 0x80008000u); + uint32_t b1 = a1 | (in1 & 0x80008000u); + uint32_t packed = __byte_perm(b0, b1, 0x7531u); + return bit_cast(packed); #else f8_e4m3b15x4 result; #pragma unroll @@ -1164,8 +1207,7 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to(const f16x4& v) { #endif } -// --- fp8_e4m3b15 <-> f32 conversion specializations --- -// Derived from fp16 conversions: fp8→f32 = fp8→fp16→f32, f32→fp8 = f32→fp16→fp8. +// --- fp8_e4m3b15 <-> f32 conversion specializations (software, always available) --- /// f8_e4m3b15x2 -> f32x2. /// Routes through fp16: fp8→fp16 (bit manip) then fp16→f32. @@ -1175,6 +1217,12 @@ MSCCLPP_DEVICE_INLINE f32x2 to(const f8_e4m3b15x2& v) { f16x2 h = to(v); float2 f2 = __half22float2(h); return bit_cast(f2); +#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + f16x2 h = to(v); + f32x2 result; + result.data[0] = __half2float(h.data[0]); + result.data[1] = __half2float(h.data[1]); + return result; #else f32x2 result; result.data[0] = float(v.data[0]); @@ -1200,6 +1248,14 @@ MSCCLPP_DEVICE_INLINE f32x4 to(const f8_e4m3b15x4& v) { result.data[2] = f1.x; result.data[3] = f1.y; return result; +#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + f16x4 h = to(v); + f32x4 result; + result.data[0] = __half2float(h.data[0]); + result.data[1] = __half2float(h.data[1]); + result.data[2] = __half2float(h.data[2]); + result.data[3] = __half2float(h.data[3]); + return result; #else f32x4 result; #pragma unroll @@ -1218,6 +1274,11 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to(const f32x2& v) { float2 f2 = {v.data[0], v.data[1]}; __half2 h = __float22half2_rn(f2); return to(h); +#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + f16x2 h; + h.data[0] = __float2half_rn(v.data[0]); + h.data[1] = __float2half_rn(v.data[1]); + return to(h); #else f8_e4m3b15x2 result; result.data[0] = __fp8_e4m3b15(v.data[0]); @@ -1239,6 +1300,11 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to(const f32x4& v) { asm("mov.b32 %0, %1;" : "=r"(h.words[0]) : "r"(*reinterpret_cast(&h01))); asm("mov.b32 %0, %1;" : "=r"(h.words[1]) : "r"(*reinterpret_cast(&h23))); return to(h); +#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + f16x4 h; + h.words[0] = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(v.data[0], v.data[1])); + h.words[1] = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(v.data[2], v.data[3])); + return to(h); #else f8_e4m3b15x4 result; #pragma unroll diff --git a/python/csrc/CMakeLists.txt b/python/csrc/CMakeLists.txt index 8759201f..44fb150f 100644 --- a/python/csrc/CMakeLists.txt +++ b/python/csrc/CMakeLists.txt @@ -24,4 +24,7 @@ set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME _mscclpp) set_target_properties(mscclpp_py PROPERTIES INSTALL_RPATH "\$ORIGIN/lib") target_link_libraries(mscclpp_py PRIVATE dlpack mscclpp mscclpp_collectives ${GPU_LIBRARIES}) target_include_directories(mscclpp_py SYSTEM PRIVATE ${GPU_INCLUDE_DIRS}) +if(MSCCLPP_USE_ROCM) + target_compile_definitions(mscclpp_py PRIVATE MSCCLPP_USE_ROCM) +endif() install(TARGETS mscclpp_py LIBRARY DESTINATION .) diff --git a/python/requirements_rocm6.txt b/python/requirements_rocm6.txt index d2a3389b..7ed4fef3 100644 --- a/python/requirements_rocm6.txt +++ b/python/requirements_rocm6.txt @@ -1,5 +1,5 @@ -mpi4py==4.1.1 -cupy==13.6.0 +mpi4py +cupy prettytable netifaces pytest diff --git a/python/test/test_fp8_accum.py b/python/test/test_fp8_accum.py index 3a6c67f1..82981ce1 100644 --- a/python/test/test_fp8_accum.py +++ b/python/test/test_fp8_accum.py @@ -21,9 +21,8 @@ from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group # FP8 E4M3 (hardware) requires SM >= 89 (Ada / Hopper) on NVIDIA GPUs. # On AMD/ROCm (e.g. MI300X), FP8 is supported natively — no skip needed. _is_hip = hasattr(cp.cuda.runtime, "is_hip") and cp.cuda.runtime.is_hip -# TODO(binyli): Skip hip for now, will fix it in the next PR -_skip_fp8 = _is_hip or int(cp.cuda.Device().compute_capability) < 89 -pytestmark = pytest.mark.skipif(_skip_fp8, reason="FP8 accum tests require SM >= 89 on CUDA (HIP not yet supported)") +_skip_fp8 = not _is_hip and int(cp.cuda.Device().compute_capability) < 89 +pytestmark = pytest.mark.skipif(_skip_fp8, reason="FP8 accum tests require SM >= 89 on CUDA") # --------------------------------------------------------------------------- # FP8 E4M3FN helpers (bias=7, no infinity, NaN = exp=15 & mant=7) @@ -208,6 +207,7 @@ def run_allreduce(algo, comm_group, buffer, dtype, accum_dtype=None, nblocks=0, "default_allreduce_nvls_packet", "default_allreduce_fullmesh", "default_allreduce_rsag_zero_copy", + "default_allreduce_allpair_packet", ], ) @pytest.mark.parametrize("size", [1024, 4096, 16384, 65536, 262144, 1048576]) @@ -220,6 +220,8 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int): comm_group, algo_map, scratch = setup_algorithms(mpi_group) if algo_name not in algo_map: pytest.skip(f"{algo_name} not available") + if "nvls" in algo_name and not is_nvls_supported(): + pytest.skip(f"{algo_name} requires NVLS which is not supported on this platform") algo = algo_map[algo_name] buf = GpuBuffer(size, dtype=cp.uint8) @@ -243,9 +245,9 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int): errors = {} for accum_label, accum_dtype in accum_configs: - # Generate deterministic per-rank data - cp.random.seed(42 + rank) - src_f32 = cp.random.randn(size).astype(cp.float32) + # Generate deterministic per-rank data (use numpy to avoid hipRAND issues on ROCm) + rng = np.random.RandomState(42 + rank) + src_f32 = cp.asarray(rng.randn(size).astype(np.float32)) src_f32 = cp.clip(src_f32, -240.0, 240.0) src_fp8 = float_to_e4m3fn(src_f32) @@ -268,8 +270,8 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int): # Compute float32 reference: sum all ranks' quantized FP8 inputs in float32 ref_f32 = cp.zeros(size, dtype=cp.float32) for r in range(world_size): - cp.random.seed(42 + r) - rank_data = cp.random.randn(size).astype(cp.float32) + rng_r = np.random.RandomState(42 + r) + rank_data = cp.asarray(rng_r.randn(size).astype(np.float32)) rank_data = cp.clip(rank_data, -240.0, 240.0) rank_data_fp8 = float_to_e4m3fn(rank_data) ref_f32 += e4m3fn_to_float(rank_data_fp8) @@ -303,6 +305,8 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int): "default_allreduce_packet", "default_allreduce_nvls_packet", "default_allreduce_rsag_zero_copy", + "default_allreduce_fullmesh", + "default_allreduce_allpair_packet", ], ) @pytest.mark.parametrize("size", [1024, 4096, 65536]) @@ -315,6 +319,8 @@ def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int): comm_group, algo_map, scratch = setup_algorithms(mpi_group) if algo_name not in algo_map: pytest.skip(f"{algo_name} not available") + if "nvls" in algo_name and not is_nvls_supported(): + pytest.skip(f"{algo_name} requires NVLS which is not supported on this platform") algo = algo_map[algo_name] buf = GpuBuffer(size, dtype=cp.uint8) @@ -336,9 +342,9 @@ def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int): errors = {} for accum_label, accum_dtype in accum_configs: # Generate deterministic per-rank random uint8 values in valid e4m3b15 range - cp.random.seed(42 + rank) - raw = cp.random.randint(0, 0x78, (size,), dtype=cp.uint8) - signs = cp.random.randint(0, 2, (size,), dtype=cp.uint8).astype(cp.uint8) << 7 + rng = np.random.RandomState(42 + rank) + raw = cp.asarray(rng.randint(0, 0x78, (size,)).astype(np.uint8)) + signs = cp.asarray(rng.randint(0, 2, (size,)).astype(np.uint8)) << 7 src_uint8 = raw | signs # Fix negative zero -> positive zero src_uint8 = cp.where(src_uint8 == 0x80, cp.uint8(0), src_uint8) @@ -364,9 +370,9 @@ def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int): # Compute float32 reference ref_f32 = cp.zeros(size, dtype=cp.float32) for r in range(world_size): - cp.random.seed(42 + r) - raw_r = cp.random.randint(0, 0x78, (size,), dtype=cp.uint8) - signs_r = cp.random.randint(0, 2, (size,), dtype=cp.uint8).astype(cp.uint8) << 7 + rng_r = np.random.RandomState(42 + r) + raw_r = cp.asarray(rng_r.randint(0, 0x78, (size,)).astype(np.uint8)) + signs_r = cp.asarray(rng_r.randint(0, 2, (size,)).astype(np.uint8)) << 7 bits_r = raw_r | signs_r bits_r = cp.where(bits_r == 0x80, cp.uint8(0), bits_r) ref_f32 += e4m3b15_to_float(bits_r) diff --git a/src/ext/collectives/allreduce/allreduce_allpair_packet.cu b/src/ext/collectives/allreduce/allreduce_allpair_packet.cu index 6cbc8977..17bcfc33 100644 --- a/src/ext/collectives/allreduce/allreduce_allpair_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_allpair_packet.cu @@ -2,6 +2,7 @@ // Licensed under the MIT license. #include +#include #include "allreduce/allreduce_allpair_packet.hpp" #include "allreduce/common.hpp" @@ -11,7 +12,7 @@ namespace mscclpp { namespace collective { -template +template __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHandle* memoryChannels, size_t channelDataOffset, size_t scratchBufferSize, int rank, int nRanksPerNode, int worldSize, size_t nelems, uint32_t numScratchBuff, void* flags, @@ -43,13 +44,16 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand // step 2: Reduce Data for (size_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nelems; idx += blockDim.x * gridDim.x) { uint32_t data = src[idx]; + using AccRaw = std::conditional_t, uint32_t, + mscclpp::VectorType>; + AccRaw acc = mscclpp::upcastVector(data); for (int index = 0; index < nPeers; index++) { const int remoteRank = index < rank ? index : index + 1; LL8Packet* dstPkt = (LL8Packet*)scratchBuff + remoteRank * nelems; uint32_t val = dstPkt[idx].read(flag, -1); - data = calVector(val, data); + acc = mscclpp::calVectorAccum(acc, val); } - dst[idx] = data; + dst[idx] = mscclpp::downcastVector(acc); } __syncthreads(); if (threadIdx.x == 0) { @@ -76,7 +80,12 @@ struct AllpairAdapter { int nThreadsPerBlock = 0) { using ChannelType = DeviceHandle; const size_t nelems = inputSize / sizeof(T); - allreduceAllPairs<<>>( + // Round nBlocks to multiple of nPeers so every block maps to a valid peer. + const int nPeers = worldSize - 1; + if (nPeers > 0) { + nBlocks = (nBlocks / nPeers) * nPeers; + } + allreduceAllPairs<<>>( (T*)buff, (T*)scratch, (T*)resultBuff, (ChannelType*)memoryChannels, channelInOffset, scratchBufferSize, rank, nRanksPerNode, worldSize, nelems, numScratchBuff, flags, flagSize); return cudaGetLastError(); @@ -101,6 +110,11 @@ CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptrworkSize); } + // nBlocks must be at least nPeers for allpair — each block maps to one peer. + const int nPeers = algoCtx->nRanksPerNode - 1; + if (nPeers > 0 && blockAndThreadNum.first < nPeers) { + return CommResult::CommInvalidArgument; + } size_t sendBytes; CUdeviceptr sendBasePtr; MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)input)); diff --git a/src/ext/collectives/allreduce/allreduce_fullmesh.cu b/src/ext/collectives/allreduce/allreduce_fullmesh.cu index ee46fd77..24d2a31c 100644 --- a/src/ext/collectives/allreduce/allreduce_fullmesh.cu +++ b/src/ext/collectives/allreduce/allreduce_fullmesh.cu @@ -213,6 +213,13 @@ CommResult AllreduceFullmesh::allreduceKernelFunc( return CommResult::CommInvalidArgument; } std::pair numBlocksAndThreads = {nBlocks, nThreadsPerBlock}; + if (numBlocksAndThreads.first > 64) { + WARN("AllreduceFullmesh: number of blocks exceeds maximum supported blocks, which is 64"); + return mscclpp::CommResult::CommInvalidArgument; + } + if (numBlocksAndThreads.first == 0 || numBlocksAndThreads.second == 0) { + numBlocksAndThreads = {35, 512}; + } cudaError_t error = allreduce(input, this->scratchBuffer_, output, inputChannelHandles.get(), ctx->memoryChannelDeviceHandles.get(), nullptr, nullptr, 0, channelOutOffset, 0, ctx->rank, ctx->nRanksPerNode, ctx->workSize, inputSize,