diff --git a/.azure-pipelines/templates/deploy.yml b/.azure-pipelines/templates/deploy.yml index fc116acf..2f642f1d 100644 --- a/.azure-pipelines/templates/deploy.yml +++ b/.azure-pipelines/templates/deploy.yml @@ -94,7 +94,27 @@ steps: du -sh build/bin/* 2>/dev/null || true workingDirectory: '$(System.DefaultWorkingDirectory)' -# 2. Download SSH key + install packages + start VMSS +# 2. Write CMake args for pip install on remote VMs +- task: Bash@3 + name: WritePipCmakeArgs + displayName: Write pip CMake args + inputs: + targetType: 'inline' + script: | + set -e + PIP_CMAKE_ARGS="" + if [ -n "${{ parameters.gpuArch }}" ]; then + PIP_CMAKE_ARGS="-DMSCCLPP_GPU_ARCHS=${{ parameters.gpuArch }}" + fi + CMAKE_EXTRA_ARGS='${{ parameters.cmakeArgs }}' + if [ -n "${CMAKE_EXTRA_ARGS}" ]; then + PIP_CMAKE_ARGS="${PIP_CMAKE_ARGS} ${CMAKE_EXTRA_ARGS}" + fi + echo "${PIP_CMAKE_ARGS}" > pip_cmake_args.txt + echo "pip CMake args: $(cat pip_cmake_args.txt)" + workingDirectory: '$(System.DefaultWorkingDirectory)' + +# 3. Download SSH key + install packages + start VMSS - task: DownloadSecureFile@1 name: SshKeyFile displayName: Download key file @@ -120,7 +140,7 @@ steps: inlineScript: | az vmss start --name ${{ parameters.vmssName }} --resource-group ${{ parameters.resourceGroup }} -# 3. Deploy test environment +# 4. Deploy test environment - task: Bash@3 name: DeployTestEnv displayName: Deploy Test Env diff --git a/.azure-pipelines/templates/ut-npkit.yml b/.azure-pipelines/templates/ut-npkit.yml index e53b5cf5..1bd89caf 100644 --- a/.azure-pipelines/templates/ut-npkit.yml +++ b/.azure-pipelines/templates/ut-npkit.yml @@ -28,7 +28,7 @@ steps: grep -q NPKIT_EVENT_EXECUTOR_INIT_ENTRY ./npkit_output/npkit_event_trace.json grep -q NPKIT_EVENT_EXECUTOR_SIGNAL_ENTRY ./npkit_output/npkit_event_trace.json grep -q NPKIT_EVENT_EXECUTOR_WAIT_ENTRY ./npkit_output/npkit_event_trace.json - grep -q NPKIT_EVENT_EXECUTOR_READ_REDUCE_COPY_SEND_ENTRY ./npkit_output/npkit_event_trace.json + grep -q NPKIT_EVENT_EXECUTOR_READ_REDUCE_SEND_ENTRY ./npkit_output/npkit_event_trace.json - template: run-remote-task.yml parameters: @@ -42,14 +42,14 @@ steps: grep -q NPKIT_EVENT_EXECUTOR_INIT_ENTRY ./npkit_output/npkit_event_trace.json grep -q NPKIT_EVENT_EXECUTOR_SIGNAL_ENTRY ./npkit_output/npkit_event_trace.json grep -q NPKIT_EVENT_EXECUTOR_WAIT_ENTRY ./npkit_output/npkit_event_trace.json - grep -q NPKIT_EVENT_EXECUTOR_READ_REDUCE_COPY_SEND_ENTRY ./npkit_output/npkit_event_trace.json + grep -q NPKIT_EVENT_EXECUTOR_READ_REDUCE_SEND_ENTRY ./npkit_output/npkit_event_trace.json rm -rf ./npkit_dump && mkdir ./npkit_dump && rm -rf ./npkit_output && mkdir ./npkit_output mpirun --allow-run-as-root -tag-output -x MSCCLPP_HOME=/root/mscclpp -np 8 python3 -m pytest ./python/test/test_mscclpp.py -x -k 'test_executor[allreduce_packet.json' python3 ./tools/npkit/npkit_trace_generator.py --npkit_dump_dir=./npkit_dump --npkit_event_header_path=./include/mscclpp/npkit/npkit_event.hpp --output_dir=./npkit_output grep -q NPKIT_EVENT_EXECUTOR_INIT_ENTRY ./npkit_output/npkit_event_trace.json - grep -q NPKIT_EVENT_EXECUTOR_COPY_PACKET_ENTRY ./npkit_output/npkit_event_trace.json - grep -q NPKIT_EVENT_EXECUTOR_PUT_PACKET_ENTRY ./npkit_output/npkit_event_trace.json - grep -q NPKIT_EVENT_EXECUTOR_REDUCE_SEND_PACKET_ENTRY ./npkit_output/npkit_event_trace.json + grep -q NPKIT_EVENT_EXECUTOR_PUT_PACKETS_ENTRY ./npkit_output/npkit_event_trace.json + grep -q NPKIT_EVENT_EXECUTOR_REDUCE_SEND_PACKETS_ENTRY ./npkit_output/npkit_event_trace.json + grep -q NPKIT_EVENT_EXECUTOR_UNPACK_PACKETS_ENTRY ./npkit_output/npkit_event_trace.json - template: stop.yml parameters: diff --git a/.azure-pipelines/templates/ut.yml b/.azure-pipelines/templates/ut.yml index 9d17e923..743c66e6 100644 --- a/.azure-pipelines/templates/ut.yml +++ b/.azure-pipelines/templates/ut.yml @@ -41,6 +41,7 @@ steps: displayName: Run pytests remoteScript: | mpirun --allow-run-as-root -tag-output -x MSCCLPP_HOME=/root/mscclpp -x GPU_MAX_HW_QUEUES=8 -np 8 python3 -m pytest ./python/test/test_mscclpp.py -x + mpirun --allow-run-as-root -tag-output -x MSCCLPP_HOME=/root/mscclpp -x GPU_MAX_HW_QUEUES=8 -np 8 python3 -m pytest ./python/test/test_fp8_accum.py -x - template: stop.yml parameters: diff --git a/VERSION b/VERSION index a3df0a69..ac39a106 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.8.0 +0.9.0 diff --git a/cmake/FindGDRCopy.cmake b/cmake/FindGDRCopy.cmake index c1f786ae..54e0ba1c 100644 --- a/cmake/FindGDRCopy.cmake +++ b/cmake/FindGDRCopy.cmake @@ -30,19 +30,15 @@ find_library(GDRCOPY_LIBRARIES ${GDRCOPY_ROOT_DIR}/lib /usr/local/lib /usr/lib - /usr/lib/x86_64-linux-gnu - /usr/lib/aarch64-linux-gnu) + /usr/lib/x86_64-linux-gnu) if(GDRCOPY_INCLUDE_DIRS) - include(CheckCXXSourceCompiles) + include(CheckSymbolExists) set(CMAKE_REQUIRED_INCLUDES ${GDRCOPY_INCLUDE_DIRS}) set(CMAKE_REQUIRED_LIBRARIES ${GDRCOPY_LIBRARIES}) - check_cxx_source_compiles(" - #include - int main() { gdr_pin_buffer_v2(0, 0, 0, 0, 0); return 0; } - " GDRCOPY_HAS_PIN_BUFFER_V2) - unset(CMAKE_REQUIRED_INCLUDES) + check_symbol_exists(gdr_pin_buffer_v2 "gdrapi.h" GDRCOPY_HAS_PIN_BUFFER_V2) unset(CMAKE_REQUIRED_LIBRARIES) + unset(CMAKE_REQUIRED_INCLUDES) if(NOT GDRCOPY_HAS_PIN_BUFFER_V2) message(STATUS "GDRCopy found but too old (gdr_pin_buffer_v2 not available). Requires >= 2.5.") set(GDRCOPY_INCLUDE_DIRS GDRCOPY_INCLUDE_DIRS-NOTFOUND) diff --git a/docker/base-dev-x.dockerfile b/docker/base-dev-x.dockerfile index e0e2a043..47436202 100644 --- a/docker/base-dev-x.dockerfile +++ b/docker/base-dev-x.dockerfile @@ -52,7 +52,7 @@ RUN OS_ARCH=$(uname -m) && \ # Install GDRCopy userspace library for CUDA targets ARG TARGET="cuda13.0" RUN if echo "$TARGET" | grep -q "^cuda"; then \ - GDRCOPY_VERSION="2.5.1" && \ + GDRCOPY_VERSION="2.5.2" && \ apt-get update -y && \ apt-get install -y --no-install-recommends devscripts debhelper fakeroot pkg-config dkms && \ cd /tmp && \ diff --git a/docs/Makefile b/docs/Makefile index 5bc7422e..bf82c03a 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -5,7 +5,7 @@ # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build -SPHINXMULTIVERSION ?= sphinx-multiversion +SPHINXMULTIVERSION ?= python3 build_multiversion.py SOURCEDIR = . BUILDDIR = _build diff --git a/docs/build_multiversion.py b/docs/build_multiversion.py new file mode 100644 index 00000000..ace20fc0 --- /dev/null +++ b/docs/build_multiversion.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Wrapper around sphinx-multiversion that patches copy_tree to generate +_version.py in each tag checkout. This is needed because setuptools_scm +generates _version.py at build time, but sphinx-multiversion uses +`git archive` which only contains committed files. + +Usage (called by Makefile): + python3 build_multiversion.py [sphinx-opts...] +""" + +import os +import re +import subprocess +import sys + +import sphinx_multiversion.git as smv_git +from sphinx_multiversion import main as smv_main + +# Save the original copy_tree +_original_copy_tree = smv_git.copy_tree + + +def _patched_copy_tree(gitroot, src, dst, reference, sourcepath="."): + """Call original copy_tree, then generate _version.py from the VERSION file.""" + _original_copy_tree(gitroot, src, dst, reference, sourcepath) + + # Extract version from the tag name (e.g., "v0.9.0" -> "0.9.0") + refname = getattr(reference, "refname", "") or "" + match = re.search(r"v(\d+\.\d+\.\d+)", refname) + if not match: + return + + version = match.group(1) + version_py_dir = os.path.join(dst, "python", "mscclpp") + if os.path.isdir(version_py_dir): + version_py = os.path.join(version_py_dir, "_version.py") + if not os.path.exists(version_py): + with open(version_py, "w") as f: + f.write(f'__version__ = "{version}"\n') + + +# Monkey-patch +smv_git.copy_tree = _patched_copy_tree + +if __name__ == "__main__": + sys.exit(smv_main(sys.argv[1:])) diff --git a/docs/guide/mscclpp-torch-integration.md b/docs/guide/mscclpp-torch-integration.md index 1c966155..b4e4fcdf 100644 --- a/docs/guide/mscclpp-torch-integration.md +++ b/docs/guide/mscclpp-torch-integration.md @@ -332,7 +332,8 @@ public: size_t inputSize, size_t outputSize, mscclpp::DataType dtype, mscclpp::ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras) { + const std::unordered_map& extras, + [[maybe_unused]] mscclpp::DataType accumDtype) { return self->kernelFunc(ctx, input, output, inputSize, dtype, stream); }, // Context initialization function diff --git a/docs/quickstart.md b/docs/quickstart.md index e0a383b7..c9c98128 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -31,7 +31,7 @@ ``` If you don't want to build Python module, you need to set `-DMSCCLPP_BUILD_PYTHON_BINDINGS=OFF` in your `cmake` command (see details in [Install from Source](#install-from-source)). * (Optional, for benchmarks) MPI - * (Optional, for NVIDIA platforms) [GDRCopy](https://github.com/NVIDIA/gdrcopy) >= 2.5.0 + * (Optional, for NVIDIA platforms) [GDRCopy](https://github.com/NVIDIA/gdrcopy) >= 2.5.1 * GDRCopy is required for IB `HostNoAtomic` mode, which uses CPU-side signal forwarding to GPU memory via BAR1 mappings. This mode is used on platforms where RDMA atomics are not available (e.g., when using Data Direct Virtual Functions). * Install GDRCopy from source or via packages. See the [GDRCopy installation guide](https://github.com/NVIDIA/gdrcopy#installation). * Others diff --git a/examples/customized-collective-algorithm/customized_allgather.cu b/examples/customized-collective-algorithm/customized_allgather.cu index e78c4777..02df3685 100644 --- a/examples/customized-collective-algorithm/customized_allgather.cu +++ b/examples/customized-collective-algorithm/customized_allgather.cu @@ -101,7 +101,8 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder { "allgather", "allgather", [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, size_t outputSize, mscclpp::DataType dtype, [[maybe_unused]] mscclpp::ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) { + int nThreadsPerBlock, const std::unordered_map& extras, + [[maybe_unused]] mscclpp::DataType accumDtype) { return self->allgatherKernelFunc(ctx, input, output, inputSize, stream); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, diff --git a/examples/torch-integration/customized_allgather.cu b/examples/torch-integration/customized_allgather.cu index d48c4410..907b3ada 100644 --- a/examples/torch-integration/customized_allgather.cu +++ b/examples/torch-integration/customized_allgather.cu @@ -69,7 +69,8 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder { "allgather", "allgather", [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, size_t outputSize, mscclpp::DataType dtype, [[maybe_unused]] mscclpp::ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) { + int nThreadsPerBlock, const std::unordered_map& extras, + [[maybe_unused]] mscclpp::DataType accumDtype) { return self->allgatherKernelFunc(ctx, input, output, inputSize, dtype, stream); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, diff --git a/examples/torch-integration/customized_comm_with_tuning.py b/examples/torch-integration/customized_comm_with_tuning.py index 41be5825..060a0097 100644 --- a/examples/torch-integration/customized_comm_with_tuning.py +++ b/examples/torch-integration/customized_comm_with_tuning.py @@ -1,193 +1,117 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# MSCCLPP_MASTER_ADDR= MSCCLPP_MASTER_PORT= torchrun --nnodes=1 --nproc_per_node=8 customized_comm_with_tuning.py +# torchrun --nnodes=1 --nproc_per_node=8 examples/torch-integration/customized_comm_with_tuning.py import os -import torch -import mscclpp.utils as mscclpp_utils -import mscclpp -import mscclpp.ext -import netifaces as ni import ipaddress +import netifaces as ni +import torch +import mscclpp +import mscclpp.ext +import mscclpp.utils as mscclpp_utils -def load_algorithms(scratch_buffer: torch.tensor, rank: int) -> mscclpp.AlgorithmCollection: - collection_builder = mscclpp.ext.AlgorithmCollectionBuilder() - return collection_builder.build_default_algorithms( - scratch_buffer=scratch_buffer.data_ptr(), scratch_buffer_size=scratch_buffer.nbytes, rank=rank +# -- Helpers ------------------------------------------------------------------ + + +def _make_tensor(size_bytes: int, dtype: torch.dtype) -> torch.Tensor: + """Allocate a tensor backed by RawGpuBuffer (symmetric memory).""" + # PyTorch's from_dlpack does not support certain float8 DLPack type codes. + # Work around by importing as uint8 and reinterpreting via .view(). + _DLPACK_UNSUPPORTED = (torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz) + if dtype in _DLPACK_UNSUPPORTED: + dlpack = mscclpp.RawGpuBuffer(size_bytes).to_dlpack(data_type=str(torch.uint8)) + return torch.utils.dlpack.from_dlpack(dlpack).view(dtype) + dlpack = mscclpp.RawGpuBuffer(size_bytes).to_dlpack(data_type=str(dtype)) + return torch.utils.dlpack.from_dlpack(dlpack) + + +def _load_algorithms(scratch: torch.Tensor, rank: int): + return mscclpp.ext.AlgorithmCollectionBuilder().build_default_algorithms( + scratch_buffer=scratch.data_ptr(), + scratch_buffer_size=scratch.nbytes, + rank=rank, ) -def interfaces_for_ip_netifaces(ip: str): +def _interfaces_for_ip(ip: str): target = ipaddress.ip_address(ip) - for interface in ni.interfaces(): - addresses = ni.ifaddresses(interface) - if ni.AF_INET in addresses: - for link in addresses[ni.AF_INET]: - if "addr" in link: - addr = ipaddress.ip_address(link["addr"]) - if addr == target: - return interface + for iface in ni.interfaces(): + addrs = ni.ifaddresses(iface) + if ni.AF_INET in addrs: + for link in addrs[ni.AF_INET]: + if "addr" in link and ipaddress.ip_address(link["addr"]) == target: + return iface return None -def to_mscclpp_reduce_op(op: torch.distributed.ReduceOp) -> mscclpp.ReduceOp: +def _to_mscclpp_op(op) -> mscclpp.ReduceOp: if op == torch.distributed.ReduceOp.SUM: return mscclpp.ReduceOp.SUM - elif op == torch.distributed.ReduceOp.MIN: + if op == torch.distributed.ReduceOp.MIN: return mscclpp.ReduceOp.MIN - else: - raise ValueError(f"unsupported op: {op}") + raise ValueError(f"unsupported op: {op}") + + +def _round_pow2(size: int) -> int: + """Round up to next power-of-2, clamped to [1024, 256 MB].""" + size = max(size, 1024) + size = min(size, 256 << 20) + return 1 << (size - 1).bit_length() + + +# -- CustomizedComm ----------------------------------------------------------- class CustomizedComm: - def __init__(self, comm: mscclpp.CommGroup): + """Exposes all_reduce, all_gather, barrier with lazy per-size tuning.""" + + _TUNE_N_WARMUP = 5 + _TUNE_N_GRAPH_LAUNCHES = 10 + _TUNE_N_OPS_PER_GRAPH = 100 + _CANDIDATE_NBLOCKS = [4, 8, 16, 24, 32, 48, 64, 128] + _CANDIDATE_NTHREADS = [512, 768, 1024] + _NBLOCKS_LIMIT = { + "default_allreduce_nvls_packet": 16, + "default_allreduce_packet": 56, + "default_allreduce_allpair_packet": 56, + "default_allreduce_fullmesh": 64, + "default_allgather_fullmesh2": 32, + } + + def __init__(self, comm: mscclpp.CommGroup, symmetric_memory: bool = False): self.comm = comm self.rank = comm.my_rank self.world_size = comm.nranks - self.local_rank = comm.my_rank % comm.nranks_per_node - self.n_ranks_per_node = comm.nranks_per_node - dlpack = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(torch.float16)) - self.scratch_buffer = torch.utils.dlpack.from_dlpack(dlpack) - algorithms = load_algorithms(scratch_buffer=self.scratch_buffer, rank=self.rank) - self._algorithm_nvls_packet = [ - algo - for algo in algorithms - if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_packet" - ][0] - self._algorithm_rsag_zero_copy = [ - algo - for algo in algorithms - if algo.collective == "allreduce" and algo.name == "default_allreduce_rsag_zero_copy" - ][0] - self._algorithm_packet = [ - algo for algo in algorithms if algo.collective == "allreduce" and algo.name == "default_allreduce_packet" - ][0] - if mscclpp.is_nvls_supported(): - self._algorithm_nvls_zero_copy = [ - algo - for algo in algorithms - if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_zero_copy" - ][0] - self._tune(n_warmup=5, n_graph_launches=10, n_ops_per_graph=100) + self.symmetric_memory = symmetric_memory + self._nvls = mscclpp.is_nvls_supported() - def _tune(self, n_warmup, n_graph_launches, n_ops_per_graph): - sizes = [1 << i for i in range(10, 28)] - # Pre-fill with defaults for barrier - self.best_configs = {1024: (self._algorithm_nvls_packet, 0, 0)} + self._scratch = _make_tensor(1 << 27, torch.float16) + self._barrier_tensor = _make_tensor(4096, torch.float32) - tune_tensor = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(torch.float16)) - tune_tensor = torch.utils.dlpack.from_dlpack(tune_tensor) - tune_tensor.normal_() - candidates_nblocks = [4, 8, 16, 24, 32, 48, 64, 128] - candidates_nthreads = [512, 768, 1024] + algos = _load_algorithms(self._scratch, self.rank) + self._algos = {(a.collective, a.name): a for a in algos} - for size in sizes: - algos = [] - if mscclpp.is_nvls_supported(): - algos.append(self._algorithm_nvls_zero_copy) - if size <= 4 * 1024 * 1024: - algos.append(self._algorithm_nvls_packet) - algos.append(self._algorithm_packet) - if size >= 512 * 1024: - algos.append(self._algorithm_rsag_zero_copy) + # {collective: {rounded_size: (algo, nblocks, nthreads)}} + self._tune_cache: dict[str, dict[int, tuple]] = {"allreduce": {}, "allgather": {}} + self._tune_buf = None + self._time_buf = None - best_time = float("inf") - best_config = None + def _algo(self, collective: str, name: str): + return self._algos.get((collective, name)) - for algo in algos: - for nb in candidates_nblocks: - if algo.name == "default_allreduce_nvls_packet" and nb > 16: - continue - if algo.name == "default_allreduce_packet" and nb > 56: - continue - for nt in candidates_nthreads: - if self._run_algo(algo, tune_tensor, size, nb, nt) != 0: - continue + def _default_ar_config(self): + """Fallback allreduce config for barrier / timing sync.""" + pkt = self._algo("allreduce", "default_allreduce_nvls_packet") + if self._nvls and pkt: + return (pkt, 0, 0) + return (self._algo("allreduce", "default_allreduce_packet"), 0, 0) - for _ in range(n_warmup): - self._run_algo(algo, tune_tensor, size, nb, nt) - self.barrier() + # -- low-level execute -- - capture_stream = torch.cuda.Stream() - capture_stream.wait_stream(torch.cuda.current_stream()) - - g = torch.cuda.CUDAGraph() - # Warmup on capture stream - with torch.cuda.stream(capture_stream): - self._run_algo(algo, tune_tensor, size, nb, nt) - capture_stream.synchronize() - - with torch.cuda.graph(g, stream=capture_stream): - for _ in range(n_ops_per_graph): - self._run_algo(algo, tune_tensor, size, nb, nt) - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record(capture_stream) - with torch.cuda.stream(capture_stream): - for _ in range(n_graph_launches): - g.replay() - end_event.record(capture_stream) - end_event.synchronize() - - elapsed = start_event.elapsed_time(end_event) - - # Synchronize timing results across all ranks to ensure consistent algorithm selection - # replicate n times such due to algo limitations - time_tensor = torch.full((self.world_size,), elapsed, dtype=torch.float64, device="cuda").to( - dtype=torch.float32 - ) - torch.cuda.current_stream().wait_stream(capture_stream) - # TODO: use all_reduce may cause problem if the time elapsed between different algos are too close. - # May change to broadcast in the future if that becomes an issue. - self.all_reduce(time_tensor, op=torch.distributed.ReduceOp.SUM) - avg_time = time_tensor[self.rank].item() / self.world_size - - if avg_time < best_time: - best_time = avg_time - best_config = (algo, nb, nt) - - if best_config: - self.best_configs[size] = best_config - if self.rank == 0: - print( - f"Size {size}: Best Algo {best_config[0].name} nblocks {best_config[1]} nthreads {best_config[2]} Time {(best_time/(n_graph_launches * n_ops_per_graph))*1000:.2f} us" - ) - # reset the algorithms after tuning - torch.cuda.synchronize() - for algo in algos: - algo.reset() - - def _run_algo(self, algo: mscclpp.Algorithm, tensor, size, nblocks, nthreads): - return algo.execute( - comm=self.comm.communicator, - input_buffer=tensor.data_ptr(), - output_buffer=tensor.data_ptr(), - input_size=size, - output_size=size, - dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(tensor.dtype), - op=mscclpp.ReduceOp.SUM, - stream=torch.cuda.current_stream().cuda_stream, - nblocks=nblocks, - nthreads_per_block=nthreads, - symmetric_memory=True, - ) - - def get_tuned_config(self, size): - if size < 1024: - target_size = 1024 - elif size > 256 * 1024 * 1024: - target_size = 256 * 1024 * 1024 - else: - target_size = 1 << (size - 1).bit_length() - return self.best_configs.get(target_size) - - def all_reduce(self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM, stream: torch.cuda.Stream = None): - assert op == torch.distributed.ReduceOp.SUM - config = self.get_tuned_config(tensor.nbytes) - algo, nblocks, nthreads = config if config else (self._algorithm_nvls_packet, 0, 0) + def _exec_ar(self, tensor, algo, nb, nt, op=mscclpp.ReduceOp.SUM, stream=None, accum_dtype=None, sym=True): + s = stream.cuda_stream if stream else torch.cuda.current_stream().cuda_stream ret = algo.execute( comm=self.comm.communicator, input_buffer=tensor.data_ptr(), @@ -195,107 +119,357 @@ class CustomizedComm: input_size=tensor.nbytes, output_size=tensor.nbytes, dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(tensor.dtype), - op=to_mscclpp_reduce_op(op), - stream=stream.cuda_stream if stream is not None else torch.cuda.current_stream().cuda_stream, - nblocks=nblocks, - nthreads_per_block=nthreads, - symmetric_memory=True, + op=op, + stream=s, + nblocks=nb, + nthreads_per_block=nt, + symmetric_memory=sym, + accum_dtype=accum_dtype, ) if ret != 0: - print(f"Rank {self.rank}: Algo {algo.name} failed with error {ret}") + print(f"Rank {self.rank}: {algo.name} failed ({ret})") + return ret + + def _exec_ag(self, inp, out, algo, nb, nt, stream=None, sym=None): + if sym is None: + sym = self.symmetric_memory + s = stream.cuda_stream if stream else torch.cuda.current_stream().cuda_stream + ret = algo.execute( + comm=self.comm.communicator, + input_buffer=inp.data_ptr(), + output_buffer=out.data_ptr(), + input_size=inp.nbytes, + output_size=out.nbytes, + dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(inp.dtype), + op=mscclpp.ReduceOp.NOP, + stream=s, + nblocks=nb, + nthreads_per_block=nt, + symmetric_memory=sym, + ) + if ret != 0: + print(f"Rank {self.rank}: AG {algo.name} failed ({ret})") + return ret + + def _barrier_internal(self): + a, nb, nt = self._default_ar_config() + self._exec_ar(self._barrier_tensor, a, nb, nt, sym=True) + + # -- lazy tuning -- + + def _ensure_tune_bufs(self): + if self._tune_buf is None: + self._tune_buf = _make_tensor(1 << 27, torch.float16) + self._tune_buf.normal_() + self._time_buf = _make_tensor(4096, torch.float32) + return self._tune_buf + + def _ar_candidates(self, size: int): + out = [] + if size <= 4 << 20: + a = self._algo("allreduce", "default_allreduce_nvls_packet") + if self._nvls and a: + out.append(a) + a = self._algo("allreduce", "default_allreduce_packet") + if a: + out.append(a) + a = self._algo("allreduce", "default_allreduce_allpair_packet") + if a: + out.append(a) + if size >= 512 << 10: + a = self._algo("allreduce", "default_allreduce_nvls_zero_copy") + if self._nvls and self.symmetric_memory and a: + out.append(a) + a = self._algo("allreduce", "default_allreduce_rsag_zero_copy") + if a: + out.append(a) + if torch.version.hip is not None: + a = self._algo("allreduce", "default_allreduce_fullmesh") + if a: + out.append(a) + return out + + def _ag_candidates(self): + a = self._algo("allgather", "default_allgather_fullmesh2") + return [a] if a else [] + + def _run_tune(self, collective, algo, buf, size, nb, nt): + """Single tune invocation for either collective.""" + if collective == "allreduce": + return algo.execute( + comm=self.comm.communicator, + input_buffer=buf.data_ptr(), + output_buffer=buf.data_ptr(), + input_size=size, + output_size=size, + dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(buf.dtype), + op=mscclpp.ReduceOp.SUM, + stream=torch.cuda.current_stream().cuda_stream, + nblocks=nb, + nthreads_per_block=nt, + symmetric_memory=True, + ) + else: + total = size * self.world_size + out_ptr = buf.data_ptr() + return algo.execute( + comm=self.comm.communicator, + input_buffer=out_ptr + self.rank * size, + output_buffer=out_ptr, + input_size=size, + output_size=total, + dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(buf.dtype), + op=mscclpp.ReduceOp.NOP, + stream=torch.cuda.current_stream().cuda_stream, + nblocks=nb, + nthreads_per_block=nt, + symmetric_memory=False, + ) + + def _tune_size(self, collective: str, target_size: int): + """Auto-tune one (collective, target_size) pair and cache result.""" + buf = self._ensure_tune_bufs() + cands = self._ar_candidates(target_size) if collective == "allreduce" else self._ag_candidates() + + best_time, best_cfg = float("inf"), None + used = set() + run = lambda a, nb, nt: self._run_tune(collective, a, buf, target_size, nb, nt) + + for algo in cands: + nb_limit = self._NBLOCKS_LIMIT.get(algo.name, 128) + for nb in self._CANDIDATE_NBLOCKS: + if nb > nb_limit: + continue + for nt in self._CANDIDATE_NTHREADS: + # Feasibility — sync result across ranks so all agree + ret = run(algo, nb, nt) + torch.cuda.synchronize() + self._time_buf[0] = float(ret) + self._exec_ar(self._time_buf[:1], *self._default_ar_config(), sym=True) + if self._time_buf[0].item() != 0: + continue + used.add(algo) + + # Warmup + for _ in range(self._TUNE_N_WARMUP): + run(algo, nb, nt) + + # CUDA-graph timed benchmark + cs = torch.cuda.Stream() + cs.wait_stream(torch.cuda.current_stream()) + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, stream=cs): + for _ in range(self._TUNE_N_OPS_PER_GRAPH): + run(algo, nb, nt) + + start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + start.record(cs) + with torch.cuda.stream(cs): + for _ in range(self._TUNE_N_GRAPH_LAUNCHES): + g.replay() + end.record(cs) + end.synchronize() + elapsed = start.elapsed_time(end) + + # Cross-rank timing sync + self._time_buf.fill_(elapsed) + torch.cuda.current_stream().wait_stream(cs) + self._exec_ar(self._time_buf, *self._default_ar_config(), sym=True) + avg = self._time_buf[self.rank].item() / self.world_size + + if avg < best_time: + best_time, best_cfg = avg, (algo, nb, nt) + + if best_cfg: + self._tune_cache[collective][target_size] = best_cfg + if self.rank == 0: + n = self._TUNE_N_GRAPH_LAUNCHES * self._TUNE_N_OPS_PER_GRAPH + print( + f"[tune] {collective} size={target_size}: {best_cfg[0].name} " + f"nb={best_cfg[1]} nt={best_cfg[2]} time={best_time / n * 1000:.2f}us", + flush=True, + ) + else: + fb = ( + self._default_ar_config() + if collective == "allreduce" + else ((self._ag_candidates()[0], 32, 512) if self._ag_candidates() else None) + ) + self._tune_cache[collective][target_size] = fb + + torch.cuda.synchronize() + self._barrier_internal() + for a in used: + a.reset() + + # -- public API -- + + def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, stream=None, accum_dtype=None): + sz = _round_pow2(tensor.nbytes) + if sz not in self._tune_cache["allreduce"]: + self._tune_size("allreduce", sz) + a, nb, nt = self._tune_cache["allreduce"][sz] + self._exec_ar( + tensor, a, nb, nt, op=_to_mscclpp_op(op), stream=stream, accum_dtype=accum_dtype, sym=self.symmetric_memory + ) + + def all_gather(self, output_tensor, input_tensor, stream=None): + sz = _round_pow2(input_tensor.nbytes) + if sz not in self._tune_cache["allgather"]: + self._tune_size("allgather", sz) + a, nb, nt = self._tune_cache["allgather"][sz] + self._exec_ag(input_tensor, output_tensor, a, nb, nt, stream=stream, sym=self.symmetric_memory) def barrier(self): - tensor = torch.empty(self.world_size, dtype=torch.float, device=torch.device("cuda")) - self.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM, stream=torch.cuda.current_stream()) - - def benchmark(self, n_warmup=10, n_graph_launches=10, n_iter_per_graph=100): - low = 5 * 1024 - high = 80 * 1024 * 1024 - sizes = [] - curr = low - while curr <= high: - sizes.append(curr) - curr *= 2 - - if self.rank == 0: - print(f"{'Size (Bytes)':<20} {'Time (us)':<20} {'AlgoBW (GB/s)':<20}") - - dtype = torch.float16 - capture_stream = torch.cuda.Stream() - - # Allocate a single large RawGpuBuffer (symmetric memory) and reuse it for all sizes. - # Cannot allocate per-size tensors with symmetric memory. - bench_buf = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(dtype)) - bench_buf = torch.utils.dlpack.from_dlpack(bench_buf) - bench_buf.normal_() - - for size in sizes: - n_elements = size // bench_buf.element_size() - tensor = bench_buf[:n_elements] - - capture_stream.wait_stream(torch.cuda.current_stream()) - # Capture Graph - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g, stream=capture_stream): - for _ in range(n_iter_per_graph): - self.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) - - # warmup: Execute the graph once to prime the driver - with torch.cuda.stream(capture_stream): - for _ in range(n_warmup): - g.replay() - self.barrier() - capture_stream.synchronize() - - # Benchmark - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record(capture_stream) - with torch.cuda.stream(capture_stream): - for _ in range(n_graph_launches): - g.replay() - end_event.record(capture_stream) - end_event.synchronize() - - # Get elapsed time in milliseconds - elapsed_ms = start_event.elapsed_time(end_event) - avg_time_ms = elapsed_ms / (n_graph_launches * n_iter_per_graph) - time_us = avg_time_ms * 1000 - - alg_bw = size / (avg_time_ms * 1e-3) if avg_time_ms > 0 else 0 - if self.rank == 0: - print(f"{size:<20} {time_us:<20.2f} {alg_bw / 1e9:<20.2f}") + self._barrier_internal() def destroy(self): - self._algorithm_nvls_nonzero_copy = None - self._algorithm_nvls_packet = None - self.scratch_buffer = None - self.comm = None + self._algos.clear() + self._tune_cache = {"allreduce": {}, "allgather": {}} + self._tune_buf = self._time_buf = self._barrier_tensor = self._scratch = self.comm = None -def init_dist() -> CustomizedComm: - rank = int(os.environ["RANK"]) - world = int(os.environ["WORLD_SIZE"]) - master_addr = os.environ["MSCCLPP_MASTER_ADDR"] - master_port = os.environ["MSCCLPP_MASTER_PORT"] - interface = interfaces_for_ip_netifaces(master_addr) - if interface is None: - raise ValueError(f"Cannot find network interface for IP address {master_addr}") - interfaceIpPortTrio = f"{interface}:{master_addr}:{master_port}" - mscclpp_group = mscclpp.CommGroup(interfaceIpPortTrio=interfaceIpPortTrio, rank=rank, size=world) - return CustomizedComm(mscclpp_group) +# -- Benchmarks (standalone) -------------------------------------------------- + + +def _bench_sizes(low=5 * 1024, high=80 << 20): + sizes, c = [], low + while c <= high: + sizes.append(c) + c *= 2 + return sizes + + +def benchmark_allreduce( + comm: CustomizedComm, dtype=torch.float16, accum_dtype=None, n_warmup=10, n_graph_launches=10, n_iter=100 +): + sizes = _bench_sizes() + if comm.rank == 0: + print(f"\n{'='*60}\nAllreduce Benchmark\n{'='*60}") + print(f"{'Nelements':<18} {'Size(B)':<18} {'Time(us)':<18} {'AlgoBW(GB/s)':<18}") + + cs = torch.cuda.Stream() + buf = _make_tensor(1 << 27, dtype) + buf.normal_() if dtype in (torch.float16, torch.float32, torch.bfloat16) else buf.fill_(0) + + for size in sizes: + nelems = size // buf.element_size() + t = buf[: size // buf.element_size()] + comm.all_reduce(t, accum_dtype=accum_dtype) + torch.cuda.synchronize() + + cs.wait_stream(torch.cuda.current_stream()) + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, stream=cs): + for _ in range(n_iter): + comm.all_reduce(t, accum_dtype=accum_dtype) + with torch.cuda.stream(cs): + for _ in range(n_warmup): + g.replay() + comm.barrier() + cs.synchronize() + + s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + s.record(cs) + with torch.cuda.stream(cs): + for _ in range(n_graph_launches): + g.replay() + e.record(cs) + e.synchronize() + + ms = s.elapsed_time(e) / (n_graph_launches * n_iter) + if comm.rank == 0: + print(f"{nelems:<18} {size:<18} {ms*1000:<18.2f} {size/(ms*1e-3)/1e9:<18.2f}") + + +def benchmark_allgather(comm: CustomizedComm, dtype=torch.float16, n_warmup=10, n_graph_launches=10, n_iter=100): + sizes = _bench_sizes() + if comm.rank == 0: + print(f"\n{'='*60}\nAllgather Benchmark\n{'='*60}") + print(f"{'PerRank(B)':<18} {'Total(B)':<18} {'Time(us)':<18} {'AlgoBW(GB/s)':<18}") + + cs = torch.cuda.Stream() + buf = _make_tensor(1 << 27, dtype) + buf.normal_() if dtype in (torch.float16, torch.float32, torch.bfloat16) else buf.fill_(0) + + for prs in sizes: + total = prs * comm.world_size + if total > buf.nbytes: + break + nt = total // buf.element_size() + npr = prs // buf.element_size() + out = buf[:nt] + inp = out[comm.rank * npr : (comm.rank + 1) * npr] + + comm.all_gather(out, inp) + torch.cuda.synchronize() + + cs.wait_stream(torch.cuda.current_stream()) + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, stream=cs): + for _ in range(n_iter): + comm.all_gather(out, inp) + with torch.cuda.stream(cs): + for _ in range(n_warmup): + g.replay() + comm.barrier() + cs.synchronize() + + s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + s.record(cs) + with torch.cuda.stream(cs): + for _ in range(n_graph_launches): + g.replay() + e.record(cs) + e.synchronize() + + ms = s.elapsed_time(e) / (n_graph_launches * n_iter) + if comm.rank == 0: + print(f"{prs:<18} {total:<18} {ms*1000:<18.2f} {total/(ms*1e-3)/1e9:<18.2f}") + + +# -- Bootstrap & main --------------------------------------------------------- + + +def init_dist() -> mscclpp.CommGroup: + addr = os.environ.get("MSCCLPP_MASTER_ADDR") + if addr: + rank, world = int(os.environ["RANK"]), int(os.environ["WORLD_SIZE"]) + port = os.environ["MSCCLPP_MASTER_PORT"] + iface = _interfaces_for_ip(addr) + if not iface: + raise ValueError(f"No interface for {addr}") + return mscclpp.CommGroup(interfaceIpPortTrio=f"{iface}:{addr}:{port}", rank=rank, size=world) + import torch.distributed as dist + + dist.init_process_group(backend="gloo") + return mscclpp.CommGroup(torch_group=dist.group.WORLD) def main(): local = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(local) - comm = init_dist() - comm.benchmark(n_warmup=5, n_graph_launches=10, n_iter_per_graph=100) - comm.barrier() + + dtype_str = os.environ.get("DTYPE", "float16") + dtype = getattr(torch, dtype_str, torch.float16) + accum_map = {"float32": mscclpp.DataType.float32, "float16": mscclpp.DataType.float16} + accum_str = os.environ.get("ACCUM_DTYPE") + accum_dtype = accum_map.get(accum_str) if accum_str else None + + comm_group = init_dist() + cc = CustomizedComm(comm_group) + + print(f"rank {local} starting benchmarks with dtype={dtype} accum_dtype={accum_dtype}...") + benchmark_allreduce(cc, dtype=dtype, accum_dtype=accum_dtype) + cc.barrier() torch.cuda.synchronize() - comm.destroy() - print(f"rank {local} All-reduce operation completed successfully.") + + benchmark_allgather(cc, dtype=dtype) + cc.barrier() + torch.cuda.synchronize() + + cc.destroy() + print(f"rank {local} completed successfully.") if __name__ == "__main__": diff --git a/examples/torch-integration/dsl_with_nccl_api.py b/examples/torch-integration/dsl_with_nccl_api.py index 975d3749..5a4dd1c4 100644 --- a/examples/torch-integration/dsl_with_nccl_api.py +++ b/examples/torch-integration/dsl_with_nccl_api.py @@ -1,19 +1,20 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# LD_PRELOAD=/build/lib/nccl/libmscclpp_nccl.so torchrun --nnodes=1 --nproc_per_node=8 dsl_with_nccl_api.py +# LD_PRELOAD=/build/lib/libmscclpp_nccl.so torchrun --nnodes=1 --nproc_per_node=8 dsl_with_nccl_api.py import os from typing import Any, Dict import torch, torch.distributed as dist -import mscclpp +import mscclpp.ext from mscclpp.language.collectives import AllReduce from mscclpp.language.channel import SwitchChannel, MemoryChannel, BufferType, SyncType from mscclpp.language.program import CollectiveProgram from mscclpp.language.rank import Rank +from mscclpp.language.utils import AlgoSpec -def allreduce_nvls(spec: mscclpp.AlgoSpec) -> CollectiveProgram: +def allreduce_nvls(spec: AlgoSpec) -> CollectiveProgram: gpu_size = spec.world_size with CollectiveProgram.from_spec(spec) as program: # Creating Channels @@ -63,8 +64,8 @@ def allreduce_nvls(spec: mscclpp.AlgoSpec) -> CollectiveProgram: return program -def setup_plan(algo_collection_builder: mscclpp.AlgorithmCollectionBuilder, rank: int, world_size: int): - spec = mscclpp.AlgoSpec( +def setup_plan(algo_collection_builder: mscclpp.ext.AlgorithmCollectionBuilder, rank: int, world_size: int): + spec = AlgoSpec( name="allreduce_nvls", collective=AllReduce(8, 1, True), nranks_per_node=8, @@ -94,10 +95,10 @@ def init_dist(): rank = int(os.environ["RANK"]) world = int(os.environ["WORLD_SIZE"]) local = int(os.environ["LOCAL_RANK"]) - algorithm_collection_builder = mscclpp.AlgorithmCollectionBuilder() + algorithm_collection_builder = mscclpp.ext.AlgorithmCollectionBuilder() setup_plan(algorithm_collection_builder, rank, world) algorithm_collection_builder.set_algorithm_selector(selector) - dist.init_process_group(backend="nccl", device_id=local) + dist.init_process_group(backend="nccl", device_id=torch.device("cuda", local)) return rank, world, local diff --git a/include/mscclpp/algorithm.hpp b/include/mscclpp/algorithm.hpp index 65b1ab3c..531cb857 100644 --- a/include/mscclpp/algorithm.hpp +++ b/include/mscclpp/algorithm.hpp @@ -103,12 +103,14 @@ class Algorithm { /// @param nThreadsPerBlock Number of threads per block (0 for auto-selection). /// @param symmetricMemory Whether to use symmetric memory optimization. /// @param extras Additional parameters for algorithm-specific customization. + /// @param accumDtype Data type for accumulation during reduction. DataType::AUTO resolves to dtype. /// @return The result of the operation. virtual CommResult execute(std::shared_ptr comm, const void* input, void* output, size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, std::shared_ptr executor, int nBlocks = 0, int nThreadsPerBlock = 0, bool symmetricMemory = false, - const std::unordered_map& extras = {}) = 0; + const std::unordered_map& extras = {}, + DataType accumDtype = DataType::AUTO) = 0; /// Reset the algorithm state, clearing any cached contexts. virtual void reset() = 0; @@ -186,10 +188,11 @@ class NativeAlgorithm : public Algorithm { /// @param nBlocks Number of CUDA blocks. /// @param nThreadsPerBlock Number of threads per block. /// @param extras Additional algorithm-specific parameters. + /// @param accumDtype Data type for accumulation (resolved from input dtype if sentinel). /// @return The result of the operation. using KernelFunc = std::function, const void*, void*, size_t, size_t, DataType, ReduceOp, - cudaStream_t, int, int, const std::unordered_map&)>; + cudaStream_t, int, int, const std::unordered_map&, DataType)>; /// Function type for creating algorithm contexts. /// @param comm The communicator. @@ -233,8 +236,8 @@ class NativeAlgorithm : public Algorithm { CommResult execute(std::shared_ptr comm, const void* input, void* output, size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, std::shared_ptr executor, int nBlocks = 0, int nThreadsPerBlock = 0, - bool symmetricMemory = false, - const std::unordered_map& extras = {}) override; + bool symmetricMemory = false, const std::unordered_map& extras = {}, + DataType accumDtype = DataType::AUTO) override; const std::string& name() const override; const std::string& collective() const override; const std::pair& messageRange() const override; @@ -285,8 +288,8 @@ class DslAlgorithm : public Algorithm, public AlgorithmBuilder, public std::enab CommResult execute(std::shared_ptr comm, const void* input, void* output, size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, std::shared_ptr executor, int nBlocks = 0, int nThreadsPerBlock = 0, - bool symmetricMemory = false, - const std::unordered_map& extras = {}) override; + bool symmetricMemory = false, const std::unordered_map& extras = {}, + DataType accumDtype = DataType::AUTO) override; AlgorithmType type() const override { return AlgorithmType::DSL; } Constraint constraint() const override; void reset() override; diff --git a/include/mscclpp/gpu_data_types.hpp b/include/mscclpp/gpu_data_types.hpp index 1cecbea6..41bd5928 100644 --- a/include/mscclpp/gpu_data_types.hpp +++ b/include/mscclpp/gpu_data_types.hpp @@ -64,18 +64,151 @@ using __bfloat162 = __nv_bfloat162; #endif +/// Software float8 with 4 exponent bits, 3 mantissa bits, exponent bias = 15. +/// Format (MSB first): [sign:1][exponent:4][mantissa:3] +/// No infinities; exp=15 is NaN. Negative zero is NaN (fnuz convention). +/// Max finite value: 0.9375, min normal: ~6.1e-5, min subnormal: ~7.6e-6. +struct alignas(1) __fp8_e4m3b15 { + uint8_t __x; + + __fp8_e4m3b15() = default; + + /// Construct from raw bits (use __fp8_e4m3b15::fromRaw() for clarity). + MSCCLPP_HOST_DEVICE_INLINE explicit __fp8_e4m3b15(uint8_t raw) : __x(raw) {} + + /// Construct from float32 (explicit to avoid ambiguous conversion chains). + MSCCLPP_HOST_DEVICE_INLINE explicit __fp8_e4m3b15(float val) : __x(fromFloat(val)) {} + + /// Convert to float32. + MSCCLPP_HOST_DEVICE_INLINE operator float() const { return toFloat(__x); } + + /// Construct from a raw bit pattern without conversion. + static MSCCLPP_HOST_DEVICE_INLINE __fp8_e4m3b15 fromRaw(uint8_t bits) { + __fp8_e4m3b15 r; + r.__x = bits; + return r; + } + + private: + /// Decode fp8_e4m3b15 bits → float32. + /// + /// Uses bit manipulation through fp16 as intermediate, adapted from the Triton compiler. + /// fp8_e4m3b15 is identical to fp8_e4m3fn (NVIDIA) except exponent bias is 15 vs 7. + /// Algorithm: reinterpret fp8 bits into an fp16 bit pattern with exponent shifted by -8, + /// then convert fp16 → float32. + static MSCCLPP_HOST_DEVICE_INLINE float toFloat(uint8_t bits) { + // Handle special values: negative zero (0x80) → NaN, exponent=15 → NaN. + uint32_t exp = (bits >> 3) & 0xFu; + if (bits == 0x80 || exp == 15) { + union { + uint32_t u; + float f; + } nan_val = {0x7FC00000u}; + return nan_val.f; + } + if (bits == 0) return 0.0f; + + // Triton-style bit manipulation: fp8 → fp16 → fp32. + // fp8 layout: [S:1][E:4][M:3] (bias=15) + // fp16 layout: [S:1][E:5][M:10] (bias=15) + // + // Place fp8 in upper byte of fp16, then right-shift exponent+mantissa by 1 + // to convert E4 → E5 (both share bias=15). Sign bit stays at bit 15. + // Refer: + // https://github.com/triton-lang/triton/blob/cf34004b8a67d290a962da166f5aa2fc66751326/python/triton/language/extra/cuda/utils.py#L34 + uint16_t h = (uint16_t)bits << 8; // place fp8 in upper byte of fp16 + uint16_t sign16 = h & 0x8000u; // extract sign at fp16 position + uint16_t nosign = h & 0x7F00u; // exponent + mantissa (no sign) + uint16_t fp16_bits = sign16 | (nosign >> 1); // shift exponent right by 1 + + // For subnormals: when fp8 exponent=0, the above gives fp16 exponent=0 + // and fp16 mantissa = (fp8_mantissa << 7), which correctly represents + // the subnormal fp16 value since both share bias=15. + + // Convert fp16 bits to float via __half (works on host and device, CUDA and HIP). + union { + uint16_t u; + __half h; + } cvt = {fp16_bits}; + return __half2float(cvt.h); + } + + /// Encode float32 → fp8_e4m3b15 bits. + /// + /// Algorithm adapted from Triton: float32 → fp16 → bit-manipulate → fp8. + /// The key insight is to convert to fp16 first (which shares bias=15 with e4m3b15), + /// then pack the fp16 bits back into 8 bits by shifting the exponent left by 1. + static MSCCLPP_HOST_DEVICE_INLINE uint8_t fromFloat(float val) { + union { + float f; + uint32_t u; + } in = {val}; + + // NaN → 0x80 (negative-zero bit pattern = NaN in fnuz). + if ((in.u & 0x7F800000u) == 0x7F800000u && (in.u & 0x007FFFFFu) != 0) return 0x80u; + + // Convert float32 → fp16 bits via __half (works on host and device, CUDA and HIP). + __half h_val = __float2half_rn(val); + union { + __half h; + uint16_t u; + } cvt = {h_val}; + uint16_t fp16_bits = cvt.u; + + // Clamp absolute value to max finite e4m3b15: 0.9375 → fp16 = 0x3B80. + uint16_t abs_fp16 = fp16_bits & 0x7FFFu; + if (abs_fp16 > 0x3B80u) abs_fp16 = 0x3B80u; + + // Reconstruct with sign. + uint16_t sign16 = fp16_bits & 0x8000u; + + // Triton-style: fp16 → fp8. + // fp16 layout: [S:1][E:5][M:10] (bias=15) + // fp8 layout: [S:1][E:4][M:3] (bias=15) + // + // mad.lo.u32 a0, a0, 2, 0x00800080 → (abs_fp16 * 2 + 0x0080) + // This shifts left by 1 (undoing the right-shift in decode) and adds rounding bias. + // Then: lop3.b32 b0, $1, 0x80008000, a0, 0xea → (sign & 0x8000) | a0 + // Finally: prmt for byte extraction. + // + // Simplified for scalar: shift abs_fp16 left by 1, add rounding bias, take upper byte. + uint16_t adjusted = (uint16_t)(abs_fp16 * 2u + 0x0080u); + // The upper byte now contains [E:4][M:3][round_bit]. + // Combine with sign and extract. + uint16_t with_sign = sign16 | adjusted; + uint8_t result = (uint8_t)(with_sign >> 8); + + // Zero → 0x00 (ensure positive zero, not negative zero which is NaN). + if ((result & 0x7Fu) == 0) result = 0x00u; + + return result; + } +}; + +/// Packed 2x fp8_e4m3b15 storage. +struct alignas(2) __fp8x2_e4m3b15 { + uint16_t __x; +}; + +/// Packed 4x fp8_e4m3b15 storage. +struct alignas(4) __fp8x4_e4m3b15 { + uint32_t __x; +}; + namespace mscclpp { /// Data types supported by mscclpp operations. enum class DataType { - INT32, // 32-bit signed integer. - UINT32, // 32-bit unsigned integer. - FLOAT16, // IEEE 754 half precision. - FLOAT32, // IEEE 754 single precision. - BFLOAT16, // bfloat16 precision. - FLOAT8_E4M3, // float8 with E4M3 layout. - FLOAT8_E5M2, // float8 with E5M2 layout. - UINT8, // 8-bit unsigned integer. + INT32, // 32-bit signed integer. + UINT32, // 32-bit unsigned integer. + FLOAT16, // IEEE 754 half precision. + FLOAT32, // IEEE 754 single precision. + BFLOAT16, // bfloat16 precision. + FLOAT8_E4M3, // float8 with E4M3 layout. + FLOAT8_E5M2, // float8 with E5M2 layout. + UINT8, // 8-bit unsigned integer. + FLOAT8_E4M3B15, // float8 with E4M3 layout, bias=15 (software, no HW accel). + AUTO = 255, // Sentinel: resolve to the input dtype at runtime. }; /// Word array. @@ -97,6 +230,7 @@ struct alignas(Bytes) Words {}; template union alignas(sizeof(T) * N) VectorTypeImpl { static_assert(N > 0, "N must be greater than 0"); + static_assert(sizeof(StorageT) >= sizeof(T) * N, "StorageT must cover the full vector size"); T data[N]; Words words; @@ -127,13 +261,14 @@ union alignas(sizeof(T) * N) VectorTypeImpl { MSCCLPP_HOST_DEVICE_INLINE const T& operator[](int i) const { return data[i]; } }; -// Helper template to get the appropriate vector type for a given element type and count +// Helper template to get the appropriate vector type for a given element type and count. template struct VectorTypeHelper { - using type = - VectorTypeImpl>>; + static constexpr int Bytes = N * sizeof(T); + using type = VectorTypeImpl< + T, N, + std::conditional_t>>>>; }; /// Vector type - clean user interface (automatically selects appropriate storage type) @@ -170,6 +305,11 @@ DEFINE_VEC(bf16x4, __bfloat16, 4, uint2); DEFINE_VEC(f16x8, __half, 8, uint4); DEFINE_VEC(bf16x8, __bfloat16, 8, uint4); +// Aliases for large vector types (>16 bytes) where no native CUDA storage type exists. +using f32x8 = VectorType; +using f32x16 = VectorType; +using f16x16 = VectorType<__half, 16>; + #if defined(__FP8_TYPES_EXIST__) DEFINE_VEC(f8_e4m3x2, __fp8_e4m3, 2, __fp8x2_e4m3); DEFINE_VEC(f8_e4m3x4, __fp8_e4m3, 4, __fp8x4_e4m3); @@ -181,6 +321,12 @@ DEFINE_VEC(f8_e5m2x4, __fp8_e5m2, 4, __fp8x4_e5m2); DEFINE_VEC(f8_e5m2x8, __fp8_e5m2, 8, uint2); DEFINE_VEC(f8_e5m2x16, __fp8_e5m2, 16, uint4); #endif + +// fp8_e4m3b15 vectors (always available — software type, no HW dependency) +DEFINE_VEC(f8_e4m3b15x2, __fp8_e4m3b15, 2, __fp8x2_e4m3b15); +DEFINE_VEC(f8_e4m3b15x4, __fp8_e4m3b15, 4, __fp8x4_e4m3b15); +DEFINE_VEC(f8_e4m3b15x8, __fp8_e4m3b15, 8, uint2); +DEFINE_VEC(f8_e4m3b15x16, __fp8_e4m3b15, 16, uint4); #undef DEFINE_VEC #if defined(MSCCLPP_DEVICE_COMPILE) @@ -254,6 +400,21 @@ MSCCLPP_DEVICE_INLINE __fp8_e5m2 clip(__fp8_e5m2 val) { } #endif +// --- f32x2 arithmetic --- + +template +MSCCLPP_DEVICE_INLINE f32x2 operator+(const f32x2& a, const f32x2& b) { +#if defined(MSCCLPP_DEVICE_CUDA) && (__CUDA_ARCH__ >= 1000) + // Blackwell (SM 10.0+): packed float2 add in a single instruction. + return __fadd2_rn(a.storage, b.storage); +#else + f32x2 result; + result.data[0] = a.data[0] + b.data[0]; + result.data[1] = a.data[1] + b.data[1]; + return result; +#endif +} + template MSCCLPP_DEVICE_INLINE f16x2 operator+(const f16x2& a, const f16x2& b) { __half2 result; @@ -265,6 +426,18 @@ MSCCLPP_DEVICE_INLINE f16x2 operator+(const f16x2& a, const f16x2& b) { return result; } +template +MSCCLPP_DEVICE_INLINE f16x4 operator+(const f16x4& a, const f16x4& b) { + // Decompose into 2× packed __hadd2 (2 instructions instead of 4 scalar __hadd). + const f16x2* a2 = reinterpret_cast(&a); + const f16x2* b2 = reinterpret_cast(&b); + f16x4 result; + f16x2* r2 = reinterpret_cast(&result); + r2[0] = a2[0] + b2[0]; + r2[1] = a2[1] + b2[1]; + return result; +} + template MSCCLPP_DEVICE_INLINE bf16x2 operator+(const bf16x2& a, const bf16x2& b) { __bfloat162 result; @@ -449,6 +622,14 @@ MSCCLPP_DEVICE_INLINE T min(const T& a, const T& b) { return (a < b ? a : b); } +template <> +MSCCLPP_DEVICE_INLINE f32x2 min(const f32x2& a, const f32x2& b) { + f32x2 result; + result.data[0] = fminf(a.data[0], b.data[0]); + result.data[1] = fminf(a.data[1], b.data[1]); + return result; +} + template <> MSCCLPP_DEVICE_INLINE f16x2 min(const f16x2& a, const f16x2& b) { #if defined(MSCCLPP_DEVICE_HIP) @@ -489,6 +670,51 @@ MSCCLPP_DEVICE_INLINE u8x4 min(const u8x4& a, const u8x4& b) { #endif } +/// Convert a vector type From to vector type To. +/// Primary template with auto-decomposition: vectors with N > 4 elements decompose into x4 chunks, +/// vectors with N == 4 decompose into x2 chunks, enabling optimized x2/x4 specializations to be reached. +/// Specialized below for optimized FP8 conversion paths at x2/x4 level. +template +MSCCLPP_DEVICE_INLINE To to(const From& v) { + static_assert(To::Size == From::Size, "to: vector sizes must match"); + constexpr int N = From::Size; + + // Auto-decompose: N > 4 → split into x4 chunks + if constexpr (N > 4 && N % 4 == 0) { + constexpr int nChunks = N / 4; + using FromChunk = VectorType; + using ToChunk = VectorType; + const FromChunk* in = reinterpret_cast(&v); + To result; + ToChunk* out = reinterpret_cast(&result); +#pragma unroll + for (int c = 0; c < nChunks; ++c) { + out[c] = to(in[c]); + } + return result; + } + // Auto-decompose: N == 4 → split into 2x x2 chunks + else if constexpr (N == 4) { + using FromChunk = VectorType; + using ToChunk = VectorType; + const FromChunk* in = reinterpret_cast(&v); + To result; + ToChunk* out = reinterpret_cast(&result); + out[0] = to(in[0]); + out[1] = to(in[1]); + return result; + } + // Base case: element-wise conversion + else { + To result; +#pragma unroll + for (int i = 0; i < N; ++i) { + result.data[i] = static_cast(v.data[i]); + } + return result; + } +} + #if defined(__FP8_TYPES_EXIST__) template <> MSCCLPP_DEVICE_INLINE __fp8_e4m3 min(const __fp8_e4m3& a, const __fp8_e4m3& b) { @@ -551,7 +777,592 @@ MSCCLPP_DEVICE_INLINE f8_e5m2x4 min(const f8_e5m2x4& a, const f8_e5m2x4& b) { return result; } + +// --- f8_e4m3 -> f32 specializations --- + +/// f8_e4m3x2 -> f32x2. +/// NVIDIA: fp8 -> half (via __nv_cvt_fp8x2_to_halfraw2) -> float. +/// HIP gfx942: fp8 -> float (via __builtin_amdgcn_cvt_pk_f32_fp8). +template <> +MSCCLPP_DEVICE_INLINE f32x2 to(const f8_e4m3x2& v) { +#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + auto f = __builtin_amdgcn_cvt_pk_f32_fp8(v.storage.__x, 0); + f32x2 result; + result.data[0] = f[0]; + result.data[1] = f[1]; + return result; +#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900 + __half2_raw h2 = __nv_cvt_fp8x2_to_halfraw2(bit_cast<__nv_fp8x2_storage_t>(v.storage), __NV_E4M3); + f32x2 result; + result.data[0] = __half2float(bit_cast<__half>(h2.x)); + result.data[1] = __half2float(bit_cast<__half>(h2.y)); + return result; +#else + f32x2 result; + result.data[0] = float(v.data[0]); + result.data[1] = float(v.data[1]); + return result; +#endif +} + +/// f8_e4m3x4 -> f32x4. +template <> +MSCCLPP_DEVICE_INLINE f32x4 to(const f8_e4m3x4& v) { +#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + auto lo = __builtin_amdgcn_cvt_pk_f32_fp8(v.storage.__x, false); + auto hi = __builtin_amdgcn_cvt_pk_f32_fp8(v.storage.__x, true); + f32x4 result; + result.data[0] = lo[0]; + result.data[1] = lo[1]; + result.data[2] = hi[0]; + result.data[3] = hi[1]; + return result; +#else + const f8_e4m3x2* pair = reinterpret_cast(&v); + f32x2 lo = to(pair[0]); + f32x2 hi = to(pair[1]); + f32x4 result; + result.data[0] = lo.data[0]; + result.data[1] = lo.data[1]; + result.data[2] = hi.data[0]; + result.data[3] = hi.data[1]; + return result; +#endif +} + +// --- f8_e5m2 -> f32 specializations --- + +/// f8_e5m2x2 -> f32x2. +/// NVIDIA: fp8 -> half (via __nv_cvt_fp8x2_to_halfraw2) -> float. +/// HIP gfx942: bf8 -> float (via __builtin_amdgcn_cvt_pk_f32_bf8). +template <> +MSCCLPP_DEVICE_INLINE f32x2 to(const f8_e5m2x2& v) { +#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + auto f = __builtin_amdgcn_cvt_pk_f32_bf8(v.storage.__x, 0); + f32x2 result; + result.data[0] = f[0]; + result.data[1] = f[1]; + return result; +#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900 + __half2_raw h2 = __nv_cvt_fp8x2_to_halfraw2(bit_cast<__nv_fp8x2_storage_t>(v.storage), __NV_E5M2); + f32x2 result; + result.data[0] = __half2float(bit_cast<__half>(h2.x)); + result.data[1] = __half2float(bit_cast<__half>(h2.y)); + return result; +#else + f32x2 result; + result.data[0] = float(v.data[0]); + result.data[1] = float(v.data[1]); + return result; +#endif +} + +/// f8_e5m2x4 -> f32x4. +template <> +MSCCLPP_DEVICE_INLINE f32x4 to(const f8_e5m2x4& v) { +#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + auto lo = __builtin_amdgcn_cvt_pk_f32_bf8(v.storage.__x, false); + auto hi = __builtin_amdgcn_cvt_pk_f32_bf8(v.storage.__x, true); + f32x4 result; + result.data[0] = lo[0]; + result.data[1] = lo[1]; + result.data[2] = hi[0]; + result.data[3] = hi[1]; + return result; +#else + const f8_e5m2x2* pair = reinterpret_cast(&v); + f32x2 lo = to(pair[0]); + f32x2 hi = to(pair[1]); + f32x4 result; + result.data[0] = lo.data[0]; + result.data[1] = lo.data[1]; + result.data[2] = hi.data[0]; + result.data[3] = hi.data[1]; + return result; +#endif +} + +// --- f32 -> f8_e4m3 specializations (downcast) --- + +/// f32x2 -> f8_e4m3x2. +/// HIP gfx942: float -> fp8 (via __builtin_amdgcn_cvt_pk_fp8_f32). +/// NVIDIA SM90+: float -> half -> fp8 (via __nv_cvt_halfraw2_to_fp8x2). +/// NVIDIA pre-SM90: float -> half -> fp8 (via __nv_cvt_halfraw_to_fp8, element-wise). +template <> +MSCCLPP_DEVICE_INLINE f8_e4m3x2 to(const f32x2& v) { +#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + uint32_t packed = __builtin_amdgcn_cvt_pk_fp8_f32(v.data[0], v.data[1], 0, false); + return bit_cast(static_cast<__hip_fp8x2_storage_t>(packed)); +#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900 + __half2_raw h2; + h2.x = bit_cast(__float2half_rn(v.data[0])); + h2.y = bit_cast(__float2half_rn(v.data[1])); + __nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2(h2, __NV_SATFINITE, __NV_E4M3); + return bit_cast(fp8x2); +#elif defined(MSCCLPP_DEVICE_CUDA) + __half_raw h0, h1; + h0.x = bit_cast(__float2half_rn(v.data[0])); + h1.x = bit_cast(__float2half_rn(v.data[1])); + f8_e4m3x2 result; + result.data[0] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h0, __NV_SATFINITE, __NV_E4M3)); + result.data[1] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h1, __NV_SATFINITE, __NV_E4M3)); + return result; +#else + f8_e4m3x2 result; + result.data[0] = static_cast<__fp8_e4m3>(v.data[0]); + result.data[1] = static_cast<__fp8_e4m3>(v.data[1]); + return result; +#endif +} + +/// f32x4 -> f8_e4m3x4. +template <> +MSCCLPP_DEVICE_INLINE f8_e4m3x4 to(const f32x4& v) { +#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + uint32_t packed = __builtin_amdgcn_cvt_pk_fp8_f32(v.data[0], v.data[1], 0, false); + packed = __builtin_amdgcn_cvt_pk_fp8_f32(v.data[2], v.data[3], packed, true); + return bit_cast(packed); +#else + f32x2 lo, hi; + lo.data[0] = v.data[0]; + lo.data[1] = v.data[1]; + hi.data[0] = v.data[2]; + hi.data[1] = v.data[3]; + f8_e4m3x2 lo_fp8 = to(lo); + f8_e4m3x2 hi_fp8 = to(hi); + f8_e4m3x4 result; + result.data[0] = lo_fp8.data[0]; + result.data[1] = lo_fp8.data[1]; + result.data[2] = hi_fp8.data[0]; + result.data[3] = hi_fp8.data[1]; + return result; +#endif +} + +// --- f32 -> f8_e5m2 specializations (downcast) --- + +/// f32x2 -> f8_e5m2x2. +/// HIP gfx942: float -> bf8 (via __builtin_amdgcn_cvt_pk_bf8_f32). +/// NVIDIA SM90+: float -> half -> fp8 (via __nv_cvt_halfraw2_to_fp8x2 with __NV_E5M2). +/// NVIDIA pre-SM90: float -> half -> fp8 (via __nv_cvt_halfraw_to_fp8, element-wise). +template <> +MSCCLPP_DEVICE_INLINE f8_e5m2x2 to(const f32x2& v) { +#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + uint32_t packed = __builtin_amdgcn_cvt_pk_bf8_f32(v.data[0], v.data[1], 0, false); + return bit_cast(static_cast<__hip_fp8x2_storage_t>(packed)); +#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900 + __half2_raw h2; + h2.x = bit_cast(__float2half_rn(v.data[0])); + h2.y = bit_cast(__float2half_rn(v.data[1])); + __nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2(h2, __NV_SATFINITE, __NV_E5M2); + return bit_cast(fp8x2); +#elif defined(MSCCLPP_DEVICE_CUDA) + __half_raw h0, h1; + h0.x = bit_cast(__float2half_rn(v.data[0])); + h1.x = bit_cast(__float2half_rn(v.data[1])); + f8_e5m2x2 result; + result.data[0] = bit_cast<__fp8_e5m2>(__nv_cvt_halfraw_to_fp8(h0, __NV_SATFINITE, __NV_E5M2)); + result.data[1] = bit_cast<__fp8_e5m2>(__nv_cvt_halfraw_to_fp8(h1, __NV_SATFINITE, __NV_E5M2)); + return result; +#else + f8_e5m2x2 result; + result.data[0] = static_cast<__fp8_e5m2>(v.data[0]); + result.data[1] = static_cast<__fp8_e5m2>(v.data[1]); + return result; +#endif +} + +/// f32x4 -> f8_e5m2x4. +template <> +MSCCLPP_DEVICE_INLINE f8_e5m2x4 to(const f32x4& v) { +#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + uint32_t packed = __builtin_amdgcn_cvt_pk_bf8_f32(v.data[0], v.data[1], 0, false); + packed = __builtin_amdgcn_cvt_pk_bf8_f32(v.data[2], v.data[3], packed, true); + return bit_cast(packed); +#else + f32x2 lo, hi; + lo.data[0] = v.data[0]; + lo.data[1] = v.data[1]; + hi.data[0] = v.data[2]; + hi.data[1] = v.data[3]; + f8_e5m2x2 lo_fp8 = to(lo); + f8_e5m2x2 hi_fp8 = to(hi); + f8_e5m2x4 result; + result.data[0] = lo_fp8.data[0]; + result.data[1] = lo_fp8.data[1]; + result.data[2] = hi_fp8.data[0]; + result.data[3] = hi_fp8.data[1]; + return result; +#endif +} + +// --- f8_e4m3 <-> f16 conversion specializations --- + +/// f8_e4m3x2 -> f16x2. +/// NVIDIA SM90+: packed intrinsic (1 instruction). +/// HIP gfx942: fp8 -> float -> half (via AMD builtin). +/// Pre-SM90 / fallback: element-wise scalar conversion. +template <> +MSCCLPP_DEVICE_INLINE f16x2 to(const f8_e4m3x2& v) { +#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + auto f = __builtin_amdgcn_cvt_pk_f32_fp8(v.storage.__x, 0); + f16x2 result; + result.data[0] = __float2half(f[0]); + result.data[1] = __float2half(f[1]); + return result; +#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900 + __half2_raw h2 = __nv_cvt_fp8x2_to_halfraw2(bit_cast<__nv_fp8x2_storage_t>(v.storage), __NV_E4M3); + return bit_cast(h2); +#else + f16x2 result; + result.data[0] = static_cast<__half>(v.data[0]); + result.data[1] = static_cast<__half>(v.data[1]); + return result; +#endif +} + +/// f16x2 -> f8_e4m3x2. +/// NVIDIA SM90+: packed intrinsic (1 instruction). +/// HIP gfx942: half -> float -> fp8 (via AMD builtin). +/// Pre-SM90: element-wise scalar conversion. +template <> +MSCCLPP_DEVICE_INLINE f8_e4m3x2 to(const f16x2& v) { +#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + float f0 = __half2float(v.data[0]); + float f1 = __half2float(v.data[1]); + uint32_t packed = __builtin_amdgcn_cvt_pk_fp8_f32(f0, f1, 0, false); + return bit_cast(static_cast<__hip_fp8x2_storage_t>(packed)); +#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900 + __half2_raw h2 = bit_cast<__half2_raw>(v); + __nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2(h2, __NV_SATFINITE, __NV_E4M3); + return bit_cast(fp8x2); +#elif defined(MSCCLPP_DEVICE_CUDA) + __half_raw h0, h1; + h0.x = bit_cast(v.data[0]); + h1.x = bit_cast(v.data[1]); + f8_e4m3x2 result; + result.data[0] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h0, __NV_SATFINITE, __NV_E4M3)); + result.data[1] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h1, __NV_SATFINITE, __NV_E4M3)); + return result; +#else + f8_e4m3x2 result; + result.data[0] = static_cast<__fp8_e4m3>(v.data[0]); + result.data[1] = static_cast<__fp8_e4m3>(v.data[1]); + return result; +#endif +} + #endif // defined(__FP8_TYPES_EXIST__) + +// --- fp8_e4m3b15 <-> fp16 direct conversion specializations --- +// These are the PRIMARY conversions: fp8_b15 <-> fp16 is just a 1-bit exponent shift +// (E4 bias=15 <-> E5 bias=15), no precision loss since fp16 has 10 mantissa bits +// vs fp8's 3. fp32 conversions are derived by routing through fp16. + +/// f8_e4m3b15x2 -> f16x2. +/// Direct fp8 -> fp16 via branch-free bit manipulation. +template <> +MSCCLPP_DEVICE_INLINE f16x2 to(const f8_e4m3b15x2& v) { +#if defined(MSCCLPP_DEVICE_CUDA) + uint16_t in = v.storage.__x; + // Spread 2 fp8 bytes into packed fp16 pair, adjust exponent E4->E5. + uint32_t a0 = ((uint32_t)(in & 0xFFu) << 8) | ((uint32_t)(in >> 8) << 24); + uint32_t b0 = (a0 & 0x7f007f00u) >> 1; + uint32_t out0 = b0 | (a0 & 0x80008000u); + __half2 h; + asm("mov.b32 %0, %1;" : "=r"(*reinterpret_cast(&h)) : "r"(out0)); + return h; +#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + // gfx942: same bit manipulation as CUDA, store packed fp16 bits via words[]. + uint16_t in = v.storage.__x; + uint32_t a0 = ((uint32_t)(in & 0xFFu) << 8) | ((uint32_t)(in >> 8) << 24); + uint32_t b0 = (a0 & 0x7f007f00u) >> 1; + uint32_t out0 = b0 | (a0 & 0x80008000u); + f16x2 result; + result.words[0] = out0; + return result; +#else + f16x2 result; + result.data[0] = __float2half(float(v.data[0])); + result.data[1] = __float2half(float(v.data[1])); + return result; +#endif +} + +/// f8_e4m3b15x4 -> f16x4. +/// Uses __byte_perm + lop3 for branch-free vectorized conversion. +template <> +MSCCLPP_DEVICE_INLINE f16x4 to(const f8_e4m3b15x4& v) { +#if defined(MSCCLPP_DEVICE_CUDA) + uint32_t in = v.storage.__x; + uint32_t a0 = __byte_perm(0u, in, 0x5746u); + uint32_t a0_shr = a0 >> 1; + uint32_t a0_sign = a0 & 0x80008000u; + uint32_t out0; + asm("lop3.b32 %0, %1, %2, %3, 0xEA;" : "=r"(out0) : "r"(a0_shr), "r"(0x3f803f80u), "r"(a0_sign)); + uint32_t a1 = __byte_perm(a0, 0u, 0x2301u); + uint32_t a1_shr = a1 >> 1; + uint32_t a1_sign = a1 & 0x80008000u; + uint32_t out1; + asm("lop3.b32 %0, %1, %2, %3, 0xEA;" : "=r"(out1) : "r"(a1_shr), "r"(0x3f803f80u), "r"(a1_sign)); + f16x4 result; + asm("mov.b32 %0, %1;" : "=r"(result.words[0]) : "r"(out0)); + asm("mov.b32 %0, %1;" : "=r"(result.words[1]) : "r"(out1)); + return result; +#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + // gfx942: __byte_perm + bitwise E4→E5 shift (no lop3), store via words[]. + uint32_t in = v.storage.__x; + uint32_t a0 = __byte_perm(0u, in, 0x5746u); + uint32_t out0 = ((a0 >> 1) & 0x3f803f80u) | (a0 & 0x80008000u); + uint32_t a1 = __byte_perm(a0, 0u, 0x2301u); + uint32_t out1 = ((a1 >> 1) & 0x3f803f80u) | (a1 & 0x80008000u); + f16x4 result; + result.words[0] = out0; + result.words[1] = out1; + return result; +#else + f16x4 result; +#pragma unroll + for (int i = 0; i < 4; ++i) { + result.data[i] = __float2half(float(v.data[i])); + } + return result; +#endif +} + +/// f16x2 -> f8_e4m3b15x2. +/// Direct fp16 -> fp8 via clamp + exponent shift E5->E4 + pack. +template <> +MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to(const f16x2& v) { +#if defined(MSCCLPP_DEVICE_CUDA) + uint32_t in0; + asm("mov.b32 %0, %1;" : "=r"(in0) : "r"(*reinterpret_cast(&v))); + // Clamp abs to max finite e4m3b15 (0x3B80 = 0.9375 in fp16). + uint32_t lo = in0 & 0xFFFFu, hi = in0 >> 16; + uint32_t alo = lo & 0x7FFFu, ahi = hi & 0x7FFFu; + alo = alo < 0x3B80u ? alo : 0x3B80u; + ahi = ahi < 0x3B80u ? ahi : 0x3B80u; + uint32_t a0 = alo | (ahi << 16); + a0 = a0 * 2u + 0x00800080u; + uint32_t b0 = a0 | (in0 & 0x80008000u); + uint16_t packed = (uint16_t)(((b0 >> 8) & 0xFFu) | ((b0 >> 16) & 0xFF00u)); + return bit_cast(packed); +#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + // gfx942: read packed fp16 bits, clamp via v_pk_min_u16, shift E5→E4, pack. + uint32_t in0 = v.words[0]; + uint32_t abs0 = in0 & 0x7fff7fffu; + uint32_t a0; + asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3B803B80u)); + a0 = a0 * 2u + 0x00800080u; + uint32_t b0 = a0 | (in0 & 0x80008000u); + uint16_t packed = (uint16_t)(((b0 >> 8) & 0xFFu) | ((b0 >> 16) & 0xFF00u)); + return bit_cast(packed); +#else + f8_e4m3b15x2 result; + result.data[0] = __fp8_e4m3b15(__half2float(v.data[0])); + result.data[1] = __fp8_e4m3b15(__half2float(v.data[1])); + return result; +#endif +} + +/// f16x4 -> f8_e4m3b15x4. +/// Uses __vminu2 + lop3 + __byte_perm for branch-free vectorized conversion. +template <> +MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to(const f16x4& v) { +#if defined(MSCCLPP_DEVICE_CUDA) + uint32_t in0, in1; + asm("mov.b32 %0, %1;" : "=r"(in0) : "r"(v.words[0])); + asm("mov.b32 %0, %1;" : "=r"(in1) : "r"(v.words[1])); + uint32_t abs0 = in0 & 0x7fff7fffu; + uint32_t abs1 = in1 & 0x7fff7fffu; + uint32_t a0 = __vminu2(abs0, 0x3B803B80u); + uint32_t a1 = __vminu2(abs1, 0x3B803B80u); + a0 = a0 * 2u + 0x00800080u; + a1 = a1 * 2u + 0x00800080u; + uint32_t b0, b1; + asm("lop3.b32 %0, %1, %2, %3, 0xf8;" : "=r"(b0) : "r"(a0), "r"(in0), "r"(0x80008000u)); + asm("lop3.b32 %0, %1, %2, %3, 0xf8;" : "=r"(b1) : "r"(a1), "r"(in1), "r"(0x80008000u)); + uint32_t packed = __byte_perm(b0, b1, 0x7531u); + return bit_cast(packed); +#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + // gfx942: read packed fp16 bits, clamp via v_pk_min_u16, shift E5→E4, __byte_perm pack. + uint32_t in0 = v.words[0], in1 = v.words[1]; + uint32_t abs0 = in0 & 0x7fff7fffu, abs1 = in1 & 0x7fff7fffu; + uint32_t a0, a1; + asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3B803B80u)); + asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a1) : "v"(abs1), "v"(0x3B803B80u)); + a0 = a0 * 2u + 0x00800080u; + a1 = a1 * 2u + 0x00800080u; + uint32_t b0 = a0 | (in0 & 0x80008000u); + uint32_t b1 = a1 | (in1 & 0x80008000u); + uint32_t packed = __byte_perm(b0, b1, 0x7531u); + return bit_cast(packed); +#else + f8_e4m3b15x4 result; +#pragma unroll + for (int i = 0; i < 4; ++i) { + result.data[i] = __fp8_e4m3b15(__half2float(v.data[i])); + } + return result; +#endif +} + +// --- fp8_e4m3b15 <-> f32 conversion specializations (software, always available) --- + +/// f8_e4m3b15x2 -> f32x2. +/// Routes through fp16: fp8→fp16 (bit manip) then fp16→f32. +template <> +MSCCLPP_DEVICE_INLINE f32x2 to(const f8_e4m3b15x2& v) { +#if defined(MSCCLPP_DEVICE_CUDA) + f16x2 h = to(v); + float2 f2 = __half22float2(h); + return bit_cast(f2); +#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + f16x2 h = to(v); + f32x2 result; + result.data[0] = __half2float(h.data[0]); + result.data[1] = __half2float(h.data[1]); + return result; +#else + f32x2 result; + result.data[0] = float(v.data[0]); + result.data[1] = float(v.data[1]); + return result; +#endif +} + +/// f8_e4m3b15x4 -> f32x4. +/// Routes through fp16: fp8→fp16 (bit manip) then fp16→f32. +template <> +MSCCLPP_DEVICE_INLINE f32x4 to(const f8_e4m3b15x4& v) { +#if defined(MSCCLPP_DEVICE_CUDA) + f16x4 h = to(v); + __half2 h0, h1; + asm("mov.b32 %0, %1;" : "=r"(*reinterpret_cast(&h0)) : "r"(h.words[0])); + asm("mov.b32 %0, %1;" : "=r"(*reinterpret_cast(&h1)) : "r"(h.words[1])); + float2 f0 = __half22float2(h0); + float2 f1 = __half22float2(h1); + f32x4 result; + result.data[0] = f0.x; + result.data[1] = f0.y; + result.data[2] = f1.x; + result.data[3] = f1.y; + return result; +#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + f16x4 h = to(v); + f32x4 result; + result.data[0] = __half2float(h.data[0]); + result.data[1] = __half2float(h.data[1]); + result.data[2] = __half2float(h.data[2]); + result.data[3] = __half2float(h.data[3]); + return result; +#else + f32x4 result; +#pragma unroll + for (int i = 0; i < 4; ++i) { + result.data[i] = float(v.data[i]); + } + return result; +#endif +} + +/// f32x2 -> f8_e4m3b15x2. +/// Routes through fp16: f32→fp16 then fp16→fp8 (clamp + exponent shift + pack). +template <> +MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to(const f32x2& v) { +#if defined(MSCCLPP_DEVICE_CUDA) + float2 f2 = {v.data[0], v.data[1]}; + __half2 h = __float22half2_rn(f2); + return to(h); +#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + f16x2 h; + h.data[0] = __float2half_rn(v.data[0]); + h.data[1] = __float2half_rn(v.data[1]); + return to(h); +#else + f8_e4m3b15x2 result; + result.data[0] = __fp8_e4m3b15(v.data[0]); + result.data[1] = __fp8_e4m3b15(v.data[1]); + return result; +#endif +} + +/// f32x4 -> f8_e4m3b15x4. +/// Routes through fp16: f32→fp16 then fp16→fp8 (clamp + exponent shift + pack). +template <> +MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to(const f32x4& v) { +#if defined(MSCCLPP_DEVICE_CUDA) + float2 f01 = {v.data[0], v.data[1]}; + float2 f23 = {v.data[2], v.data[3]}; + __half2 h01 = __float22half2_rn(f01); + __half2 h23 = __float22half2_rn(f23); + f16x4 h; + asm("mov.b32 %0, %1;" : "=r"(h.words[0]) : "r"(*reinterpret_cast(&h01))); + asm("mov.b32 %0, %1;" : "=r"(h.words[1]) : "r"(*reinterpret_cast(&h23))); + return to(h); +#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + f16x4 h; + h.words[0] = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(v.data[0], v.data[1])); + h.words[1] = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(v.data[2], v.data[3])); + return to(h); +#else + f8_e4m3b15x4 result; +#pragma unroll + for (int i = 0; i < 4; ++i) { + result.data[i] = __fp8_e4m3b15(v.data[i]); + } + return result; +#endif +} + +// --- fp8_e4m3b15 arithmetic (software, always available) --- + +template +MSCCLPP_DEVICE_INLINE __fp8_e4m3b15 operator+(const __fp8_e4m3b15& a, const __fp8_e4m3b15& b) { + return __fp8_e4m3b15(float(a) + float(b)); +} + +template +MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 operator+(const f8_e4m3b15x2& a, const f8_e4m3b15x2& b) { + f8_e4m3b15x2 result; + result.data[0] = __fp8_e4m3b15(float(a.data[0]) + float(b.data[0])); + result.data[1] = __fp8_e4m3b15(float(a.data[1]) + float(b.data[1])); + return result; +} + +template +MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 operator+(const f8_e4m3b15x4& a, const f8_e4m3b15x4& b) { + f8_e4m3b15x4 result; +#pragma unroll + for (int i = 0; i < 4; ++i) { + result.data[i] = __fp8_e4m3b15(float(a.data[i]) + float(b.data[i])); + } + return result; +} + +// --- fp8_e4m3b15 min (software) --- + +template <> +MSCCLPP_DEVICE_INLINE __fp8_e4m3b15 min(const __fp8_e4m3b15& a, const __fp8_e4m3b15& b) { + return __fp8_e4m3b15(fminf(float(a), float(b))); +} + +MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 min(const f8_e4m3b15x2& a, const f8_e4m3b15x2& b) { + f8_e4m3b15x2 result; + result.data[0] = mscclpp::min(a.data[0], b.data[0]); + result.data[1] = mscclpp::min(a.data[1], b.data[1]); + return result; +} + +MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 min(const f8_e4m3b15x4& a, const f8_e4m3b15x4& b) { + f8_e4m3b15x4 result; +#pragma unroll + for (int i = 0; i < 4; ++i) { + result.data[i] = mscclpp::min(a.data[i], b.data[i]); + } + return result; +} + #endif // MSCCLPP_DEVICE_COMPILE } // namespace mscclpp diff --git a/python/csrc/CMakeLists.txt b/python/csrc/CMakeLists.txt index 8759201f..44fb150f 100644 --- a/python/csrc/CMakeLists.txt +++ b/python/csrc/CMakeLists.txt @@ -24,4 +24,7 @@ set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME _mscclpp) set_target_properties(mscclpp_py PROPERTIES INSTALL_RPATH "\$ORIGIN/lib") target_link_libraries(mscclpp_py PRIVATE dlpack mscclpp mscclpp_collectives ${GPU_LIBRARIES}) target_include_directories(mscclpp_py SYSTEM PRIVATE ${GPU_INCLUDE_DIRS}) +if(MSCCLPP_USE_ROCM) + target_compile_definitions(mscclpp_py PRIVATE MSCCLPP_USE_ROCM) +endif() install(TARGETS mscclpp_py LIBRARY DESTINATION .) diff --git a/python/csrc/algorithm.cpp b/python/csrc/algorithm.cpp index 1a93cbc0..1cb3f253 100644 --- a/python/csrc/algorithm.cpp +++ b/python/csrc/algorithm.cpp @@ -75,15 +75,17 @@ void register_algorithm(nb::module_& m) { [](Algorithm& self, std::shared_ptr comm, uintptr_t input, uintptr_t output, size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op, uintptr_t stream, std::shared_ptr executor, int nBlocks, int nThreadsPerBlock, bool symmetricMemory, - std::unordered_map extras) { + std::unordered_map extras, int32_t accumDtype) { return self.execute(comm, reinterpret_cast(input), reinterpret_cast(output), inputSize, outputSize, dtype, op, reinterpret_cast(stream), executor, - nBlocks, nThreadsPerBlock, symmetricMemory, extras); + nBlocks, nThreadsPerBlock, symmetricMemory, extras, + static_cast(accumDtype)); }, nb::arg("comm"), nb::arg("input"), nb::arg("output"), nb::arg("input_size"), nb::arg("output_size"), nb::arg("dtype"), nb::arg("op") = ReduceOp::NOP, nb::arg("stream") = 0, nb::arg("executor") = nullptr, nb::arg("n_blocks") = 0, nb::arg("n_threads_per_block") = 0, nb::arg("symmetric_memory") = false, - nb::arg("extras") = std::unordered_map()) + nb::arg("extras") = std::unordered_map(), + nb::arg("accum_dtype") = static_cast(DataType::AUTO)) .def("reset", &Algorithm::reset); nb::class_(algorithmClass, "Constraint") diff --git a/python/csrc/core_py.cpp b/python/csrc/core_py.cpp index 47d76ac4..b8649564 100644 --- a/python/csrc/core_py.cpp +++ b/python/csrc/core_py.cpp @@ -47,7 +47,8 @@ void register_core(nb::module_& m) { .value("bfloat16", DataType::BFLOAT16) .value("float8_e4m3", DataType::FLOAT8_E4M3) .value("float8_e5m2", DataType::FLOAT8_E5M2) - .value("uint8", DataType::UINT8); + .value("uint8", DataType::UINT8) + .value("float8_e4m3b15", DataType::FLOAT8_E4M3B15); nb::class_(m, "CppBootstrap") .def("get_rank", &Bootstrap::getRank) diff --git a/python/csrc/ext/algorithm_collection_builder_py.cpp b/python/csrc/ext/algorithm_collection_builder_py.cpp index be7f944e..4a3563d9 100644 --- a/python/csrc/ext/algorithm_collection_builder_py.cpp +++ b/python/csrc/ext/algorithm_collection_builder_py.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include diff --git a/python/csrc/gpu_utils_py.cpp b/python/csrc/gpu_utils_py.cpp index 6995756b..60880456 100644 --- a/python/csrc/gpu_utils_py.cpp +++ b/python/csrc/gpu_utils_py.cpp @@ -34,6 +34,19 @@ static DLDataType getDlType(std::string type) { return DLDataType{kDLBfloat, 16, 1}; } else if (type == "torch.float16") { return DLDataType{kDLFloat, 16, 1}; + } else if (type == "torch.float8_e4m3fn") { + return DLDataType{kDLFloat8_e4m3fn, 8, 1}; + } else if (type == "torch.float8_e4m3fnuz") { + return DLDataType{kDLFloat8_e4m3fnuz, 8, 1}; + } else if (type == "torch.float8_e5m2") { + return DLDataType{kDLFloat8_e5m2, 8, 1}; + } else if (type == "torch.float8_e5m2fnuz") { + return DLDataType{kDLFloat8_e5m2fnuz, 8, 1}; + } else if (type == "torch.uint8") { + return DLDataType{kDLUInt, 8, 1}; + } else if (type == "fp8_e4m3b15") { + // No standard DLPack code for fp8_e4m3b15; store as raw uint8 bytes. + return DLDataType{kDLUInt, 8, 1}; } else { throw Error("Unsupported type: " + type, ErrorCode::InvalidUsage); } diff --git a/python/mscclpp/_core/algorithm.py b/python/mscclpp/_core/algorithm.py index 744cf39e..f12a3027 100644 --- a/python/mscclpp/_core/algorithm.py +++ b/python/mscclpp/_core/algorithm.py @@ -177,6 +177,7 @@ class Algorithm: nthreads_per_block=0, symmetric_memory: bool = False, extras: Optional[Dict[str, int]] = None, + accum_dtype: Optional[CppDataType] = None, ) -> int: """Execute the collective algorithm. @@ -194,10 +195,14 @@ class Algorithm: nthreads_per_block: Number of threads per block (0 for auto-selection). symmetric_memory: Whether to use symmetric memory optimization (default: False). extras: Additional algorithm-specific parameters. + accum_dtype: Data type for accumulation during reduction. If None, defaults to + the same as dtype. Use DataType.float32 for high-precision FP8 accumulation. Returns: The result code (0 for success). """ + merged_extras = dict(extras) if extras is not None else {} + accum_dtype = accum_dtype if accum_dtype is not None else dtype return self._algorithm.execute( comm, int(input_buffer), @@ -211,7 +216,8 @@ class Algorithm: nblocks, nthreads_per_block, symmetric_memory, - extras if extras is not None else {}, + merged_extras, + int(accum_dtype), ) def reset(self): diff --git a/python/mscclpp/language/channel.py b/python/mscclpp/language/channel.py index 1b22e4e2..23d76eda 100644 --- a/python/mscclpp/language/channel.py +++ b/python/mscclpp/language/channel.py @@ -140,7 +140,7 @@ class MemoryChannel: for tb_id in tb_list: tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type) - tb_channel_ids = get_program().setup_channel(tb, self) + tb_channel_ids = get_program().setup_channel(tb_id, self) op = GetOperation( src_buff=[RemoteChunk(src_chunk.buffer, src_chunk.index, src_chunk.size, tb_chunk_id)], dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)], diff --git a/python/mscclpp/language/internal/operations.py b/python/mscclpp/language/internal/operations.py index 5f719c21..5fb392e3 100644 --- a/python/mscclpp/language/internal/operations.py +++ b/python/mscclpp/language/internal/operations.py @@ -745,7 +745,7 @@ class ReduceOperation(BaseOperation): remote_dst_buff=self.remote_dst_buff + other.dst_buff, channel_ids=self.channel_ids, put_channel_ids=self.put_channel_ids + other.channel_ids, - channel_type=self.channel_type, + channel_type=other.channel_type, reduce_operation=self.reduce_operation, tbg_info=self.tbg_info, packet=self.packet, diff --git a/python/requirements_rocm6.txt b/python/requirements_rocm6.txt index d2a3389b..7ed4fef3 100644 --- a/python/requirements_rocm6.txt +++ b/python/requirements_rocm6.txt @@ -1,5 +1,5 @@ -mpi4py==4.1.1 -cupy==13.6.0 +mpi4py +cupy prettytable netifaces pytest diff --git a/python/test/test_fp8_accum.py b/python/test/test_fp8_accum.py new file mode 100644 index 00000000..82981ce1 --- /dev/null +++ b/python/test/test_fp8_accum.py @@ -0,0 +1,397 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Correctness test for FP8 allreduce with different accumulation types. +# +# Verifies that FP8 allreduce with higher-precision accumulation produces +# results at least as accurate as native FP8 accumulation, by comparing +# against a float32 reference. +# +# Usage: +# mpirun -np 8 pytest python/test/test_fp8_accum.py -v + +import cupy as cp +import numpy as np +import pytest + +from mscclpp import CommGroup, GpuBuffer, DataType, ReduceOp, is_nvls_supported +from mscclpp.ext import AlgorithmCollectionBuilder +from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group + +# FP8 E4M3 (hardware) requires SM >= 89 (Ada / Hopper) on NVIDIA GPUs. +# On AMD/ROCm (e.g. MI300X), FP8 is supported natively — no skip needed. +_is_hip = hasattr(cp.cuda.runtime, "is_hip") and cp.cuda.runtime.is_hip +_skip_fp8 = not _is_hip and int(cp.cuda.Device().compute_capability) < 89 +pytestmark = pytest.mark.skipif(_skip_fp8, reason="FP8 accum tests require SM >= 89 on CUDA") + +# --------------------------------------------------------------------------- +# FP8 E4M3FN helpers (bias=7, no infinity, NaN = exp=15 & mant=7) +# --------------------------------------------------------------------------- + + +def e4m3fn_to_float(uint8_array): + """Decode a cupy uint8 array of E4M3FN bit patterns to float32.""" + bits = uint8_array.astype(cp.int32) + sign = (bits >> 7) & 1 + exp = (bits >> 3) & 0xF + mant = bits & 0x7 + + # Normal: (-1)^s * 2^(exp-7) * (1 + mant/8) + normal_val = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), (exp - 7).astype(cp.int32)) + # Subnormal (exp==0): (-1)^s * 2^(-6) * (mant/8) + subnormal_val = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-6)) + + result = cp.where(exp == 0, subnormal_val, normal_val) + result = cp.where(sign == 1, -result, result) + # Zero + result = cp.where((exp == 0) & (mant == 0), cp.float32(0.0), result) + # NaN: exp==15 & mant==7 + nan_mask = (exp == 15) & (mant == 7) + result = cp.where(nan_mask, cp.float32(float("nan")), result) + return result + + +def float_to_e4m3fn(f32_array, chunk_size=65536): + """Encode a cupy float32 array to uint8 E4M3FN bit patterns. + + Uses a lookup-table approach: precompute all 128 positive E4M3FN values, + then find nearest match per element via chunked broadcast comparison. + """ + # Build lookup table of all 128 positive E4M3FN values (0x00..0x7F) + all_bytes = cp.arange(128, dtype=cp.uint8) + all_floats = e4m3fn_to_float(all_bytes) # (128,) float32 + # Mark NaN entries as inf so they're never selected as nearest + all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats) + + # Clamp input and extract sign + clamped = f32_array.astype(cp.float32) + clamped = cp.clip(clamped, -448.0, 448.0) + signs = (clamped < 0).astype(cp.uint8) + absval = cp.abs(clamped) + + result = cp.zeros(absval.shape, dtype=cp.uint8) + n = absval.size + absval_flat = absval.ravel() + result_flat = result.ravel() + + for start in range(0, n, chunk_size): + end = min(start + chunk_size, n) + chunk = absval_flat[start:end] + # (chunk_size, 128) difference matrix + diffs = cp.abs(chunk[:, None] - all_floats[None, :]) + result_flat[start:end] = cp.argmin(diffs, axis=1).astype(cp.uint8) + + # Combine with sign bit + result = result_flat.reshape(absval.shape) + result = result | (signs << 7) + # Handle exact zero + result = cp.where(absval == 0, cp.uint8(0), result) + return result + + +# --------------------------------------------------------------------------- +# FP8 E4M3B15 helpers (bias=15, max=0.9375, NaN = exp==15 or bits==0x80) +# --------------------------------------------------------------------------- + + +def e4m3b15_to_float(uint8_array): + """Decode a cupy uint8 array of E4M3B15 bit patterns to float32.""" + bits = uint8_array.astype(cp.int32) + sign = (bits >> 7) & 1 + exp = (bits >> 3) & 0xF + mant = bits & 0x7 + + # Normal: (-1)^s * 2^(exp-15) * (1 + mant/8) + normal_val = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), (exp - 15).astype(cp.int32)) + # Subnormal (exp==0): (-1)^s * 2^(-14) * (mant/8) + subnormal_val = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-14)) + + result = cp.where(exp == 0, subnormal_val, normal_val) + result = cp.where(sign == 1, -result, result) + # Zero + result = cp.where((exp == 0) & (mant == 0), cp.float32(0.0), result) + # NaN: exp==15 or negative zero (0x80) + nan_mask = (exp == 15) | (uint8_array.astype(cp.int32) == 0x80) + result = cp.where(nan_mask, cp.float32(float("nan")), result) + return result + + +def float_to_e4m3b15(f32_array, chunk_size=65536): + """Encode a cupy float32 array to uint8 E4M3B15 bit patterns. + + Same lookup-table approach as float_to_e4m3fn. + """ + # Build lookup table of all 128 positive E4M3B15 values (0x00..0x7F) + all_bytes = cp.arange(128, dtype=cp.uint8) + all_floats = e4m3b15_to_float(all_bytes) # (128,) float32 + # Mark NaN entries as inf so they're never selected as nearest + all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats) + + # Clamp input and extract sign + clamped = f32_array.astype(cp.float32) + clamped = cp.clip(clamped, -0.9375, 0.9375) + signs = (clamped < 0).astype(cp.uint8) + absval = cp.abs(clamped) + + result = cp.zeros(absval.shape, dtype=cp.uint8) + n = absval.size + absval_flat = absval.ravel() + result_flat = result.ravel() + + for start in range(0, n, chunk_size): + end = min(start + chunk_size, n) + chunk = absval_flat[start:end] + # (chunk_size, 128) difference matrix + diffs = cp.abs(chunk[:, None] - all_floats[None, :]) + result_flat[start:end] = cp.argmin(diffs, axis=1).astype(cp.uint8) + + # Combine with sign bit + result = result_flat.reshape(absval.shape) + result = result | (signs << 7) + # Handle exact zero + result = cp.where(absval == 0, cp.uint8(0), result) + return result + + +# --------------------------------------------------------------------------- +# Shared test helpers +# --------------------------------------------------------------------------- + + +def setup_algorithms(mpi_group): + """Build default algorithms and return (comm_group, algo_map, scratch_buf).""" + comm_group = CommGroup(mpi_group.comm) + scratch = GpuBuffer(1 << 27, dtype=cp.uint8) # 128 MB + AlgorithmCollectionBuilder.reset() + builder = AlgorithmCollectionBuilder() + algorithms = builder.build_default_algorithms( + scratch_buffer=scratch.data.ptr, + scratch_buffer_size=scratch.nbytes, + rank=comm_group.my_rank, + ) + algo_map = {a.name: a for a in algorithms} + return comm_group, algo_map, scratch + + +def run_allreduce(algo, comm_group, buffer, dtype, accum_dtype=None, nblocks=0, nthreads_per_block=0): + """Run allreduce in-place on buffer and return a copy of the result.""" + ret = algo.execute( + comm=comm_group.communicator, + input_buffer=buffer.data.ptr, + output_buffer=buffer.data.ptr, + input_size=buffer.nbytes, + output_size=buffer.nbytes, + dtype=dtype, + op=ReduceOp.SUM, + stream=cp.cuda.get_current_stream().ptr, + nblocks=nblocks, + nthreads_per_block=nthreads_per_block, + symmetric_memory=True, + accum_dtype=accum_dtype, + ) + cp.cuda.Device().synchronize() + assert ret == 0, f"Allreduce failed with error code {ret}" + return buffer.copy() + + +# --------------------------------------------------------------------------- +# Test: FP8 E4M3 accumulation correctness +# --------------------------------------------------------------------------- + + +@parametrize_mpi_groups(8) +@pytest.mark.parametrize( + "algo_name", + [ + "default_allreduce_packet", + "default_allreduce_nvls_packet", + "default_allreduce_fullmesh", + "default_allreduce_rsag_zero_copy", + "default_allreduce_allpair_packet", + ], +) +@pytest.mark.parametrize("size", [1024, 4096, 16384, 65536, 262144, 1048576]) +def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int): + """Verify that FP8 E4M3 allreduce with higher-precision accumulation is at + least as accurate as native FP8 accumulation, across all algorithm variants.""" + rank = mpi_group.comm.rank + world_size = mpi_group.comm.size + + comm_group, algo_map, scratch = setup_algorithms(mpi_group) + if algo_name not in algo_map: + pytest.skip(f"{algo_name} not available") + if "nvls" in algo_name and not is_nvls_supported(): + pytest.skip(f"{algo_name} requires NVLS which is not supported on this platform") + algo = algo_map[algo_name] + + buf = GpuBuffer(size, dtype=cp.uint8) + + accum_configs = [ + ("fp8_native", DataType.float8_e4m3), + ("float16", DataType.float16), + ("float32", DataType.float32), + ] + + # rsag_zero_copy and fullmesh need explicit block/thread counts + if "rsag" in algo_name: + nb = max(1, min(32, size // (world_size * 32))) + nt = 1024 + elif "fullmesh" in algo_name: + nb = 35 + nt = 512 + else: + nb = 0 + nt = 0 + + errors = {} + for accum_label, accum_dtype in accum_configs: + # Generate deterministic per-rank data (use numpy to avoid hipRAND issues on ROCm) + rng = np.random.RandomState(42 + rank) + src_f32 = cp.asarray(rng.randn(size).astype(np.float32)) + src_f32 = cp.clip(src_f32, -240.0, 240.0) + src_fp8 = float_to_e4m3fn(src_f32) + + # Copy into symmetric buffer + buf[:] = src_fp8 + cp.cuda.Device().synchronize() + + # Run allreduce + result = run_allreduce( + algo, + comm_group, + buf, + dtype=DataType.float8_e4m3, + accum_dtype=accum_dtype, + nblocks=nb, + nthreads_per_block=nt, + ) + result_f32 = e4m3fn_to_float(result) + + # Compute float32 reference: sum all ranks' quantized FP8 inputs in float32 + ref_f32 = cp.zeros(size, dtype=cp.float32) + for r in range(world_size): + rng_r = np.random.RandomState(42 + r) + rank_data = cp.asarray(rng_r.randn(size).astype(np.float32)) + rank_data = cp.clip(rank_data, -240.0, 240.0) + rank_data_fp8 = float_to_e4m3fn(rank_data) + ref_f32 += e4m3fn_to_float(rank_data_fp8) + + # Compute errors + abs_err = cp.abs(result_f32 - ref_f32) + mean_abs_err = float(cp.mean(abs_err)) + errors[accum_label] = mean_abs_err + + # Reset between runs + algo.reset() + + # Higher-precision accumulation should be at least as accurate as native fp8 + assert ( + errors["float16"] <= errors["fp8_native"] + 1e-6 + ), f"float16 accum ({errors['float16']:.6f}) worse than native ({errors['fp8_native']:.6f})" + assert ( + errors["float32"] <= errors["fp8_native"] + 1e-6 + ), f"float32 accum ({errors['float32']:.6f}) worse than native ({errors['fp8_native']:.6f})" + + +# --------------------------------------------------------------------------- +# Test: FP8 E4M3B15 accumulation correctness +# --------------------------------------------------------------------------- + + +@parametrize_mpi_groups(8) +@pytest.mark.parametrize( + "algo_name", + [ + "default_allreduce_packet", + "default_allreduce_nvls_packet", + "default_allreduce_rsag_zero_copy", + "default_allreduce_fullmesh", + "default_allreduce_allpair_packet", + ], +) +@pytest.mark.parametrize("size", [1024, 4096, 65536]) +def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int): + """Verify that FP8 E4M3B15 allreduce with higher-precision accumulation is at + least as accurate as native E4M3B15 accumulation.""" + rank = mpi_group.comm.rank + world_size = mpi_group.comm.size + + comm_group, algo_map, scratch = setup_algorithms(mpi_group) + if algo_name not in algo_map: + pytest.skip(f"{algo_name} not available") + if "nvls" in algo_name and not is_nvls_supported(): + pytest.skip(f"{algo_name} requires NVLS which is not supported on this platform") + + algo = algo_map[algo_name] + buf = GpuBuffer(size, dtype=cp.uint8) + + accum_configs = [ + ("e4m3b15_native", DataType.float8_e4m3b15), + ("float16", DataType.float16), + ("float32", DataType.float32), + ] + + # rsag_zero_copy needs explicit block/thread counts, scaled to data size + if "rsag" in algo_name: + nb = max(1, min(32, size // (world_size * 32))) + nt = 1024 + else: + nb = 0 + nt = 0 + + errors = {} + for accum_label, accum_dtype in accum_configs: + # Generate deterministic per-rank random uint8 values in valid e4m3b15 range + rng = np.random.RandomState(42 + rank) + raw = cp.asarray(rng.randint(0, 0x78, (size,)).astype(np.uint8)) + signs = cp.asarray(rng.randint(0, 2, (size,)).astype(np.uint8)) << 7 + src_uint8 = raw | signs + # Fix negative zero -> positive zero + src_uint8 = cp.where(src_uint8 == 0x80, cp.uint8(0), src_uint8) + + # Copy into symmetric buffer + buf[:] = src_uint8 + cp.cuda.Device().synchronize() + + # Run allreduce + result = run_allreduce( + algo, + comm_group, + buf, + dtype=DataType.float8_e4m3b15, + accum_dtype=accum_dtype, + nblocks=nb, + nthreads_per_block=nt, + ) + + # Decode result + result_f32 = e4m3b15_to_float(result) + + # Compute float32 reference + ref_f32 = cp.zeros(size, dtype=cp.float32) + for r in range(world_size): + rng_r = np.random.RandomState(42 + r) + raw_r = cp.asarray(rng_r.randint(0, 0x78, (size,)).astype(np.uint8)) + signs_r = cp.asarray(rng_r.randint(0, 2, (size,)).astype(np.uint8)) << 7 + bits_r = raw_r | signs_r + bits_r = cp.where(bits_r == 0x80, cp.uint8(0), bits_r) + ref_f32 += e4m3b15_to_float(bits_r) + + # Clamp reference to e4m3b15 representable range + ref_f32 = cp.clip(ref_f32, -0.9375, 0.9375) + + # Compute errors (only on valid entries) + valid = ~cp.isnan(result_f32) & ~cp.isnan(ref_f32) + abs_err = cp.abs(result_f32[valid] - ref_f32[valid]) + mean_abs_err = float(cp.mean(abs_err)) if abs_err.size > 0 else 0.0 + errors[accum_label] = mean_abs_err + + algo.reset() + + # Higher-precision accumulation should be at least as accurate as native + assert ( + errors["float16"] <= errors["e4m3b15_native"] + 1e-8 + ), f"float16 accum ({errors['float16']:.8f}) worse than native ({errors['e4m3b15_native']:.8f})" + assert ( + errors["float32"] <= errors["e4m3b15_native"] + 1e-8 + ), f"float32 accum ({errors['float32']:.8f}) worse than native ({errors['e4m3b15_native']:.8f})" diff --git a/src/core/algorithm.cc b/src/core/algorithm.cc index 99e7b031..ffa53aa8 100644 --- a/src/core/algorithm.cc +++ b/src/core/algorithm.cc @@ -41,7 +41,9 @@ NativeAlgorithm::NativeAlgorithm(std::string name, std::string collective, InitF CommResult NativeAlgorithm::execute(std::shared_ptr comm, const void* input, void* output, size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, std::shared_ptr, int nBlocks, int nThreadsPerBlock, - bool symmetricMemory, const std::unordered_map& extras) { + bool symmetricMemory, const std::unordered_map& extras, + DataType accumDtype) { + if (accumDtype == DataType::AUTO) accumDtype = dtype; if (!initialized_) { initFunc_(comm); initialized_ = true; @@ -53,7 +55,7 @@ CommResult NativeAlgorithm::execute(std::shared_ptr comm, const vo contexts_[ctxKey] = ctx; } return kernelLaunchFunc_(contexts_[ctxKey], input, output, inputSize, outputSize, dtype, op, stream, nBlocks, - nThreadsPerBlock, extras); + nThreadsPerBlock, extras, accumDtype); } const std::string& NativeAlgorithm::name() const { return name_; } @@ -77,10 +79,7 @@ const CollectiveBufferMode& NativeAlgorithm::bufferMode() const { return bufferM Algorithm::Constraint NativeAlgorithm::constraint() const { return constraint_; } -void NativeAlgorithm::reset() { - contexts_.clear(); - initialized_ = false; -} +void NativeAlgorithm::reset() { contexts_.clear(); } void AlgorithmCollection::registerAlgorithm(const std::string collective, const std::string algoName, std::shared_ptr algorithm) { @@ -166,7 +165,7 @@ Algorithm::Constraint DslAlgorithm::constraint() const { return constraint_; } CommResult DslAlgorithm::execute(std::shared_ptr comm, const void* input, void* output, size_t inputSize, size_t outputSize, DataType dtype, ReduceOp, cudaStream_t stream, std::shared_ptr executor, int, int, bool, - const std::unordered_map&) { + const std::unordered_map&, DataType) { if (!executor) { THROW(EXEC, Error, ErrorCode::InvalidUsage, "Executor is null in DslAlgorithm::execute"); } @@ -192,6 +191,10 @@ CommResult DslAlgorithm::execute(std::shared_ptr comm, const void* plan_, stream); break; #endif + case DataType::FLOAT8_E4M3B15: + executor->execute(rank, (__fp8_e4m3b15*)input, (__fp8_e4m3b15*)output, inputSize, outputSize, + DataType::FLOAT8_E4M3B15, plan_, stream); + break; case DataType::INT32: case DataType::UINT32: executor->execute(rank, (int*)input, (int*)output, inputSize, outputSize, DataType::UINT32, plan_, stream); diff --git a/src/core/endpoint.cc b/src/core/endpoint.cc index 3ae2e154..87042e1e 100644 --- a/src/core/endpoint.cc +++ b/src/core/endpoint.cc @@ -56,7 +56,7 @@ Endpoint::Impl::Impl(const EndpointConfig& config, Context::Impl& contextImpl) } ibQp_ = contextImpl.getIbContext(config_.transport) - ->createQp(config_.ib.port, gidIndex, config_.ib.maxCqSize, config_.ib.maxCqPollNum, + ->createQp(config_.ib.port, config_.ib.gidIndex, config_.ib.maxCqSize, config_.ib.maxCqPollNum, config_.ib.maxSendWr, maxRecvWr, config_.ib.maxWrPerSend, ibNoAtomic_); ibQpInfo_ = ibQp_->getInfo(); } else if (config_.transport == Transport::Ethernet) { diff --git a/src/core/executor/execution_kernel.cu b/src/core/executor/execution_kernel.cu index 2d36bcf5..28ced77f 100644 --- a/src/core/executor/execution_kernel.cu +++ b/src/core/executor/execution_kernel.cu @@ -82,6 +82,12 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo case DataType::FLOAT8_E5M2: // FP8 is not supported in CUDA execution kernel. break; + case DataType::FLOAT8_E4M3B15: + // fp8_e4m3b15 is a software type not supported in the CUDA execution kernel. + break; + case DataType::AUTO: + // AUTO is a sentinel resolved before reaching this point; nothing to do. + break; } } diff --git a/src/core/include/execution_kernel.hpp b/src/core/include/execution_kernel.hpp index 7719e61a..0451ea52 100644 --- a/src/core/include/execution_kernel.hpp +++ b/src/core/include/execution_kernel.hpp @@ -210,7 +210,7 @@ MSCCLPP_DEVICE_INLINE void handleReadReduceSend(const Operation& op, void* input sizeof(int4); void* remoteMemory = static_cast(memoryChannelBufferPtrs_[op.inputBufferRefs[index + 1].id]); val = mscclpp::read(remoteMemory, srcOffset + idx); - tmp = cal_vector(tmp, val); + tmp = calVector(tmp, val); } output4[outputOffset4 + idx] = tmp; if constexpr (SendToRemote) { @@ -353,9 +353,9 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPackets(const Operation& op, void* in for (uint32_t index = 0; index < nSrcs; ++index) { PacketType* pkt = (PacketType*)((char*)scratch + scratchOffset_ + 2 * inputOffsets[index]); PacketPayload val = pkt[idx].read(flag_); - data = cal_vector(data, val); + data = calVector(data, val); } - data = cal_vector(data, srcPacketPayload[idx]); + data = calVector(data, srcPacketPayload[idx]); dstPacketPayload[idx] = data; if constexpr (SendToRemote) { @@ -394,9 +394,9 @@ MSCCLPP_DEVICE_INLINE void handleReduceCopySendPackets(const Operation& op, void for (uint32_t index = 0; index < nSrcs; ++index) { PacketType* pkt = (PacketType*)((char*)scratch + scratchOffset_ + 2 * inputOffsets[index]); PacketPayload val = pkt[idx].read(flag_); - data = cal_vector(data, val); + data = calVector(data, val); } - data = cal_vector(data, srcPacketPayload[idx]); + data = calVector(data, srcPacketPayload[idx]); dstPacketPayload[idx] = data; PacketType* dst_val = &dstPkt[idx]; dst_val->write(data, flag_); @@ -464,7 +464,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(const Operation& op, void* input, vo size_t buffOffset = (inputOffsets[index] + getOffset(outputBufferRefs[index].type, offset)) / sizeof(int4); int4 val = buff4[buffOffset + idx]; - tmp = cal_vector(tmp, val); + tmp = calVector(tmp, val); } dst4[dstOffset4 + idx] = tmp; if constexpr (SendToRemote) { @@ -899,6 +899,17 @@ class ExecutionKernel { #endif break; #endif // __FP8_TYPES_EXIST__ + case DataType::FLOAT8_E4M3B15: + executionKernel<__fp8_e4m3b15, PacketType, ReuseScratch><<>>( + rank, (__fp8_e4m3b15*)src, (__fp8_e4m3b15*)dst, (__fp8_e4m3b15*)scratch, scratchOffset, scratchChunkSize, + plan, semaphores, localMemoryIdBegin, flag +#if defined(ENABLE_NPKIT) + , + NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); +#else + ); +#endif + break; case DataType::UINT8: executionKernel<<>>( rank, (uint8_t*)src, (uint8_t*)dst, (uint8_t*)scratch, scratchOffset, scratchChunkSize, plan, semaphores, @@ -910,6 +921,10 @@ class ExecutionKernel { ); #endif break; + case DataType::AUTO: + // AUTO is a sentinel that must be resolved before reaching this point. + assert(false && "DataType::AUTO must be resolved before kernel launch"); + break; } } #else // !defined(MSCCLPP_DEVICE_HIP) diff --git a/src/core/include/reduce_kernel.hpp b/src/core/include/reduce_kernel.hpp index fd9bd1e9..463f827d 100644 --- a/src/core/include/reduce_kernel.hpp +++ b/src/core/include/reduce_kernel.hpp @@ -14,7 +14,7 @@ namespace mscclpp { // Generic element-wise calculation helper template -MSCCLPP_DEVICE_INLINE T cal_elements(const T& a, const T& b) { +MSCCLPP_DEVICE_INLINE T calElements(const T& a, const T& b) { if constexpr (OpType == SUM) { return a + b; } else if constexpr (OpType == MIN) { @@ -24,56 +24,168 @@ MSCCLPP_DEVICE_INLINE T cal_elements(const T& a, const T& b) { } // Generic vector reduction helpers -template -MSCCLPP_DEVICE_INLINE int4 cal_vector_helper(const int4& a, const int4& b) { - int4 ret; - ret.w = bit_cast(cal_elements(bit_cast(a.w), bit_cast(b.w))); - ret.x = bit_cast(cal_elements(bit_cast(a.x), bit_cast(b.x))); - ret.y = bit_cast(cal_elements(bit_cast(a.y), bit_cast(b.y))); - ret.z = bit_cast(cal_elements(bit_cast(a.z), bit_cast(b.z))); - return ret; -} template -MSCCLPP_DEVICE_INLINE uint2 cal_vector_helper(const uint2& a, const uint2& b) { +MSCCLPP_DEVICE_INLINE uint2 calVectorHelper(const uint2& a, const uint2& b) { uint2 ret; - ret.x = bit_cast(cal_elements(bit_cast(a.x), bit_cast(b.x))); - ret.y = bit_cast(cal_elements(bit_cast(a.y), bit_cast(b.y))); + ret.x = bit_cast(calElements(bit_cast(a.x), bit_cast(b.x))); + ret.y = bit_cast(calElements(bit_cast(a.y), bit_cast(b.y))); return ret; } -template -MSCCLPP_DEVICE_INLINE int cal_vector_helper(const int& a, const int& b) { - return bit_cast(cal_elements(bit_cast(a), bit_cast(b))); +/// f32x2 specialization for uint2: uses packed f32x2 operator+ (Blackwell __fadd2_rn when available). +template <> +MSCCLPP_DEVICE_INLINE uint2 calVectorHelper(const uint2& a, const uint2& b) { + f32x2 fa = bit_cast(a); + f32x2 fb = bit_cast(b); + f32x2 fr = fa + fb; + return bit_cast(fr); +} + +template <> +MSCCLPP_DEVICE_INLINE uint2 calVectorHelper(const uint2& a, const uint2& b) { + f32x2 fa = bit_cast(a); + f32x2 fb = bit_cast(b); + f32x2 fr = mscclpp::min(fa, fb); + return bit_cast(fr); } template -MSCCLPP_DEVICE_INLINE uint32_t cal_vector_helper(const uint32_t& a, const uint32_t& b) { - return bit_cast(cal_elements(bit_cast(a), bit_cast(b))); +MSCCLPP_DEVICE_INLINE int4 calVectorHelper(const int4& a, const int4& b) { + int4 ret; + ret.w = bit_cast(calElements(bit_cast(a.w), bit_cast(b.w))); + ret.x = bit_cast(calElements(bit_cast(a.x), bit_cast(b.x))); + ret.y = bit_cast(calElements(bit_cast(a.y), bit_cast(b.y))); + ret.z = bit_cast(calElements(bit_cast(a.z), bit_cast(b.z))); + return ret; } -// cal_vector wrapper - converts scalar types to vector types and calls cal_vector_helper +/// f32x2 specialization for int4: process as two uint2 pairs using packed f32x2 arithmetic. +template <> +MSCCLPP_DEVICE_INLINE int4 calVectorHelper(const int4& a, const int4& b) { + uint2 lo_a = {(uint32_t)a.x, (uint32_t)a.y}; + uint2 hi_a = {(uint32_t)a.z, (uint32_t)a.w}; + uint2 lo_b = {(uint32_t)b.x, (uint32_t)b.y}; + uint2 hi_b = {(uint32_t)b.z, (uint32_t)b.w}; + uint2 lo_r = calVectorHelper(lo_a, lo_b); + uint2 hi_r = calVectorHelper(hi_a, hi_b); + return {(int)lo_r.x, (int)lo_r.y, (int)hi_r.x, (int)hi_r.y}; +} + +template <> +MSCCLPP_DEVICE_INLINE int4 calVectorHelper(const int4& a, const int4& b) { + uint2 lo_a = {(uint32_t)a.x, (uint32_t)a.y}; + uint2 hi_a = {(uint32_t)a.z, (uint32_t)a.w}; + uint2 lo_b = {(uint32_t)b.x, (uint32_t)b.y}; + uint2 hi_b = {(uint32_t)b.z, (uint32_t)b.w}; + uint2 lo_r = calVectorHelper(lo_a, lo_b); + uint2 hi_r = calVectorHelper(hi_a, hi_b); + return {(int)lo_r.x, (int)lo_r.y, (int)hi_r.x, (int)hi_r.y}; +} + +template +MSCCLPP_DEVICE_INLINE int calVectorHelper(const int& a, const int& b) { + return bit_cast(calElements(bit_cast(a), bit_cast(b))); +} + +template +MSCCLPP_DEVICE_INLINE uint32_t calVectorHelper(const uint32_t& a, const uint32_t& b) { + return bit_cast(calElements(bit_cast(a), bit_cast(b))); +} + +/// f32x2 specialization for uint32_t: a single float packed in 32 bits (scalar fallback). +template <> +MSCCLPP_DEVICE_INLINE uint32_t calVectorHelper(const uint32_t& a, const uint32_t& b) { + float fa = bit_cast(a); + float fb = bit_cast(b); + return bit_cast(fa + fb); +} + +template <> +MSCCLPP_DEVICE_INLINE uint32_t calVectorHelper(const uint32_t& a, const uint32_t& b) { + float fa = bit_cast(a); + float fb = bit_cast(b); + return bit_cast(fminf(fa, fb)); +} + +// calVector wrapper – converts scalar types to vector types and calls calVectorHelper template -MSCCLPP_DEVICE_INLINE DataType cal_vector(const DataType& a, const DataType& b) { +MSCCLPP_DEVICE_INLINE DataType calVector(const DataType& a, const DataType& b) { // Define the vectorized computation type based on the element type static_assert(sizeof(DataType) % sizeof(T) == 0, "DataType size must be multiple of T size"); static_assert(sizeof(DataType) >= 4, "DataType size must be at least 4 bytes"); using CompType = typename std::conditional_t< - std::is_same_v, f16x2, + std::is_same_v, f32x2, std::conditional_t< - std::is_same_v, bf16x2, - std::conditional_t, u8x4, + std::is_same_v, f16x2, + std::conditional_t< + std::is_same_v, bf16x2, + std::conditional_t< + std::is_same_v, u8x4, + std::conditional_t, f8_e4m3b15x4, #if defined(__FP8_TYPES_EXIST__) - std::conditional_t, f8_e4m3x4, - std::conditional_t, f8_e5m2x4, -#endif - T -#if defined(__FP8_TYPES_EXIST__) - >>>>>; + std::conditional_t, f8_e4m3x4, + std::conditional_t, f8_e5m2x4, T>> #else - >>>; + T #endif - return cal_vector_helper(a, b); + >>>>>; + return calVectorHelper(a, b); +} + +/// Upcast a packed DataType (containing T elements) to a packed AccDataType (containing AccumT elements). +/// Uses the optimized to<>() specializations when available (e.g. FP8 -> float hardware intrinsics). +/// When AccumT == T, this is a no-op identity. +template +MSCCLPP_DEVICE_INLINE AccDataType upcastVector(const DataType& val) { + if constexpr (std::is_same_v) { + return val; + } else { + constexpr int nElems = sizeof(DataType) / sizeof(T); + using FromVec = VectorType; + using ToVec = VectorType; + ToVec result = mscclpp::to(reinterpret_cast(val)); + return reinterpret_cast(result); + } +} + +/// Downcast a packed AccDataType (containing AccumT elements) back to DataType (containing T elements). +/// Uses the optimized to<>() specializations when available. +/// When AccumT == T, this is a no-op identity. +template +MSCCLPP_DEVICE_INLINE DataType downcastVector(const AccDataType& val) { + if constexpr (std::is_same_v) { + return val; + } else { + constexpr int nElems = sizeof(DataType) / sizeof(T); + using FromVec = VectorType; + using ToVec = VectorType; + FromVec result = mscclpp::to(reinterpret_cast(val)); + return reinterpret_cast(result); + } +} + +/// Accumulate `val` (packed T elements in DataType) into `acc` (packed AccumT elements in AccDataType). +/// When AccumT == T, falls back to the standard calVector. +/// Otherwise, upcasts val to AccumT, reduces element-wise, and returns the AccumT accumulator. +template +MSCCLPP_DEVICE_INLINE AccDataType calVectorAccum(const AccDataType& acc, const DataType& val) { + if constexpr (std::is_same_v) { + return calVector(acc, val); + } else { + constexpr int nElems = sizeof(DataType) / sizeof(T); + using FromVec = VectorType; + using ToVec = VectorType; + + ToVec fv = mscclpp::to(reinterpret_cast(val)); + const ToVec& fa = reinterpret_cast(acc); + ToVec fr; +#pragma unroll + for (int i = 0; i < nElems; ++i) { + fr.data[i] = calElements(fa.data[i], fv.data[i]); + } + return reinterpret_cast(fr); + } } #endif // defined(MSCCLPP_DEVICE_COMPILE) diff --git a/src/ext/collectives/allgather/allgather_fullmesh.cu b/src/ext/collectives/allgather/allgather_fullmesh.cu index 0b288b38..fb51a342 100644 --- a/src/ext/collectives/allgather/allgather_fullmesh.cu +++ b/src/ext/collectives/allgather/allgather_fullmesh.cu @@ -183,7 +183,8 @@ std::shared_ptr AllgatherFullmesh::build() { [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, [[maybe_unused]] DataType dtype, [[maybe_unused]] ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras) -> CommResult { + const std::unordered_map& extras, + [[maybe_unused]] DataType accumDtype) -> CommResult { return self->allgatherKernelFunc(ctx, input, output, inputSize, stream, nBlocks, nThreadsPerBlock, extras); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, diff --git a/src/ext/collectives/allgather/allgather_fullmesh_2.cu b/src/ext/collectives/allgather/allgather_fullmesh_2.cu index cf6027d9..9d169d68 100644 --- a/src/ext/collectives/allgather/allgather_fullmesh_2.cu +++ b/src/ext/collectives/allgather/allgather_fullmesh_2.cu @@ -212,7 +212,8 @@ std::shared_ptr AllgatherFullmesh2::build() { [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, [[maybe_unused]] mscclpp::DataType dtype, [[maybe_unused]] ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras) -> mscclpp::CommResult { + const std::unordered_map& extras, + [[maybe_unused]] mscclpp::DataType accumDtype) -> mscclpp::CommResult { return self->allgatherKernelFunc(ctx, input, output, inputSize, stream, nBlocks, nThreadsPerBlock, extras); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, diff --git a/src/ext/collectives/allreduce/allreduce_allpair_packet.cu b/src/ext/collectives/allreduce/allreduce_allpair_packet.cu index 83950d7c..17bcfc33 100644 --- a/src/ext/collectives/allreduce/allreduce_allpair_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_allpair_packet.cu @@ -2,6 +2,7 @@ // Licensed under the MIT license. #include +#include #include "allreduce/allreduce_allpair_packet.hpp" #include "allreduce/common.hpp" @@ -11,7 +12,7 @@ namespace mscclpp { namespace collective { -template +template __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHandle* memoryChannels, size_t channelDataOffset, size_t scratchBufferSize, int rank, int nRanksPerNode, int worldSize, size_t nelems, uint32_t numScratchBuff, void* flags, @@ -43,13 +44,16 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand // step 2: Reduce Data for (size_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nelems; idx += blockDim.x * gridDim.x) { uint32_t data = src[idx]; + using AccRaw = std::conditional_t, uint32_t, + mscclpp::VectorType>; + AccRaw acc = mscclpp::upcastVector(data); for (int index = 0; index < nPeers; index++) { const int remoteRank = index < rank ? index : index + 1; LL8Packet* dstPkt = (LL8Packet*)scratchBuff + remoteRank * nelems; uint32_t val = dstPkt[idx].read(flag, -1); - data = cal_vector(val, data); + acc = mscclpp::calVectorAccum(acc, val); } - dst[idx] = data; + dst[idx] = mscclpp::downcastVector(acc); } __syncthreads(); if (threadIdx.x == 0) { @@ -67,7 +71,7 @@ inline std::pair getDefaultBlockNumAndThreadNum(size_t inputSize, int return {(worldSize - 1) * 4, 512}; } -template +template struct AllpairAdapter { static cudaError_t call(const void* buff, void* scratch, void* resultBuff, void* memoryChannels, void*, DeviceHandle*, DeviceHandle*, size_t channelInOffset, size_t, @@ -76,7 +80,12 @@ struct AllpairAdapter { int nThreadsPerBlock = 0) { using ChannelType = DeviceHandle; const size_t nelems = inputSize / sizeof(T); - allreduceAllPairs<<>>( + // Round nBlocks to multiple of nPeers so every block maps to a valid peer. + const int nPeers = worldSize - 1; + if (nPeers > 0) { + nBlocks = (nBlocks / nPeers) * nPeers; + } + allreduceAllPairs<<>>( (T*)buff, (T*)scratch, (T*)resultBuff, (ChannelType*)memoryChannels, channelInOffset, scratchBufferSize, rank, nRanksPerNode, worldSize, nelems, numScratchBuff, flags, flagSize); return cudaGetLastError(); @@ -94,18 +103,24 @@ void AllreduceAllpairPacket::initialize(std::shared_ptr comm) { CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map&) { + const std::unordered_map&, + DataType accumDtype) { auto algoCtx = std::static_pointer_cast(ctx); std::pair blockAndThreadNum{nBlocks, nThreadsPerBlock}; if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) { blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize, algoCtx->workSize); } + // nBlocks must be at least nPeers for allpair — each block maps to one peer. + const int nPeers = algoCtx->nRanksPerNode - 1; + if (nPeers > 0 && blockAndThreadNum.first < nPeers) { + return CommResult::CommInvalidArgument; + } size_t sendBytes; CUdeviceptr sendBasePtr; MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)input)); size_t channelInOffset = (char*)input - (char*)sendBasePtr; - AllreduceFunc allreduce = dispatch(op, dtype); + AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { WARN("Unsupported operation or data type for allreduce: op=%d, dtype=%d", op, static_cast(dtype)); return CommResult::CommInvalidArgument; @@ -161,9 +176,9 @@ std::shared_ptr AllreduceAllpairPacket::build() { [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) { + int nThreadsPerBlock, const std::unordered_map& extras, DataType accumDtype) { return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, - extras); + extras, accumDtype); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, diff --git a/src/ext/collectives/allreduce/allreduce_fullmesh.cu b/src/ext/collectives/allreduce/allreduce_fullmesh.cu index 13c63ba1..24d2a31c 100644 --- a/src/ext/collectives/allreduce/allreduce_fullmesh.cu +++ b/src/ext/collectives/allreduce/allreduce_fullmesh.cu @@ -9,7 +9,7 @@ namespace mscclpp { namespace collective { -template +template __global__ void __launch_bounds__(512, 1) allreduceFullmesh(T* buff, T* scratch, T* resultBuff, DeviceHandle* memoryChannels, DeviceHandle* memoryOutChannels, size_t channelOutDataOffset, int rank, @@ -26,6 +26,10 @@ __global__ void __launch_bounds__(512, 1) int4* scratch4 = reinterpret_cast((char*)scratch); int4* resultBuff4 = reinterpret_cast(resultBuff); + // AccumVec: wider vector for mixed-precision accumulation. When AccumT==T, this is just int4 (no-op). + constexpr int nElemsPerInt4 = sizeof(int4) / sizeof(T); + using AccumVec = std::conditional_t, int4, mscclpp::VectorType>; + // Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4` constexpr size_t unitNInt4 = 512; const size_t maxNInt4PerBlock = @@ -81,12 +85,14 @@ __global__ void __launch_bounds__(512, 1) __syncthreads(); for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { - int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; + int4 rawData = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; + AccumVec acc = mscclpp::upcastVector(rawData); for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx]; - data = cal_vector(val, data); + acc = mscclpp::calVectorAccum(acc, val); } + int4 data = mscclpp::downcastVector(acc); resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), @@ -121,12 +127,14 @@ __global__ void __launch_bounds__(512, 1) __syncthreads(); for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { - int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; + int4 rawData = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; + AccumVec acc = mscclpp::upcastVector(rawData); for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx]; - data = cal_vector(val, data); + acc = mscclpp::calVectorAccum(acc, val); } + int4 data = mscclpp::downcastVector(acc); resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), @@ -144,7 +152,7 @@ __global__ void __launch_bounds__(512, 1) } } -template +template struct AllreduceAllconnectAdapter { static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* memoryOutChannels, DeviceHandle*, DeviceHandle*, size_t, @@ -155,7 +163,7 @@ struct AllreduceAllconnectAdapter { size_t nelems = inputSize / sizeof(T); if (nBlocks == 0) nBlocks = 35; if (nThreadsPerBlock == 0) nThreadsPerBlock = 512; - allreduceFullmesh<<>>( + allreduceFullmesh<<>>( (T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, (ChannelType*)memoryOutChannels, channelOutDataOffset, rank, nRanksPerNode, worldSize, nelems); return cudaGetLastError(); @@ -174,10 +182,10 @@ void AllreduceFullmesh::initialize(std::shared_ptr comm) { localScratchMemory_ = std::move(localMemory); } -CommResult AllreduceFullmesh::allreduceKernelFunc(const std::shared_ptr ctx_void, const void* input, void* output, - size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, - int nBlocks, int nThreadsPerBlock, - const std::unordered_map&) { +CommResult AllreduceFullmesh::allreduceKernelFunc( + const std::shared_ptr ctx_void, const void* input, void* output, size_t inputSize, DataType dtype, + ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, + [[maybe_unused]] const std::unordered_map& extras, DataType accumDtype) { auto ctx = std::static_pointer_cast(ctx_void); size_t recvBytes; CUdeviceptr recvBasePtr; @@ -198,13 +206,20 @@ CommResult AllreduceFullmesh::allreduceKernelFunc(const std::shared_ptr ct } inputChannelHandles = this->memoryChannelsMap_[input].second; - AllreduceFunc allreduce = dispatch(op, dtype); + AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { WARN("Unsupported operation or data type for allreduce: op=%d, dtype=%d", static_cast(op), static_cast(dtype)); return CommResult::CommInvalidArgument; } std::pair numBlocksAndThreads = {nBlocks, nThreadsPerBlock}; + if (numBlocksAndThreads.first > 64) { + WARN("AllreduceFullmesh: number of blocks exceeds maximum supported blocks, which is 64"); + return mscclpp::CommResult::CommInvalidArgument; + } + if (numBlocksAndThreads.first == 0 || numBlocksAndThreads.second == 0) { + numBlocksAndThreads = {35, 512}; + } cudaError_t error = allreduce(input, this->scratchBuffer_, output, inputChannelHandles.get(), ctx->memoryChannelDeviceHandles.get(), nullptr, nullptr, 0, channelOutOffset, 0, ctx->rank, ctx->nRanksPerNode, ctx->workSize, inputSize, @@ -261,9 +276,10 @@ std::shared_ptr AllreduceFullmesh::build() { [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) -> CommResult { + int nThreadsPerBlock, const std::unordered_map& extras, + DataType accumDtype) -> CommResult { return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, - extras); + extras, accumDtype); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, diff --git a/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu b/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu index b542a6a6..2d71cd63 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu @@ -146,7 +146,7 @@ __global__ void __launch_bounds__(1024, 1) #endif } -template +template struct NvlsBlockPipelineAdapter { static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void*, DeviceHandle* nvlsChannels, DeviceHandle*, size_t, size_t, @@ -155,6 +155,9 @@ struct NvlsBlockPipelineAdapter { // uint8_t is not supported for NVLS (no hardware support for byte-level reduction) if constexpr (std::is_same_v) { return cudaErrorNotSupported; + } else if constexpr (std::is_same_v) { + // fp8_e4m3b15 is a software-only type with no hardware NVLS support. + return cudaErrorNotSupported; } else #if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS if constexpr (std::is_same_v || std::is_same_v) { @@ -187,9 +190,10 @@ void AllreduceNvlsBlockPipeline::initialize(std::shared_ptr comm) CommResult AllreduceNvlsBlockPipeline::allreduceKernelFunc(const std::shared_ptr ctx_void, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map&) { + const std::unordered_map& extras, + DataType accumDtype) { auto ctx = std::static_pointer_cast(ctx_void); - AllreduceFunc allreduce = dispatch(op, dtype); + AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast(dtype)); return CommResult::CommInvalidArgument; @@ -235,9 +239,9 @@ std::shared_ptr AllreduceNvlsBlockPipeline::build() { [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) { + int nThreadsPerBlock, const std::unordered_map& extras, DataType accumDtype) { return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, - extras); + extras, accumDtype); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, diff --git a/src/ext/collectives/allreduce/allreduce_nvls_packet.cu b/src/ext/collectives/allreduce/allreduce_nvls_packet.cu index 9824fbcd..a616485e 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_packet.cu @@ -1,15 +1,17 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +#include + #include "allreduce/allreduce_nvls_packet.hpp" #include "allreduce/common.hpp" #include "collective_utils.hpp" -#include "debug.h" +#include "logger.hpp" namespace mscclpp { namespace collective { -template +template __global__ void __launch_bounds__(1024, 1) allreduceNvlsPacket([[maybe_unused]] const T* input, [[maybe_unused]] T* scratch, [[maybe_unused]] T* output, [[maybe_unused]] mscclpp::DeviceHandle* multicast, @@ -31,15 +33,16 @@ __global__ void __launch_bounds__(1024, 1) mscclpp::SwitchChannelDeviceHandle::multimemStore(*(mscclpp::f32x2*)(&pkt), multiPkt + i); } for (uint32_t i = tid; i < nPktPerRank * worldSize; i += blockDim.x * gridDim.x) { - uint data = src[i]; + // When T == AccumT, stay with raw uint to avoid type mismatch in identity path. + using AccRaw = + std::conditional_t, uint, mscclpp::VectorType>; + AccRaw acc = mscclpp::upcastVector(src[i]); for (int peer = 0; peer < worldSize; peer++) { - if (peer == rank) { - continue; - } + if (peer == rank) continue; uint val = scratchPkt[peer * worldSize * nPktPerRank + i].read(flag); - data = cal_vector(data, val); + acc = mscclpp::calVectorAccum(acc, val); } - dst[i] = data; + dst[i] = mscclpp::downcastVector(acc); } __syncthreads(); if (threadIdx.x == 0) { @@ -62,13 +65,13 @@ inline std::pair getDefaultBlockNumAndThreadNum(size_t inputSize) { return {blockNum, threadNum}; } -template +template struct AllreduceNvlsPacketAdapter { static cudaError_t call(const void* input, void* scratch, void* output, void*, void*, DeviceHandle* nvlsChannels, DeviceHandle*, size_t, size_t, size_t scratchBufferSize, int rank, int, int worldSize, size_t inputSize, cudaStream_t stream, void* flags, uint32_t flagBufferSize, uint32_t, int nBlocks, int nThreadsPerBlock) { - allreduceNvlsPacket<<>>( + allreduceNvlsPacket<<>>( (const T*)input, (T*)scratch, (T*)output, nvlsChannels, inputSize / sizeof(T), scratchBufferSize, rank, worldSize, flags, flagBufferSize); return cudaGetLastError(); @@ -78,6 +81,8 @@ struct AllreduceNvlsPacketAdapter { void AllreduceNvlsPacket::initialize(std::shared_ptr comm) { int nSwitchChannels = 1; this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels); + this->switchChannels_ = + setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels); } AlgorithmCtxKey AllreduceNvlsPacket::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) { @@ -92,9 +97,7 @@ std::shared_ptr AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr< ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); // setup channels - int nSwitchChannels = 1; - ctx->switchChannels = - setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels); + ctx->switchChannels = this->switchChannels_; ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels); return ctx; } @@ -102,19 +105,20 @@ std::shared_ptr AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr< CommResult AllreduceNvlsPacket::allreduceKernelFunc(const std::shared_ptr ctx_void, const void* input, void* output, size_t inputSize, mscclpp::DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map&) { + const std::unordered_map&, + mscclpp::DataType accumDtype) { auto ctx = std::static_pointer_cast(ctx_void); std::pair blockAndThreadNum = {nBlocks, nThreadsPerBlock}; if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) { blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize); } if (blockAndThreadNum.first > maxBlockNum_) { - WARN("Block number %d exceeds the maximum limit %d", blockAndThreadNum.first, maxBlockNum_); + WARN(ALGO, "Block number ", blockAndThreadNum.first, " exceeds the maximum limit ", maxBlockNum_); return CommResult::CommInvalidArgument; } - AllreduceFunc allreduce = dispatch(op, dtype); + AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { - WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast(dtype)); + WARN(ALGO, "Unsupported operation or data type for allreduce, dtype=", static_cast(dtype)); return CommResult::CommInvalidArgument; } cudaError_t error = @@ -122,7 +126,7 @@ CommResult AllreduceNvlsPacket::allreduceKernelFunc(const std::shared_ptr 0, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerNode, ctx->workSize, inputSize, stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_, 0, blockAndThreadNum.first, blockAndThreadNum.second); if (error != cudaSuccess) { - WARN("AllreduceNvlsPacket failed with error: %s", cudaGetErrorString(error)); + WARN(ALGO, "AllreduceNvlsPacket failed with error: ", cudaGetErrorString(error)); return CommResult::CommUnhandledCudaError; } return CommResult::CommSuccess; @@ -136,9 +140,10 @@ std::shared_ptr AllreduceNvlsPacket::build() { [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, mscclpp::DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) { + int nThreadsPerBlock, const std::unordered_map& extras, + mscclpp::DataType accumDtype) { return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, - extras); + extras, accumDtype); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, diff --git a/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu b/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu index bc03ab26..3bb054da 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu @@ -109,7 +109,7 @@ __global__ void __launch_bounds__(1024, 1) #endif } -template +template struct NvlsWarpPipelineAdapter { static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void*, DeviceHandle* nvlsChannels, DeviceHandle*, size_t, size_t, @@ -118,6 +118,9 @@ struct NvlsWarpPipelineAdapter { // uint8_t is not supported for NVLS (no hardware support for byte-level reduction) if constexpr (std::is_same_v) { return cudaErrorNotSupported; + } else if constexpr (std::is_same_v) { + // fp8_e4m3b15 is a software-only type with no hardware NVLS support. + return cudaErrorNotSupported; } else #if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS if constexpr (std::is_same_v || std::is_same_v) { @@ -147,12 +150,12 @@ void AllreduceNvlsWarpPipeline::initialize(std::shared_ptr comm) { this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels_); } -CommResult AllreduceNvlsWarpPipeline::allreduceKernelFunc(const std::shared_ptr ctx_void, const void* input, - void* output, size_t inputSize, DataType dtype, ReduceOp op, - cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map&) { +CommResult AllreduceNvlsWarpPipeline::allreduceKernelFunc( + const std::shared_ptr ctx_void, const void* input, void* output, size_t inputSize, DataType dtype, + ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, + [[maybe_unused]] const std::unordered_map& extras, DataType accumDtype) { auto ctx = std::static_pointer_cast(ctx_void); - AllreduceFunc allreduce = dispatch(op, dtype); + AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast(dtype)); return CommResult::CommInvalidArgument; @@ -198,9 +201,9 @@ std::shared_ptr AllreduceNvlsWarpPipeline::build() { [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) { + int nThreadsPerBlock, const std::unordered_map& extras, DataType accumDtype) { return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, - extras); + extras, accumDtype); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, diff --git a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu index f251bcda..e7f2028f 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu @@ -67,7 +67,7 @@ __global__ void __launch_bounds__(1024, 1) #endif } -template +template struct NvlsAdapter { static cudaError_t call(const void*, void*, void*, void* memoryChannels, void*, mscclpp::DeviceHandle* nvlsChannels, @@ -77,6 +77,9 @@ struct NvlsAdapter { // uint8_t is not supported for NVLS (no hardware support for byte-level reduction) if constexpr (std::is_same_v) { return cudaErrorNotSupported; + } else if constexpr (std::is_same_v) { + // fp8_e4m3b15 is a software-only type with no hardware NVLS support. + return cudaErrorNotSupported; } else #if (!defined(__CUDA_ARCH_SPECIFIC__) && !defined(__CUDA_ARCH_FAMILY_SPECIFIC__)) || (__CUDA_ARCH__ < 1000) if constexpr (std::is_same_v || std::is_same_v) { @@ -114,13 +117,14 @@ void AllreduceNvls::initialize(std::shared_ptr comm) { CommResult AllreduceNvls::allreduceKernelFunc(const std::shared_ptr ctx_void, const void* input, void* output, size_t inputSize, mscclpp::DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map&) { + [[maybe_unused]] const std::unordered_map& extras, + mscclpp::DataType accumDtype) { if (!symmetricMemory_) { WARN("AllreduceNvls requires symmetric memory for now."); return CommResult::CommInvalidArgument; } auto ctx = std::static_pointer_cast(ctx_void); - AllreduceFunc allreduce = dispatch(op, dtype); + AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast(dtype)); return CommResult::CommInvalidArgument; @@ -203,9 +207,10 @@ std::shared_ptr AllreduceNvls::build() { [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, mscclpp::DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) { + int nThreadsPerBlock, const std::unordered_map& extras, + mscclpp::DataType accumDtype) { return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, - extras); + extras, accumDtype); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, diff --git a/src/ext/collectives/allreduce/allreduce_packet.cu b/src/ext/collectives/allreduce/allreduce_packet.cu index ceb545ee..e2d8ef73 100644 --- a/src/ext/collectives/allreduce/allreduce_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_packet.cu @@ -2,16 +2,17 @@ // Licensed under the MIT License. #include +#include #include "allreduce/allreduce_packet.hpp" #include "allreduce/common.hpp" #include "collective_utils.hpp" -#include "debug.h" +#include "logger.hpp" namespace mscclpp { namespace collective { -template +template __global__ void __launch_bounds__(1024, 1) allreducePacket(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* memoryChannels, size_t channelDataOffset, size_t scratchBufferSize, int rank, int nRanksPerNode, int worldSize, @@ -92,12 +93,21 @@ __global__ void __launch_bounds__(1024, 1) // step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) { uint2 data = src[idx]; - for (int index = 0; index < nPeers; index++) { - const int remoteRank = index < rank ? index : index + 1; - mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank; - uint2 val = dstPkt[idx].read(flag); - data.x = cal_vector(val.x, data.x); - data.y = cal_vector(val.y, data.y); + { + // When T == AccumT, stay with raw uint32_t to avoid type mismatch in identity path. + using AccRaw = std::conditional_t, uint32_t, + mscclpp::VectorType>; + AccRaw accX = mscclpp::upcastVector(data.x); + AccRaw accY = mscclpp::upcastVector(data.y); + for (int index = 0; index < nPeers; index++) { + const int remoteRank = index < rank ? index : index + 1; + mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank; + uint2 val = dstPkt[idx].read(flag); + accX = mscclpp::calVectorAccum(accX, val.x); + accY = mscclpp::calVectorAccum(accY, val.y); + } + data.x = mscclpp::downcastVector(accX); + data.y = mscclpp::downcastVector(accY); } dst[idx].x = data.x; @@ -142,7 +152,7 @@ __global__ void __launch_bounds__(1024, 1) #endif } -template +template struct PacketAdapter { static cudaError_t call(const void* buff, void* scratch, void* resultBuff, void* memoryChannels, void*, DeviceHandle*, DeviceHandle*, size_t channelInOffset, size_t, @@ -155,12 +165,12 @@ struct PacketAdapter { nBlocks = nBlocks / (worldSize - 1) * (worldSize - 1); #if defined(ENABLE_NPKIT) size_t sharedMemSize = sizeof(NpKitEvent) * NPKIT_SHM_NUM_EVENTS; - allreducePacket<<>>( + allreducePacket<<>>( (T*)buff, (T*)scratch, (T*)resultBuff, (ChannelType*)memoryChannels, channelInOffset, scratchBufferSize, rank, nRanksPerNode, worldSize, nelems, flags, flagBufferSize, numScratchBuff, NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); #else - allreducePacket<<>>( + allreducePacket<<>>( (T*)buff, (T*)scratch, (T*)resultBuff, (ChannelType*)memoryChannels, channelInOffset, scratchBufferSize, rank, nRanksPerNode, worldSize, nelems, flags, flagBufferSize, numScratchBuff); #endif @@ -186,18 +196,22 @@ inline std::pair getDefaultBlockNumAndThreadNum(size_t inputSize, int } } -#if defined(__FP8_TYPES_EXIST__) // FP8-specific tuning for 32KB-256KB range - if (dtype == DataType::FLOAT8_E4M3 || dtype == DataType::FLOAT8_E5M2) { - if (inputSize < (64 << 10)) { - nThreadsPerBlock = 64; - } else if (inputSize >= (64 << 10) && inputSize <= (128 << 10)) { - nThreadsPerBlock = 128; - } else if (inputSize >= (128 << 10) && inputSize <= (256 << 10)) { - nThreadsPerBlock = 256; + { + bool isFp8 = dtype == DataType::FLOAT8_E4M3B15; +#if defined(__FP8_TYPES_EXIST__) + isFp8 = isFp8 || dtype == DataType::FLOAT8_E4M3 || dtype == DataType::FLOAT8_E5M2; +#endif + if (isFp8) { + if (inputSize < (64 << 10)) { + nThreadsPerBlock = 64; + } else if (inputSize >= (64 << 10) && inputSize <= (128 << 10)) { + nThreadsPerBlock = 128; + } else if (inputSize >= (128 << 10) && inputSize <= (256 << 10)) { + nThreadsPerBlock = 256; + } } } -#endif #endif return {nBlocks, nThreadsPerBlock}; } @@ -213,7 +227,8 @@ void AllreducePacket::initialize(std::shared_ptr comm) { CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr ctx_void, const void* input, void* output, size_t inputSize, [[maybe_unused]] DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map&) { + const std::unordered_map&, + DataType accumDtype) { auto ctx = std::static_pointer_cast(ctx_void); std::pair blockAndThreadNum = {nBlocks, nThreadsPerBlock}; if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) { @@ -225,9 +240,10 @@ CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr ctx_ MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)input)); size_t channelInOffset = (char*)input - (char*)sendBasePtr; - AllreduceFunc allreduce = dispatch(op, dtype); + AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { - WARN("Unsupported operation or data type for allreduce: op=%d, dtype=%d", op, static_cast(dtype)); + WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast(op), + ", dtype=", static_cast(dtype)); return CommResult::CommInvalidArgument; } cudaError_t error = @@ -236,7 +252,7 @@ CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr ctx_ stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_, this->nSegmentsForScratchBuffer_, blockAndThreadNum.first, blockAndThreadNum.second); if (error != cudaSuccess) { - WARN("AllreducePacket failed with error: %s", cudaGetErrorString(error)); + WARN(ALGO, "AllreducePacket failed with error: ", cudaGetErrorString(error)); return CommResult::CommUnhandledCudaError; } return CommResult::CommSuccess; @@ -280,9 +296,9 @@ std::shared_ptr AllreducePacket::build() { "default_allreduce_packet", "allreduce", [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) { + int nThreadsPerBlock, const std::unordered_map& extras, DataType accumDtype) { return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, - extras); + extras, accumDtype); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, diff --git a/src/ext/collectives/allreduce/allreduce_rsag.cu b/src/ext/collectives/allreduce/allreduce_rsag.cu index d5be2257..db471b93 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag.cu @@ -87,7 +87,7 @@ __global__ void __launch_bounds__(1024, 1) int rankIdx = (rank + i + 1) % nRanksPerNode; int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1; int4 data = mscclpp::read(((void**)remoteMemories)[peerIdx], offset); - tmp = cal_vector(data, tmp); + tmp = calVector(data, tmp); } for (uint32_t i = 0; i < nPeers; i++) { int rankIdx = (rank + i + 1) % nRanksPerNode; @@ -123,7 +123,7 @@ __global__ void __launch_bounds__(1024, 1) } } -template +template struct AllreduceRsAgAdapter { static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories, DeviceHandle* switchChannel, DeviceHandle*, size_t, size_t, @@ -166,9 +166,9 @@ void AllreduceRsAg::initialize(std::shared_ptr comm) { CommResult AllreduceRsAg::allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map&) { + const std::unordered_map&, DataType accumDtype) { auto algoCtx = std::static_pointer_cast(ctx); - AllreduceFunc allreduce = dispatch(op, dtype); + AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast(op), ", dtype=", static_cast(dtype)); @@ -213,9 +213,10 @@ std::shared_ptr AllreduceRsAg::build() { [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) -> CommResult { + int nThreadsPerBlock, const std::unordered_map& extras, + DataType accumDtype) -> CommResult { return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, - extras); + extras, accumDtype); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, diff --git a/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu b/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu index a230d8cd..eabe3dc5 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu @@ -168,7 +168,7 @@ __global__ void __launch_bounds__(1024, 1) uint32_t peerSlotOffset = baseOffset + remoteRankId * nInt4PerIter + threadIdInPut + putStep * blockDim.x * nblocksForPut; int4 data = scratch4[peerSlotOffset]; - tmp = cal_vector(data, tmp); + tmp = calVector(data, tmp); } storeVec(resultBuff, myChunkOffset, tmp, nelems); // Broadcast reduced result to all peers' scratch at SCATTER_AG_OFFSET + rank * nInt4PerIter @@ -220,7 +220,7 @@ __global__ void __launch_bounds__(1024, 1) } } -template +template struct AllreduceRsAgPipelineAdapter { static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories, DeviceHandle* switchChannel, DeviceHandle*, size_t, size_t, @@ -274,12 +274,12 @@ void AllreduceRsAgPipeline::initialize(std::shared_ptr comm) { cudaMemcpyHostToDevice); } -CommResult AllreduceRsAgPipeline::allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, - size_t inputSize, DataType dtype, ReduceOp op, - cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map&) { +CommResult AllreduceRsAgPipeline::allreduceKernelFunc( + const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, + cudaStream_t stream, int nBlocks, int nThreadsPerBlock, + [[maybe_unused]] const std::unordered_map& extras, DataType accumDtype) { auto algoCtx = std::static_pointer_cast(ctx); - AllreduceFunc allreduce = dispatch(op, dtype); + AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast(op), ", dtype=", static_cast(dtype)); @@ -320,9 +320,10 @@ std::shared_ptr AllreduceRsAgPipeline::build() { [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) -> CommResult { + int nThreadsPerBlock, const std::unordered_map& extras, + DataType accumDtype) -> CommResult { return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, - extras); + extras, accumDtype); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, diff --git a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu index caac07ae..f95ba7e3 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +#include + #include "allreduce/allreduce_rsag_zero_copy.hpp" #include "allreduce/common.hpp" #include "collective_utils.hpp" @@ -36,7 +38,7 @@ __device__ mscclpp::DeviceSyncer globalSyncer; // the extra copy steps of the standard RSAG. The NRanksPerNode template // parameter enables compile-time unrolling of peer loops (supports 4 or 8). -template +template __global__ void __launch_bounds__(1024, 1) allreduceRsAgZeroCopy(T* buff, T* scratch, T* resultBuff, DeviceHandle* memoryChannels, DeviceHandle* switchChannels, void* remoteMemories, int rank, int worldSize, @@ -73,19 +75,26 @@ __global__ void __launch_bounds__(1024, 1) } __syncthreads(); int4 data[NPeers]; + // AccumInt4: when AccumT != T, use a wider accumulator type. + // For AccumT == T, this is just int4 (no-op conversion). + constexpr int nElemsPerInt4 = sizeof(int4) / sizeof(T); + // When T == AccumT, stay with raw int4 to avoid type mismatch in identity path. + using AccumVec = std::conditional_t, int4, mscclpp::VectorType>; for (uint32_t idx = threadIdx.x; idx < nInt4PerBlock; idx += blockDim.x) { uint32_t offset = idx + offset4 + rank * nInt4PerRank; if (offset >= nInt4Total) continue; - int4 tmp = buff4[offset]; + int4 tmp_raw = buff4[offset]; #pragma unroll for (int i = 0; i < NPeers; i++) { int rankIdx = (rank + i + 1) % NRanksPerNode; int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1; data[i] = mscclpp::read(((void**)remoteMemories)[peerIdx], offset); } + AccumVec acc = mscclpp::upcastVector(tmp_raw); for (int i = 0; i < NPeers; i++) { - tmp = cal_vector(data[i], tmp); + acc = mscclpp::calVectorAccum(acc, data[i]); } + int4 tmp = mscclpp::downcastVector(acc); #pragma unroll for (int i = 0; i < NPeers; i++) { int rankIdx = (rank + i + 1) % NRanksPerNode; @@ -102,7 +111,7 @@ __global__ void __launch_bounds__(1024, 1) } } -template +template struct AllreduceRsAgZeroCopyAdapter { static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories, DeviceHandle* switchChannel, DeviceHandle*, size_t, size_t, @@ -118,11 +127,11 @@ struct AllreduceRsAgZeroCopyAdapter { } } if (nRanksPerNode == 4) { - allreduceRsAgZeroCopy<4, OpType, T> + allreduceRsAgZeroCopy<4, OpType, T, AccumT> <<>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, switchChannel, remoteMemories, rank, worldSize, nelems); } else if (nRanksPerNode == 8) { - allreduceRsAgZeroCopy<8, OpType, T> + allreduceRsAgZeroCopy<8, OpType, T, AccumT> <<>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, switchChannel, remoteMemories, rank, worldSize, nelems); } else { @@ -145,9 +154,10 @@ void AllreduceRsAgZeroCopy::initialize(std::shared_ptr comm) { CommResult AllreduceRsAgZeroCopy::allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map&) { + const std::unordered_map&, + DataType accumDtype) { auto algoCtx = std::static_pointer_cast(ctx); - AllreduceFunc allreduce = dispatch(op, dtype); + AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast(op), ", dtype=", static_cast(dtype)); @@ -220,9 +230,10 @@ std::shared_ptr AllreduceRsAgZeroCopy::build() { [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) -> CommResult { + int nThreadsPerBlock, const std::unordered_map& extras, + DataType accumDtype) -> CommResult { return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, - extras); + extras, accumDtype); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, diff --git a/src/ext/collectives/include/allreduce/allreduce_allpair_packet.hpp b/src/ext/collectives/include/allreduce/allreduce_allpair_packet.hpp index bd402cfa..362308b2 100644 --- a/src/ext/collectives/include/allreduce/allreduce_allpair_packet.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_allpair_packet.hpp @@ -20,7 +20,7 @@ class AllreduceAllpairPacket : public AlgorithmBuilder { void initialize(std::shared_ptr comm); CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras); + const std::unordered_map& extras, DataType accumDtype); std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, DataType); diff --git a/src/ext/collectives/include/allreduce/allreduce_fullmesh.hpp b/src/ext/collectives/include/allreduce/allreduce_fullmesh.hpp index fa811b15..a54352b3 100644 --- a/src/ext/collectives/include/allreduce/allreduce_fullmesh.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_fullmesh.hpp @@ -16,7 +16,7 @@ class AllreduceFullmesh : public mscclpp::AlgorithmBuilder { void initialize(std::shared_ptr comm); CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras); + const std::unordered_map& extras, DataType accumDtype); std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, DataType); diff --git a/src/ext/collectives/include/allreduce/allreduce_nvls_block_pipeline.hpp b/src/ext/collectives/include/allreduce/allreduce_nvls_block_pipeline.hpp index 8b9b04ae..81b74add 100644 --- a/src/ext/collectives/include/allreduce/allreduce_nvls_block_pipeline.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_nvls_block_pipeline.hpp @@ -19,7 +19,7 @@ class AllreduceNvlsBlockPipeline : public AlgorithmBuilder { void initialize(std::shared_ptr comm); CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras); + const std::unordered_map& extras, DataType accumDtype); std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, DataType); diff --git a/src/ext/collectives/include/allreduce/allreduce_nvls_packet.hpp b/src/ext/collectives/include/allreduce/allreduce_nvls_packet.hpp index 65a48923..fb0c63b8 100644 --- a/src/ext/collectives/include/allreduce/allreduce_nvls_packet.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_nvls_packet.hpp @@ -21,7 +21,8 @@ class AllreduceNvlsPacket : public mscclpp::AlgorithmBuilder { void initialize(std::shared_ptr comm); CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, mscclpp::DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras); + int nThreadsPerBlock, const std::unordered_map& extras, + mscclpp::DataType accumDtype); std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, mscclpp::DataType); @@ -34,6 +35,7 @@ class AllreduceNvlsPacket : public mscclpp::AlgorithmBuilder { uintptr_t flagBuffer_; size_t flagBufferSize_; std::vector> nvlsConnections_; + std::vector switchChannels_; }; } // namespace collective } // namespace mscclpp diff --git a/src/ext/collectives/include/allreduce/allreduce_nvls_warp_pipeline.hpp b/src/ext/collectives/include/allreduce/allreduce_nvls_warp_pipeline.hpp index e392b54e..8f02a873 100644 --- a/src/ext/collectives/include/allreduce/allreduce_nvls_warp_pipeline.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_nvls_warp_pipeline.hpp @@ -19,7 +19,7 @@ class AllreduceNvlsWarpPipeline : public AlgorithmBuilder { void initialize(std::shared_ptr comm); CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras); + const std::unordered_map& extras, DataType accumDtype); std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, DataType); diff --git a/src/ext/collectives/include/allreduce/allreduce_nvls_zero_copy.hpp b/src/ext/collectives/include/allreduce/allreduce_nvls_zero_copy.hpp index d0593500..d53ea180 100644 --- a/src/ext/collectives/include/allreduce/allreduce_nvls_zero_copy.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_nvls_zero_copy.hpp @@ -19,7 +19,7 @@ class AllreduceNvls : public AlgorithmBuilder { void initialize(std::shared_ptr comm); CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras); + const std::unordered_map& extras, DataType accumDtype); std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, DataType); diff --git a/src/ext/collectives/include/allreduce/allreduce_packet.hpp b/src/ext/collectives/include/allreduce/allreduce_packet.hpp index f0438dea..de7ca471 100644 --- a/src/ext/collectives/include/allreduce/allreduce_packet.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_packet.hpp @@ -20,7 +20,7 @@ class AllreducePacket : public AlgorithmBuilder { void initialize(std::shared_ptr comm); CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras); + const std::unordered_map& extras, DataType accumDtype); std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, DataType); diff --git a/src/ext/collectives/include/allreduce/allreduce_rsag.hpp b/src/ext/collectives/include/allreduce/allreduce_rsag.hpp index 6e033f67..1fd663da 100644 --- a/src/ext/collectives/include/allreduce/allreduce_rsag.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_rsag.hpp @@ -19,7 +19,7 @@ class AllreduceRsAg : public mscclpp::AlgorithmBuilder { void initialize(std::shared_ptr comm); CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras); + const std::unordered_map& extras, DataType accumDtype); std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, DataType); diff --git a/src/ext/collectives/include/allreduce/allreduce_rsag_pipeline.hpp b/src/ext/collectives/include/allreduce/allreduce_rsag_pipeline.hpp index 2a740ac0..7629f2fe 100644 --- a/src/ext/collectives/include/allreduce/allreduce_rsag_pipeline.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_rsag_pipeline.hpp @@ -19,7 +19,7 @@ class AllreduceRsAgPipeline : public mscclpp::AlgorithmBuilder { void initialize(std::shared_ptr comm); CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras); + const std::unordered_map& extras, DataType accumDtype); std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, DataType); diff --git a/src/ext/collectives/include/allreduce/allreduce_rsag_zero_copy.hpp b/src/ext/collectives/include/allreduce/allreduce_rsag_zero_copy.hpp index 6153a0e4..05bf2ef3 100644 --- a/src/ext/collectives/include/allreduce/allreduce_rsag_zero_copy.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_rsag_zero_copy.hpp @@ -18,7 +18,7 @@ class AllreduceRsAgZeroCopy : public mscclpp::AlgorithmBuilder { void initialize(std::shared_ptr comm); CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras); + const std::unordered_map& extras, DataType accumDtype); std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, DataType); diff --git a/src/ext/collectives/include/allreduce/common.hpp b/src/ext/collectives/include/allreduce/common.hpp index 9bfac69a..1e0e7e69 100644 --- a/src/ext/collectives/include/allreduce/common.hpp +++ b/src/ext/collectives/include/allreduce/common.hpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -#ifndef MSCCLPP_ALLREDUCE_COMMOM_HPP_ -#define MSCCLPP_ALLREDUCE_COMMOM_HPP_ +#ifndef MSCCLPP_ALLREDUCE_COMMON_HPP_ +#define MSCCLPP_ALLREDUCE_COMMON_HPP_ #include #include @@ -77,55 +77,51 @@ using AllreduceFunc = mscclpp::DeviceHandle*, size_t, size_t, size_t, int, int, int, size_t, cudaStream_t, void*, uint32_t, uint32_t, int, int)>; -template