merge main

This commit is contained in:
Ubuntu
2026-04-10 23:19:15 +00:00
60 changed files with 2289 additions and 564 deletions

View File

@@ -94,7 +94,27 @@ steps:
du -sh build/bin/* 2>/dev/null || true
workingDirectory: '$(System.DefaultWorkingDirectory)'
# 2. Download SSH key + install packages + start VMSS
# 2. Write CMake args for pip install on remote VMs
- task: Bash@3
name: WritePipCmakeArgs
displayName: Write pip CMake args
inputs:
targetType: 'inline'
script: |
set -e
PIP_CMAKE_ARGS=""
if [ -n "${{ parameters.gpuArch }}" ]; then
PIP_CMAKE_ARGS="-DMSCCLPP_GPU_ARCHS=${{ parameters.gpuArch }}"
fi
CMAKE_EXTRA_ARGS='${{ parameters.cmakeArgs }}'
if [ -n "${CMAKE_EXTRA_ARGS}" ]; then
PIP_CMAKE_ARGS="${PIP_CMAKE_ARGS} ${CMAKE_EXTRA_ARGS}"
fi
echo "${PIP_CMAKE_ARGS}" > pip_cmake_args.txt
echo "pip CMake args: $(cat pip_cmake_args.txt)"
workingDirectory: '$(System.DefaultWorkingDirectory)'
# 3. Download SSH key + install packages + start VMSS
- task: DownloadSecureFile@1
name: SshKeyFile
displayName: Download key file
@@ -120,7 +140,7 @@ steps:
inlineScript: |
az vmss start --name ${{ parameters.vmssName }} --resource-group ${{ parameters.resourceGroup }}
# 3. Deploy test environment
# 4. Deploy test environment
- task: Bash@3
name: DeployTestEnv
displayName: Deploy Test Env

View File

@@ -28,7 +28,7 @@ steps:
grep -q NPKIT_EVENT_EXECUTOR_INIT_ENTRY ./npkit_output/npkit_event_trace.json
grep -q NPKIT_EVENT_EXECUTOR_SIGNAL_ENTRY ./npkit_output/npkit_event_trace.json
grep -q NPKIT_EVENT_EXECUTOR_WAIT_ENTRY ./npkit_output/npkit_event_trace.json
grep -q NPKIT_EVENT_EXECUTOR_READ_REDUCE_COPY_SEND_ENTRY ./npkit_output/npkit_event_trace.json
grep -q NPKIT_EVENT_EXECUTOR_READ_REDUCE_SEND_ENTRY ./npkit_output/npkit_event_trace.json
- template: run-remote-task.yml
parameters:
@@ -42,14 +42,14 @@ steps:
grep -q NPKIT_EVENT_EXECUTOR_INIT_ENTRY ./npkit_output/npkit_event_trace.json
grep -q NPKIT_EVENT_EXECUTOR_SIGNAL_ENTRY ./npkit_output/npkit_event_trace.json
grep -q NPKIT_EVENT_EXECUTOR_WAIT_ENTRY ./npkit_output/npkit_event_trace.json
grep -q NPKIT_EVENT_EXECUTOR_READ_REDUCE_COPY_SEND_ENTRY ./npkit_output/npkit_event_trace.json
grep -q NPKIT_EVENT_EXECUTOR_READ_REDUCE_SEND_ENTRY ./npkit_output/npkit_event_trace.json
rm -rf ./npkit_dump && mkdir ./npkit_dump && rm -rf ./npkit_output && mkdir ./npkit_output
mpirun --allow-run-as-root -tag-output -x MSCCLPP_HOME=/root/mscclpp -np 8 python3 -m pytest ./python/test/test_mscclpp.py -x -k 'test_executor[allreduce_packet.json'
python3 ./tools/npkit/npkit_trace_generator.py --npkit_dump_dir=./npkit_dump --npkit_event_header_path=./include/mscclpp/npkit/npkit_event.hpp --output_dir=./npkit_output
grep -q NPKIT_EVENT_EXECUTOR_INIT_ENTRY ./npkit_output/npkit_event_trace.json
grep -q NPKIT_EVENT_EXECUTOR_COPY_PACKET_ENTRY ./npkit_output/npkit_event_trace.json
grep -q NPKIT_EVENT_EXECUTOR_PUT_PACKET_ENTRY ./npkit_output/npkit_event_trace.json
grep -q NPKIT_EVENT_EXECUTOR_REDUCE_SEND_PACKET_ENTRY ./npkit_output/npkit_event_trace.json
grep -q NPKIT_EVENT_EXECUTOR_PUT_PACKETS_ENTRY ./npkit_output/npkit_event_trace.json
grep -q NPKIT_EVENT_EXECUTOR_REDUCE_SEND_PACKETS_ENTRY ./npkit_output/npkit_event_trace.json
grep -q NPKIT_EVENT_EXECUTOR_UNPACK_PACKETS_ENTRY ./npkit_output/npkit_event_trace.json
- template: stop.yml
parameters:

View File

@@ -41,6 +41,7 @@ steps:
displayName: Run pytests
remoteScript: |
mpirun --allow-run-as-root -tag-output -x MSCCLPP_HOME=/root/mscclpp -x GPU_MAX_HW_QUEUES=8 -np 8 python3 -m pytest ./python/test/test_mscclpp.py -x
mpirun --allow-run-as-root -tag-output -x MSCCLPP_HOME=/root/mscclpp -x GPU_MAX_HW_QUEUES=8 -np 8 python3 -m pytest ./python/test/test_fp8_accum.py -x
- template: stop.yml
parameters:

View File

@@ -1 +1 @@
0.8.0
0.9.0

View File

@@ -30,19 +30,15 @@ find_library(GDRCOPY_LIBRARIES
${GDRCOPY_ROOT_DIR}/lib
/usr/local/lib
/usr/lib
/usr/lib/x86_64-linux-gnu
/usr/lib/aarch64-linux-gnu)
/usr/lib/x86_64-linux-gnu)
if(GDRCOPY_INCLUDE_DIRS)
include(CheckCXXSourceCompiles)
include(CheckSymbolExists)
set(CMAKE_REQUIRED_INCLUDES ${GDRCOPY_INCLUDE_DIRS})
set(CMAKE_REQUIRED_LIBRARIES ${GDRCOPY_LIBRARIES})
check_cxx_source_compiles("
#include <gdrapi.h>
int main() { gdr_pin_buffer_v2(0, 0, 0, 0, 0); return 0; }
" GDRCOPY_HAS_PIN_BUFFER_V2)
unset(CMAKE_REQUIRED_INCLUDES)
check_symbol_exists(gdr_pin_buffer_v2 "gdrapi.h" GDRCOPY_HAS_PIN_BUFFER_V2)
unset(CMAKE_REQUIRED_LIBRARIES)
unset(CMAKE_REQUIRED_INCLUDES)
if(NOT GDRCOPY_HAS_PIN_BUFFER_V2)
message(STATUS "GDRCopy found but too old (gdr_pin_buffer_v2 not available). Requires >= 2.5.")
set(GDRCOPY_INCLUDE_DIRS GDRCOPY_INCLUDE_DIRS-NOTFOUND)

View File

@@ -52,7 +52,7 @@ RUN OS_ARCH=$(uname -m) && \
# Install GDRCopy userspace library for CUDA targets
ARG TARGET="cuda13.0"
RUN if echo "$TARGET" | grep -q "^cuda"; then \
GDRCOPY_VERSION="2.5.1" && \
GDRCOPY_VERSION="2.5.2" && \
apt-get update -y && \
apt-get install -y --no-install-recommends devscripts debhelper fakeroot pkg-config dkms && \
cd /tmp && \

View File

@@ -5,7 +5,7 @@
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SPHINXMULTIVERSION ?= sphinx-multiversion
SPHINXMULTIVERSION ?= python3 build_multiversion.py
SOURCEDIR = .
BUILDDIR = _build

View File

@@ -0,0 +1,49 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Wrapper around sphinx-multiversion that patches copy_tree to generate
_version.py in each tag checkout. This is needed because setuptools_scm
generates _version.py at build time, but sphinx-multiversion uses
`git archive` which only contains committed files.
Usage (called by Makefile):
python3 build_multiversion.py <sourcedir> <outputdir> [sphinx-opts...]
"""
import os
import re
import subprocess
import sys
import sphinx_multiversion.git as smv_git
from sphinx_multiversion import main as smv_main
# Save the original copy_tree
_original_copy_tree = smv_git.copy_tree
def _patched_copy_tree(gitroot, src, dst, reference, sourcepath="."):
"""Call original copy_tree, then generate _version.py from the VERSION file."""
_original_copy_tree(gitroot, src, dst, reference, sourcepath)
# Extract version from the tag name (e.g., "v0.9.0" -> "0.9.0")
refname = getattr(reference, "refname", "") or ""
match = re.search(r"v(\d+\.\d+\.\d+)", refname)
if not match:
return
version = match.group(1)
version_py_dir = os.path.join(dst, "python", "mscclpp")
if os.path.isdir(version_py_dir):
version_py = os.path.join(version_py_dir, "_version.py")
if not os.path.exists(version_py):
with open(version_py, "w") as f:
f.write(f'__version__ = "{version}"\n')
# Monkey-patch
smv_git.copy_tree = _patched_copy_tree
if __name__ == "__main__":
sys.exit(smv_main(sys.argv[1:]))

View File

@@ -332,7 +332,8 @@ public:
size_t inputSize, size_t outputSize,
mscclpp::DataType dtype, mscclpp::ReduceOp op,
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>& extras) {
const std::unordered_map<std::string, uintptr_t>& extras,
[[maybe_unused]] mscclpp::DataType accumDtype) {
return self->kernelFunc(ctx, input, output, inputSize, dtype, stream);
},
// Context initialization function

View File

@@ -31,7 +31,7 @@
```
If you don't want to build Python module, you need to set `-DMSCCLPP_BUILD_PYTHON_BINDINGS=OFF` in your `cmake` command (see details in [Install from Source](#install-from-source)).
* (Optional, for benchmarks) MPI
* (Optional, for NVIDIA platforms) [GDRCopy](https://github.com/NVIDIA/gdrcopy) >= 2.5.0
* (Optional, for NVIDIA platforms) [GDRCopy](https://github.com/NVIDIA/gdrcopy) >= 2.5.1
* GDRCopy is required for IB `HostNoAtomic` mode, which uses CPU-side signal forwarding to GPU memory via BAR1 mappings. This mode is used on platforms where RDMA atomics are not available (e.g., when using Data Direct Virtual Functions).
* Install GDRCopy from source or via packages. See the [GDRCopy installation guide](https://github.com/NVIDIA/gdrcopy#installation).
* Others

View File

@@ -101,7 +101,8 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
"allgather", "allgather", [self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize, size_t outputSize,
mscclpp::DataType dtype, [[maybe_unused]] mscclpp::ReduceOp op, cudaStream_t stream, int nBlocks,
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) {
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras,
[[maybe_unused]] mscclpp::DataType accumDtype) {
return self->allgatherKernelFunc(ctx, input, output, inputSize, stream);
},
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t inputSize,

View File

@@ -69,7 +69,8 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
"allgather", "allgather", [self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize, size_t outputSize,
mscclpp::DataType dtype, [[maybe_unused]] mscclpp::ReduceOp op, cudaStream_t stream, int nBlocks,
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) {
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras,
[[maybe_unused]] mscclpp::DataType accumDtype) {
return self->allgatherKernelFunc(ctx, input, output, inputSize, dtype, stream);
},
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t inputSize,

View File

@@ -1,193 +1,117 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# MSCCLPP_MASTER_ADDR=<master_ip> MSCCLPP_MASTER_PORT=<port> torchrun --nnodes=1 --nproc_per_node=8 customized_comm_with_tuning.py
# torchrun --nnodes=1 --nproc_per_node=8 examples/torch-integration/customized_comm_with_tuning.py
import os
import torch
import mscclpp.utils as mscclpp_utils
import mscclpp
import mscclpp.ext
import netifaces as ni
import ipaddress
import netifaces as ni
import torch
import mscclpp
import mscclpp.ext
import mscclpp.utils as mscclpp_utils
def load_algorithms(scratch_buffer: torch.tensor, rank: int) -> mscclpp.AlgorithmCollection:
collection_builder = mscclpp.ext.AlgorithmCollectionBuilder()
return collection_builder.build_default_algorithms(
scratch_buffer=scratch_buffer.data_ptr(), scratch_buffer_size=scratch_buffer.nbytes, rank=rank
# -- Helpers ------------------------------------------------------------------
def _make_tensor(size_bytes: int, dtype: torch.dtype) -> torch.Tensor:
"""Allocate a tensor backed by RawGpuBuffer (symmetric memory)."""
# PyTorch's from_dlpack does not support certain float8 DLPack type codes.
# Work around by importing as uint8 and reinterpreting via .view().
_DLPACK_UNSUPPORTED = (torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)
if dtype in _DLPACK_UNSUPPORTED:
dlpack = mscclpp.RawGpuBuffer(size_bytes).to_dlpack(data_type=str(torch.uint8))
return torch.utils.dlpack.from_dlpack(dlpack).view(dtype)
dlpack = mscclpp.RawGpuBuffer(size_bytes).to_dlpack(data_type=str(dtype))
return torch.utils.dlpack.from_dlpack(dlpack)
def _load_algorithms(scratch: torch.Tensor, rank: int):
return mscclpp.ext.AlgorithmCollectionBuilder().build_default_algorithms(
scratch_buffer=scratch.data_ptr(),
scratch_buffer_size=scratch.nbytes,
rank=rank,
)
def interfaces_for_ip_netifaces(ip: str):
def _interfaces_for_ip(ip: str):
target = ipaddress.ip_address(ip)
for interface in ni.interfaces():
addresses = ni.ifaddresses(interface)
if ni.AF_INET in addresses:
for link in addresses[ni.AF_INET]:
if "addr" in link:
addr = ipaddress.ip_address(link["addr"])
if addr == target:
return interface
for iface in ni.interfaces():
addrs = ni.ifaddresses(iface)
if ni.AF_INET in addrs:
for link in addrs[ni.AF_INET]:
if "addr" in link and ipaddress.ip_address(link["addr"]) == target:
return iface
return None
def to_mscclpp_reduce_op(op: torch.distributed.ReduceOp) -> mscclpp.ReduceOp:
def _to_mscclpp_op(op) -> mscclpp.ReduceOp:
if op == torch.distributed.ReduceOp.SUM:
return mscclpp.ReduceOp.SUM
elif op == torch.distributed.ReduceOp.MIN:
if op == torch.distributed.ReduceOp.MIN:
return mscclpp.ReduceOp.MIN
else:
raise ValueError(f"unsupported op: {op}")
raise ValueError(f"unsupported op: {op}")
def _round_pow2(size: int) -> int:
"""Round up to next power-of-2, clamped to [1024, 256 MB]."""
size = max(size, 1024)
size = min(size, 256 << 20)
return 1 << (size - 1).bit_length()
# -- CustomizedComm -----------------------------------------------------------
class CustomizedComm:
def __init__(self, comm: mscclpp.CommGroup):
"""Exposes all_reduce, all_gather, barrier with lazy per-size tuning."""
_TUNE_N_WARMUP = 5
_TUNE_N_GRAPH_LAUNCHES = 10
_TUNE_N_OPS_PER_GRAPH = 100
_CANDIDATE_NBLOCKS = [4, 8, 16, 24, 32, 48, 64, 128]
_CANDIDATE_NTHREADS = [512, 768, 1024]
_NBLOCKS_LIMIT = {
"default_allreduce_nvls_packet": 16,
"default_allreduce_packet": 56,
"default_allreduce_allpair_packet": 56,
"default_allreduce_fullmesh": 64,
"default_allgather_fullmesh2": 32,
}
def __init__(self, comm: mscclpp.CommGroup, symmetric_memory: bool = False):
self.comm = comm
self.rank = comm.my_rank
self.world_size = comm.nranks
self.local_rank = comm.my_rank % comm.nranks_per_node
self.n_ranks_per_node = comm.nranks_per_node
dlpack = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(torch.float16))
self.scratch_buffer = torch.utils.dlpack.from_dlpack(dlpack)
algorithms = load_algorithms(scratch_buffer=self.scratch_buffer, rank=self.rank)
self._algorithm_nvls_packet = [
algo
for algo in algorithms
if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_packet"
][0]
self._algorithm_rsag_zero_copy = [
algo
for algo in algorithms
if algo.collective == "allreduce" and algo.name == "default_allreduce_rsag_zero_copy"
][0]
self._algorithm_packet = [
algo for algo in algorithms if algo.collective == "allreduce" and algo.name == "default_allreduce_packet"
][0]
if mscclpp.is_nvls_supported():
self._algorithm_nvls_zero_copy = [
algo
for algo in algorithms
if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_zero_copy"
][0]
self._tune(n_warmup=5, n_graph_launches=10, n_ops_per_graph=100)
self.symmetric_memory = symmetric_memory
self._nvls = mscclpp.is_nvls_supported()
def _tune(self, n_warmup, n_graph_launches, n_ops_per_graph):
sizes = [1 << i for i in range(10, 28)]
# Pre-fill with defaults for barrier
self.best_configs = {1024: (self._algorithm_nvls_packet, 0, 0)}
self._scratch = _make_tensor(1 << 27, torch.float16)
self._barrier_tensor = _make_tensor(4096, torch.float32)
tune_tensor = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(torch.float16))
tune_tensor = torch.utils.dlpack.from_dlpack(tune_tensor)
tune_tensor.normal_()
candidates_nblocks = [4, 8, 16, 24, 32, 48, 64, 128]
candidates_nthreads = [512, 768, 1024]
algos = _load_algorithms(self._scratch, self.rank)
self._algos = {(a.collective, a.name): a for a in algos}
for size in sizes:
algos = []
if mscclpp.is_nvls_supported():
algos.append(self._algorithm_nvls_zero_copy)
if size <= 4 * 1024 * 1024:
algos.append(self._algorithm_nvls_packet)
algos.append(self._algorithm_packet)
if size >= 512 * 1024:
algos.append(self._algorithm_rsag_zero_copy)
# {collective: {rounded_size: (algo, nblocks, nthreads)}}
self._tune_cache: dict[str, dict[int, tuple]] = {"allreduce": {}, "allgather": {}}
self._tune_buf = None
self._time_buf = None
best_time = float("inf")
best_config = None
def _algo(self, collective: str, name: str):
return self._algos.get((collective, name))
for algo in algos:
for nb in candidates_nblocks:
if algo.name == "default_allreduce_nvls_packet" and nb > 16:
continue
if algo.name == "default_allreduce_packet" and nb > 56:
continue
for nt in candidates_nthreads:
if self._run_algo(algo, tune_tensor, size, nb, nt) != 0:
continue
def _default_ar_config(self):
"""Fallback allreduce config for barrier / timing sync."""
pkt = self._algo("allreduce", "default_allreduce_nvls_packet")
if self._nvls and pkt:
return (pkt, 0, 0)
return (self._algo("allreduce", "default_allreduce_packet"), 0, 0)
for _ in range(n_warmup):
self._run_algo(algo, tune_tensor, size, nb, nt)
self.barrier()
# -- low-level execute --
capture_stream = torch.cuda.Stream()
capture_stream.wait_stream(torch.cuda.current_stream())
g = torch.cuda.CUDAGraph()
# Warmup on capture stream
with torch.cuda.stream(capture_stream):
self._run_algo(algo, tune_tensor, size, nb, nt)
capture_stream.synchronize()
with torch.cuda.graph(g, stream=capture_stream):
for _ in range(n_ops_per_graph):
self._run_algo(algo, tune_tensor, size, nb, nt)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record(capture_stream)
with torch.cuda.stream(capture_stream):
for _ in range(n_graph_launches):
g.replay()
end_event.record(capture_stream)
end_event.synchronize()
elapsed = start_event.elapsed_time(end_event)
# Synchronize timing results across all ranks to ensure consistent algorithm selection
# replicate n times such due to algo limitations
time_tensor = torch.full((self.world_size,), elapsed, dtype=torch.float64, device="cuda").to(
dtype=torch.float32
)
torch.cuda.current_stream().wait_stream(capture_stream)
# TODO: use all_reduce may cause problem if the time elapsed between different algos are too close.
# May change to broadcast in the future if that becomes an issue.
self.all_reduce(time_tensor, op=torch.distributed.ReduceOp.SUM)
avg_time = time_tensor[self.rank].item() / self.world_size
if avg_time < best_time:
best_time = avg_time
best_config = (algo, nb, nt)
if best_config:
self.best_configs[size] = best_config
if self.rank == 0:
print(
f"Size {size}: Best Algo {best_config[0].name} nblocks {best_config[1]} nthreads {best_config[2]} Time {(best_time/(n_graph_launches * n_ops_per_graph))*1000:.2f} us"
)
# reset the algorithms after tuning
torch.cuda.synchronize()
for algo in algos:
algo.reset()
def _run_algo(self, algo: mscclpp.Algorithm, tensor, size, nblocks, nthreads):
return algo.execute(
comm=self.comm.communicator,
input_buffer=tensor.data_ptr(),
output_buffer=tensor.data_ptr(),
input_size=size,
output_size=size,
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(tensor.dtype),
op=mscclpp.ReduceOp.SUM,
stream=torch.cuda.current_stream().cuda_stream,
nblocks=nblocks,
nthreads_per_block=nthreads,
symmetric_memory=True,
)
def get_tuned_config(self, size):
if size < 1024:
target_size = 1024
elif size > 256 * 1024 * 1024:
target_size = 256 * 1024 * 1024
else:
target_size = 1 << (size - 1).bit_length()
return self.best_configs.get(target_size)
def all_reduce(self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM, stream: torch.cuda.Stream = None):
assert op == torch.distributed.ReduceOp.SUM
config = self.get_tuned_config(tensor.nbytes)
algo, nblocks, nthreads = config if config else (self._algorithm_nvls_packet, 0, 0)
def _exec_ar(self, tensor, algo, nb, nt, op=mscclpp.ReduceOp.SUM, stream=None, accum_dtype=None, sym=True):
s = stream.cuda_stream if stream else torch.cuda.current_stream().cuda_stream
ret = algo.execute(
comm=self.comm.communicator,
input_buffer=tensor.data_ptr(),
@@ -195,107 +119,357 @@ class CustomizedComm:
input_size=tensor.nbytes,
output_size=tensor.nbytes,
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(tensor.dtype),
op=to_mscclpp_reduce_op(op),
stream=stream.cuda_stream if stream is not None else torch.cuda.current_stream().cuda_stream,
nblocks=nblocks,
nthreads_per_block=nthreads,
symmetric_memory=True,
op=op,
stream=s,
nblocks=nb,
nthreads_per_block=nt,
symmetric_memory=sym,
accum_dtype=accum_dtype,
)
if ret != 0:
print(f"Rank {self.rank}: Algo {algo.name} failed with error {ret}")
print(f"Rank {self.rank}: {algo.name} failed ({ret})")
return ret
def _exec_ag(self, inp, out, algo, nb, nt, stream=None, sym=None):
if sym is None:
sym = self.symmetric_memory
s = stream.cuda_stream if stream else torch.cuda.current_stream().cuda_stream
ret = algo.execute(
comm=self.comm.communicator,
input_buffer=inp.data_ptr(),
output_buffer=out.data_ptr(),
input_size=inp.nbytes,
output_size=out.nbytes,
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(inp.dtype),
op=mscclpp.ReduceOp.NOP,
stream=s,
nblocks=nb,
nthreads_per_block=nt,
symmetric_memory=sym,
)
if ret != 0:
print(f"Rank {self.rank}: AG {algo.name} failed ({ret})")
return ret
def _barrier_internal(self):
a, nb, nt = self._default_ar_config()
self._exec_ar(self._barrier_tensor, a, nb, nt, sym=True)
# -- lazy tuning --
def _ensure_tune_bufs(self):
if self._tune_buf is None:
self._tune_buf = _make_tensor(1 << 27, torch.float16)
self._tune_buf.normal_()
self._time_buf = _make_tensor(4096, torch.float32)
return self._tune_buf
def _ar_candidates(self, size: int):
out = []
if size <= 4 << 20:
a = self._algo("allreduce", "default_allreduce_nvls_packet")
if self._nvls and a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_packet")
if a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_allpair_packet")
if a:
out.append(a)
if size >= 512 << 10:
a = self._algo("allreduce", "default_allreduce_nvls_zero_copy")
if self._nvls and self.symmetric_memory and a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_rsag_zero_copy")
if a:
out.append(a)
if torch.version.hip is not None:
a = self._algo("allreduce", "default_allreduce_fullmesh")
if a:
out.append(a)
return out
def _ag_candidates(self):
a = self._algo("allgather", "default_allgather_fullmesh2")
return [a] if a else []
def _run_tune(self, collective, algo, buf, size, nb, nt):
"""Single tune invocation for either collective."""
if collective == "allreduce":
return algo.execute(
comm=self.comm.communicator,
input_buffer=buf.data_ptr(),
output_buffer=buf.data_ptr(),
input_size=size,
output_size=size,
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(buf.dtype),
op=mscclpp.ReduceOp.SUM,
stream=torch.cuda.current_stream().cuda_stream,
nblocks=nb,
nthreads_per_block=nt,
symmetric_memory=True,
)
else:
total = size * self.world_size
out_ptr = buf.data_ptr()
return algo.execute(
comm=self.comm.communicator,
input_buffer=out_ptr + self.rank * size,
output_buffer=out_ptr,
input_size=size,
output_size=total,
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(buf.dtype),
op=mscclpp.ReduceOp.NOP,
stream=torch.cuda.current_stream().cuda_stream,
nblocks=nb,
nthreads_per_block=nt,
symmetric_memory=False,
)
def _tune_size(self, collective: str, target_size: int):
"""Auto-tune one (collective, target_size) pair and cache result."""
buf = self._ensure_tune_bufs()
cands = self._ar_candidates(target_size) if collective == "allreduce" else self._ag_candidates()
best_time, best_cfg = float("inf"), None
used = set()
run = lambda a, nb, nt: self._run_tune(collective, a, buf, target_size, nb, nt)
for algo in cands:
nb_limit = self._NBLOCKS_LIMIT.get(algo.name, 128)
for nb in self._CANDIDATE_NBLOCKS:
if nb > nb_limit:
continue
for nt in self._CANDIDATE_NTHREADS:
# Feasibility — sync result across ranks so all agree
ret = run(algo, nb, nt)
torch.cuda.synchronize()
self._time_buf[0] = float(ret)
self._exec_ar(self._time_buf[:1], *self._default_ar_config(), sym=True)
if self._time_buf[0].item() != 0:
continue
used.add(algo)
# Warmup
for _ in range(self._TUNE_N_WARMUP):
run(algo, nb, nt)
# CUDA-graph timed benchmark
cs = torch.cuda.Stream()
cs.wait_stream(torch.cuda.current_stream())
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g, stream=cs):
for _ in range(self._TUNE_N_OPS_PER_GRAPH):
run(algo, nb, nt)
start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
start.record(cs)
with torch.cuda.stream(cs):
for _ in range(self._TUNE_N_GRAPH_LAUNCHES):
g.replay()
end.record(cs)
end.synchronize()
elapsed = start.elapsed_time(end)
# Cross-rank timing sync
self._time_buf.fill_(elapsed)
torch.cuda.current_stream().wait_stream(cs)
self._exec_ar(self._time_buf, *self._default_ar_config(), sym=True)
avg = self._time_buf[self.rank].item() / self.world_size
if avg < best_time:
best_time, best_cfg = avg, (algo, nb, nt)
if best_cfg:
self._tune_cache[collective][target_size] = best_cfg
if self.rank == 0:
n = self._TUNE_N_GRAPH_LAUNCHES * self._TUNE_N_OPS_PER_GRAPH
print(
f"[tune] {collective} size={target_size}: {best_cfg[0].name} "
f"nb={best_cfg[1]} nt={best_cfg[2]} time={best_time / n * 1000:.2f}us",
flush=True,
)
else:
fb = (
self._default_ar_config()
if collective == "allreduce"
else ((self._ag_candidates()[0], 32, 512) if self._ag_candidates() else None)
)
self._tune_cache[collective][target_size] = fb
torch.cuda.synchronize()
self._barrier_internal()
for a in used:
a.reset()
# -- public API --
def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, stream=None, accum_dtype=None):
sz = _round_pow2(tensor.nbytes)
if sz not in self._tune_cache["allreduce"]:
self._tune_size("allreduce", sz)
a, nb, nt = self._tune_cache["allreduce"][sz]
self._exec_ar(
tensor, a, nb, nt, op=_to_mscclpp_op(op), stream=stream, accum_dtype=accum_dtype, sym=self.symmetric_memory
)
def all_gather(self, output_tensor, input_tensor, stream=None):
sz = _round_pow2(input_tensor.nbytes)
if sz not in self._tune_cache["allgather"]:
self._tune_size("allgather", sz)
a, nb, nt = self._tune_cache["allgather"][sz]
self._exec_ag(input_tensor, output_tensor, a, nb, nt, stream=stream, sym=self.symmetric_memory)
def barrier(self):
tensor = torch.empty(self.world_size, dtype=torch.float, device=torch.device("cuda"))
self.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM, stream=torch.cuda.current_stream())
def benchmark(self, n_warmup=10, n_graph_launches=10, n_iter_per_graph=100):
low = 5 * 1024
high = 80 * 1024 * 1024
sizes = []
curr = low
while curr <= high:
sizes.append(curr)
curr *= 2
if self.rank == 0:
print(f"{'Size (Bytes)':<20} {'Time (us)':<20} {'AlgoBW (GB/s)':<20}")
dtype = torch.float16
capture_stream = torch.cuda.Stream()
# Allocate a single large RawGpuBuffer (symmetric memory) and reuse it for all sizes.
# Cannot allocate per-size tensors with symmetric memory.
bench_buf = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(dtype))
bench_buf = torch.utils.dlpack.from_dlpack(bench_buf)
bench_buf.normal_()
for size in sizes:
n_elements = size // bench_buf.element_size()
tensor = bench_buf[:n_elements]
capture_stream.wait_stream(torch.cuda.current_stream())
# Capture Graph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g, stream=capture_stream):
for _ in range(n_iter_per_graph):
self.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
# warmup: Execute the graph once to prime the driver
with torch.cuda.stream(capture_stream):
for _ in range(n_warmup):
g.replay()
self.barrier()
capture_stream.synchronize()
# Benchmark
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record(capture_stream)
with torch.cuda.stream(capture_stream):
for _ in range(n_graph_launches):
g.replay()
end_event.record(capture_stream)
end_event.synchronize()
# Get elapsed time in milliseconds
elapsed_ms = start_event.elapsed_time(end_event)
avg_time_ms = elapsed_ms / (n_graph_launches * n_iter_per_graph)
time_us = avg_time_ms * 1000
alg_bw = size / (avg_time_ms * 1e-3) if avg_time_ms > 0 else 0
if self.rank == 0:
print(f"{size:<20} {time_us:<20.2f} {alg_bw / 1e9:<20.2f}")
self._barrier_internal()
def destroy(self):
self._algorithm_nvls_nonzero_copy = None
self._algorithm_nvls_packet = None
self.scratch_buffer = None
self.comm = None
self._algos.clear()
self._tune_cache = {"allreduce": {}, "allgather": {}}
self._tune_buf = self._time_buf = self._barrier_tensor = self._scratch = self.comm = None
def init_dist() -> CustomizedComm:
rank = int(os.environ["RANK"])
world = int(os.environ["WORLD_SIZE"])
master_addr = os.environ["MSCCLPP_MASTER_ADDR"]
master_port = os.environ["MSCCLPP_MASTER_PORT"]
interface = interfaces_for_ip_netifaces(master_addr)
if interface is None:
raise ValueError(f"Cannot find network interface for IP address {master_addr}")
interfaceIpPortTrio = f"{interface}:{master_addr}:{master_port}"
mscclpp_group = mscclpp.CommGroup(interfaceIpPortTrio=interfaceIpPortTrio, rank=rank, size=world)
return CustomizedComm(mscclpp_group)
# -- Benchmarks (standalone) --------------------------------------------------
def _bench_sizes(low=5 * 1024, high=80 << 20):
sizes, c = [], low
while c <= high:
sizes.append(c)
c *= 2
return sizes
def benchmark_allreduce(
comm: CustomizedComm, dtype=torch.float16, accum_dtype=None, n_warmup=10, n_graph_launches=10, n_iter=100
):
sizes = _bench_sizes()
if comm.rank == 0:
print(f"\n{'='*60}\nAllreduce Benchmark\n{'='*60}")
print(f"{'Nelements':<18} {'Size(B)':<18} {'Time(us)':<18} {'AlgoBW(GB/s)':<18}")
cs = torch.cuda.Stream()
buf = _make_tensor(1 << 27, dtype)
buf.normal_() if dtype in (torch.float16, torch.float32, torch.bfloat16) else buf.fill_(0)
for size in sizes:
nelems = size // buf.element_size()
t = buf[: size // buf.element_size()]
comm.all_reduce(t, accum_dtype=accum_dtype)
torch.cuda.synchronize()
cs.wait_stream(torch.cuda.current_stream())
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g, stream=cs):
for _ in range(n_iter):
comm.all_reduce(t, accum_dtype=accum_dtype)
with torch.cuda.stream(cs):
for _ in range(n_warmup):
g.replay()
comm.barrier()
cs.synchronize()
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
s.record(cs)
with torch.cuda.stream(cs):
for _ in range(n_graph_launches):
g.replay()
e.record(cs)
e.synchronize()
ms = s.elapsed_time(e) / (n_graph_launches * n_iter)
if comm.rank == 0:
print(f"{nelems:<18} {size:<18} {ms*1000:<18.2f} {size/(ms*1e-3)/1e9:<18.2f}")
def benchmark_allgather(comm: CustomizedComm, dtype=torch.float16, n_warmup=10, n_graph_launches=10, n_iter=100):
sizes = _bench_sizes()
if comm.rank == 0:
print(f"\n{'='*60}\nAllgather Benchmark\n{'='*60}")
print(f"{'PerRank(B)':<18} {'Total(B)':<18} {'Time(us)':<18} {'AlgoBW(GB/s)':<18}")
cs = torch.cuda.Stream()
buf = _make_tensor(1 << 27, dtype)
buf.normal_() if dtype in (torch.float16, torch.float32, torch.bfloat16) else buf.fill_(0)
for prs in sizes:
total = prs * comm.world_size
if total > buf.nbytes:
break
nt = total // buf.element_size()
npr = prs // buf.element_size()
out = buf[:nt]
inp = out[comm.rank * npr : (comm.rank + 1) * npr]
comm.all_gather(out, inp)
torch.cuda.synchronize()
cs.wait_stream(torch.cuda.current_stream())
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g, stream=cs):
for _ in range(n_iter):
comm.all_gather(out, inp)
with torch.cuda.stream(cs):
for _ in range(n_warmup):
g.replay()
comm.barrier()
cs.synchronize()
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
s.record(cs)
with torch.cuda.stream(cs):
for _ in range(n_graph_launches):
g.replay()
e.record(cs)
e.synchronize()
ms = s.elapsed_time(e) / (n_graph_launches * n_iter)
if comm.rank == 0:
print(f"{prs:<18} {total:<18} {ms*1000:<18.2f} {total/(ms*1e-3)/1e9:<18.2f}")
# -- Bootstrap & main ---------------------------------------------------------
def init_dist() -> mscclpp.CommGroup:
addr = os.environ.get("MSCCLPP_MASTER_ADDR")
if addr:
rank, world = int(os.environ["RANK"]), int(os.environ["WORLD_SIZE"])
port = os.environ["MSCCLPP_MASTER_PORT"]
iface = _interfaces_for_ip(addr)
if not iface:
raise ValueError(f"No interface for {addr}")
return mscclpp.CommGroup(interfaceIpPortTrio=f"{iface}:{addr}:{port}", rank=rank, size=world)
import torch.distributed as dist
dist.init_process_group(backend="gloo")
return mscclpp.CommGroup(torch_group=dist.group.WORLD)
def main():
local = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local)
comm = init_dist()
comm.benchmark(n_warmup=5, n_graph_launches=10, n_iter_per_graph=100)
comm.barrier()
dtype_str = os.environ.get("DTYPE", "float16")
dtype = getattr(torch, dtype_str, torch.float16)
accum_map = {"float32": mscclpp.DataType.float32, "float16": mscclpp.DataType.float16}
accum_str = os.environ.get("ACCUM_DTYPE")
accum_dtype = accum_map.get(accum_str) if accum_str else None
comm_group = init_dist()
cc = CustomizedComm(comm_group)
print(f"rank {local} starting benchmarks with dtype={dtype} accum_dtype={accum_dtype}...")
benchmark_allreduce(cc, dtype=dtype, accum_dtype=accum_dtype)
cc.barrier()
torch.cuda.synchronize()
comm.destroy()
print(f"rank {local} All-reduce operation completed successfully.")
benchmark_allgather(cc, dtype=dtype)
cc.barrier()
torch.cuda.synchronize()
cc.destroy()
print(f"rank {local} completed successfully.")
if __name__ == "__main__":

View File

@@ -1,19 +1,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# LD_PRELOAD=<MSCCLPP_REPO>/build/lib/nccl/libmscclpp_nccl.so torchrun --nnodes=1 --nproc_per_node=8 dsl_with_nccl_api.py
# LD_PRELOAD=<MSCCLPP_REPO>/build/lib/libmscclpp_nccl.so torchrun --nnodes=1 --nproc_per_node=8 dsl_with_nccl_api.py
import os
from typing import Any, Dict
import torch, torch.distributed as dist
import mscclpp
import mscclpp.ext
from mscclpp.language.collectives import AllReduce
from mscclpp.language.channel import SwitchChannel, MemoryChannel, BufferType, SyncType
from mscclpp.language.program import CollectiveProgram
from mscclpp.language.rank import Rank
from mscclpp.language.utils import AlgoSpec
def allreduce_nvls(spec: mscclpp.AlgoSpec) -> CollectiveProgram:
def allreduce_nvls(spec: AlgoSpec) -> CollectiveProgram:
gpu_size = spec.world_size
with CollectiveProgram.from_spec(spec) as program:
# Creating Channels
@@ -63,8 +64,8 @@ def allreduce_nvls(spec: mscclpp.AlgoSpec) -> CollectiveProgram:
return program
def setup_plan(algo_collection_builder: mscclpp.AlgorithmCollectionBuilder, rank: int, world_size: int):
spec = mscclpp.AlgoSpec(
def setup_plan(algo_collection_builder: mscclpp.ext.AlgorithmCollectionBuilder, rank: int, world_size: int):
spec = AlgoSpec(
name="allreduce_nvls",
collective=AllReduce(8, 1, True),
nranks_per_node=8,
@@ -94,10 +95,10 @@ def init_dist():
rank = int(os.environ["RANK"])
world = int(os.environ["WORLD_SIZE"])
local = int(os.environ["LOCAL_RANK"])
algorithm_collection_builder = mscclpp.AlgorithmCollectionBuilder()
algorithm_collection_builder = mscclpp.ext.AlgorithmCollectionBuilder()
setup_plan(algorithm_collection_builder, rank, world)
algorithm_collection_builder.set_algorithm_selector(selector)
dist.init_process_group(backend="nccl", device_id=local)
dist.init_process_group(backend="nccl", device_id=torch.device("cuda", local))
return rank, world, local

View File

@@ -103,12 +103,14 @@ class Algorithm {
/// @param nThreadsPerBlock Number of threads per block (0 for auto-selection).
/// @param symmetricMemory Whether to use symmetric memory optimization.
/// @param extras Additional parameters for algorithm-specific customization.
/// @param accumDtype Data type for accumulation during reduction. DataType::AUTO resolves to dtype.
/// @return The result of the operation.
virtual CommResult execute(std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream,
std::shared_ptr<Executor> executor, int nBlocks = 0, int nThreadsPerBlock = 0,
bool symmetricMemory = false,
const std::unordered_map<std::string, uintptr_t>& extras = {}) = 0;
const std::unordered_map<std::string, uintptr_t>& extras = {},
DataType accumDtype = DataType::AUTO) = 0;
/// Reset the algorithm state, clearing any cached contexts.
virtual void reset() = 0;
@@ -186,10 +188,11 @@ class NativeAlgorithm : public Algorithm {
/// @param nBlocks Number of CUDA blocks.
/// @param nThreadsPerBlock Number of threads per block.
/// @param extras Additional algorithm-specific parameters.
/// @param accumDtype Data type for accumulation (resolved from input dtype if sentinel).
/// @return The result of the operation.
using KernelFunc =
std::function<CommResult(const std::shared_ptr<void>, const void*, void*, size_t, size_t, DataType, ReduceOp,
cudaStream_t, int, int, const std::unordered_map<std::string, uintptr_t>&)>;
cudaStream_t, int, int, const std::unordered_map<std::string, uintptr_t>&, DataType)>;
/// Function type for creating algorithm contexts.
/// @param comm The communicator.
@@ -233,8 +236,8 @@ class NativeAlgorithm : public Algorithm {
CommResult execute(std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream,
std::shared_ptr<Executor> executor, int nBlocks = 0, int nThreadsPerBlock = 0,
bool symmetricMemory = false,
const std::unordered_map<std::string, uintptr_t>& extras = {}) override;
bool symmetricMemory = false, const std::unordered_map<std::string, uintptr_t>& extras = {},
DataType accumDtype = DataType::AUTO) override;
const std::string& name() const override;
const std::string& collective() const override;
const std::pair<size_t, size_t>& messageRange() const override;
@@ -285,8 +288,8 @@ class DslAlgorithm : public Algorithm, public AlgorithmBuilder, public std::enab
CommResult execute(std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream,
std::shared_ptr<Executor> executor, int nBlocks = 0, int nThreadsPerBlock = 0,
bool symmetricMemory = false,
const std::unordered_map<std::string, uintptr_t>& extras = {}) override;
bool symmetricMemory = false, const std::unordered_map<std::string, uintptr_t>& extras = {},
DataType accumDtype = DataType::AUTO) override;
AlgorithmType type() const override { return AlgorithmType::DSL; }
Constraint constraint() const override;
void reset() override;

View File

@@ -64,18 +64,151 @@ using __bfloat162 = __nv_bfloat162;
#endif
/// Software float8 with 4 exponent bits, 3 mantissa bits, exponent bias = 15.
/// Format (MSB first): [sign:1][exponent:4][mantissa:3]
/// No infinities; exp=15 is NaN. Negative zero is NaN (fnuz convention).
/// Max finite value: 0.9375, min normal: ~6.1e-5, min subnormal: ~7.6e-6.
struct alignas(1) __fp8_e4m3b15 {
uint8_t __x;
__fp8_e4m3b15() = default;
/// Construct from raw bits (use __fp8_e4m3b15::fromRaw() for clarity).
MSCCLPP_HOST_DEVICE_INLINE explicit __fp8_e4m3b15(uint8_t raw) : __x(raw) {}
/// Construct from float32 (explicit to avoid ambiguous conversion chains).
MSCCLPP_HOST_DEVICE_INLINE explicit __fp8_e4m3b15(float val) : __x(fromFloat(val)) {}
/// Convert to float32.
MSCCLPP_HOST_DEVICE_INLINE operator float() const { return toFloat(__x); }
/// Construct from a raw bit pattern without conversion.
static MSCCLPP_HOST_DEVICE_INLINE __fp8_e4m3b15 fromRaw(uint8_t bits) {
__fp8_e4m3b15 r;
r.__x = bits;
return r;
}
private:
/// Decode fp8_e4m3b15 bits → float32.
///
/// Uses bit manipulation through fp16 as intermediate, adapted from the Triton compiler.
/// fp8_e4m3b15 is identical to fp8_e4m3fn (NVIDIA) except exponent bias is 15 vs 7.
/// Algorithm: reinterpret fp8 bits into an fp16 bit pattern with exponent shifted by -8,
/// then convert fp16 → float32.
static MSCCLPP_HOST_DEVICE_INLINE float toFloat(uint8_t bits) {
// Handle special values: negative zero (0x80) → NaN, exponent=15 → NaN.
uint32_t exp = (bits >> 3) & 0xFu;
if (bits == 0x80 || exp == 15) {
union {
uint32_t u;
float f;
} nan_val = {0x7FC00000u};
return nan_val.f;
}
if (bits == 0) return 0.0f;
// Triton-style bit manipulation: fp8 → fp16 → fp32.
// fp8 layout: [S:1][E:4][M:3] (bias=15)
// fp16 layout: [S:1][E:5][M:10] (bias=15)
//
// Place fp8 in upper byte of fp16, then right-shift exponent+mantissa by 1
// to convert E4 → E5 (both share bias=15). Sign bit stays at bit 15.
// Refer:
// https://github.com/triton-lang/triton/blob/cf34004b8a67d290a962da166f5aa2fc66751326/python/triton/language/extra/cuda/utils.py#L34
uint16_t h = (uint16_t)bits << 8; // place fp8 in upper byte of fp16
uint16_t sign16 = h & 0x8000u; // extract sign at fp16 position
uint16_t nosign = h & 0x7F00u; // exponent + mantissa (no sign)
uint16_t fp16_bits = sign16 | (nosign >> 1); // shift exponent right by 1
// For subnormals: when fp8 exponent=0, the above gives fp16 exponent=0
// and fp16 mantissa = (fp8_mantissa << 7), which correctly represents
// the subnormal fp16 value since both share bias=15.
// Convert fp16 bits to float via __half (works on host and device, CUDA and HIP).
union {
uint16_t u;
__half h;
} cvt = {fp16_bits};
return __half2float(cvt.h);
}
/// Encode float32 → fp8_e4m3b15 bits.
///
/// Algorithm adapted from Triton: float32 → fp16 → bit-manipulate → fp8.
/// The key insight is to convert to fp16 first (which shares bias=15 with e4m3b15),
/// then pack the fp16 bits back into 8 bits by shifting the exponent left by 1.
static MSCCLPP_HOST_DEVICE_INLINE uint8_t fromFloat(float val) {
union {
float f;
uint32_t u;
} in = {val};
// NaN → 0x80 (negative-zero bit pattern = NaN in fnuz).
if ((in.u & 0x7F800000u) == 0x7F800000u && (in.u & 0x007FFFFFu) != 0) return 0x80u;
// Convert float32 → fp16 bits via __half (works on host and device, CUDA and HIP).
__half h_val = __float2half_rn(val);
union {
__half h;
uint16_t u;
} cvt = {h_val};
uint16_t fp16_bits = cvt.u;
// Clamp absolute value to max finite e4m3b15: 0.9375 → fp16 = 0x3B80.
uint16_t abs_fp16 = fp16_bits & 0x7FFFu;
if (abs_fp16 > 0x3B80u) abs_fp16 = 0x3B80u;
// Reconstruct with sign.
uint16_t sign16 = fp16_bits & 0x8000u;
// Triton-style: fp16 → fp8.
// fp16 layout: [S:1][E:5][M:10] (bias=15)
// fp8 layout: [S:1][E:4][M:3] (bias=15)
//
// mad.lo.u32 a0, a0, 2, 0x00800080 → (abs_fp16 * 2 + 0x0080)
// This shifts left by 1 (undoing the right-shift in decode) and adds rounding bias.
// Then: lop3.b32 b0, $1, 0x80008000, a0, 0xea → (sign & 0x8000) | a0
// Finally: prmt for byte extraction.
//
// Simplified for scalar: shift abs_fp16 left by 1, add rounding bias, take upper byte.
uint16_t adjusted = (uint16_t)(abs_fp16 * 2u + 0x0080u);
// The upper byte now contains [E:4][M:3][round_bit].
// Combine with sign and extract.
uint16_t with_sign = sign16 | adjusted;
uint8_t result = (uint8_t)(with_sign >> 8);
// Zero → 0x00 (ensure positive zero, not negative zero which is NaN).
if ((result & 0x7Fu) == 0) result = 0x00u;
return result;
}
};
/// Packed 2x fp8_e4m3b15 storage.
struct alignas(2) __fp8x2_e4m3b15 {
uint16_t __x;
};
/// Packed 4x fp8_e4m3b15 storage.
struct alignas(4) __fp8x4_e4m3b15 {
uint32_t __x;
};
namespace mscclpp {
/// Data types supported by mscclpp operations.
enum class DataType {
INT32, // 32-bit signed integer.
UINT32, // 32-bit unsigned integer.
FLOAT16, // IEEE 754 half precision.
FLOAT32, // IEEE 754 single precision.
BFLOAT16, // bfloat16 precision.
FLOAT8_E4M3, // float8 with E4M3 layout.
FLOAT8_E5M2, // float8 with E5M2 layout.
UINT8, // 8-bit unsigned integer.
INT32, // 32-bit signed integer.
UINT32, // 32-bit unsigned integer.
FLOAT16, // IEEE 754 half precision.
FLOAT32, // IEEE 754 single precision.
BFLOAT16, // bfloat16 precision.
FLOAT8_E4M3, // float8 with E4M3 layout.
FLOAT8_E5M2, // float8 with E5M2 layout.
UINT8, // 8-bit unsigned integer.
FLOAT8_E4M3B15, // float8 with E4M3 layout, bias=15 (software, no HW accel).
AUTO = 255, // Sentinel: resolve to the input dtype at runtime.
};
/// Word array.
@@ -97,6 +230,7 @@ struct alignas(Bytes) Words<Bytes, false> {};
template <typename T, int N, typename StorageT>
union alignas(sizeof(T) * N) VectorTypeImpl {
static_assert(N > 0, "N must be greater than 0");
static_assert(sizeof(StorageT) >= sizeof(T) * N, "StorageT must cover the full vector size");
T data[N];
Words<sizeof(T) * N> words;
@@ -127,13 +261,14 @@ union alignas(sizeof(T) * N) VectorTypeImpl {
MSCCLPP_HOST_DEVICE_INLINE const T& operator[](int i) const { return data[i]; }
};
// Helper template to get the appropriate vector type for a given element type and count
// Helper template to get the appropriate vector type for a given element type and count.
template <typename T, int N>
struct VectorTypeHelper {
using type =
VectorTypeImpl<T, N,
typename std::conditional_t<N * sizeof(T) == 4, uint32_t,
typename std::conditional_t<N * sizeof(T) == 8, uint2, uint4>>>;
static constexpr int Bytes = N * sizeof(T);
using type = VectorTypeImpl<
T, N,
std::conditional_t<Bytes == 4, uint32_t,
std::conditional_t<Bytes == 8, uint2, std::conditional_t<Bytes <= 16, uint4, Words<Bytes>>>>>;
};
/// Vector type - clean user interface (automatically selects appropriate storage type)
@@ -170,6 +305,11 @@ DEFINE_VEC(bf16x4, __bfloat16, 4, uint2);
DEFINE_VEC(f16x8, __half, 8, uint4);
DEFINE_VEC(bf16x8, __bfloat16, 8, uint4);
// Aliases for large vector types (>16 bytes) where no native CUDA storage type exists.
using f32x8 = VectorType<float, 8>;
using f32x16 = VectorType<float, 16>;
using f16x16 = VectorType<__half, 16>;
#if defined(__FP8_TYPES_EXIST__)
DEFINE_VEC(f8_e4m3x2, __fp8_e4m3, 2, __fp8x2_e4m3);
DEFINE_VEC(f8_e4m3x4, __fp8_e4m3, 4, __fp8x4_e4m3);
@@ -181,6 +321,12 @@ DEFINE_VEC(f8_e5m2x4, __fp8_e5m2, 4, __fp8x4_e5m2);
DEFINE_VEC(f8_e5m2x8, __fp8_e5m2, 8, uint2);
DEFINE_VEC(f8_e5m2x16, __fp8_e5m2, 16, uint4);
#endif
// fp8_e4m3b15 vectors (always available — software type, no HW dependency)
DEFINE_VEC(f8_e4m3b15x2, __fp8_e4m3b15, 2, __fp8x2_e4m3b15);
DEFINE_VEC(f8_e4m3b15x4, __fp8_e4m3b15, 4, __fp8x4_e4m3b15);
DEFINE_VEC(f8_e4m3b15x8, __fp8_e4m3b15, 8, uint2);
DEFINE_VEC(f8_e4m3b15x16, __fp8_e4m3b15, 16, uint4);
#undef DEFINE_VEC
#if defined(MSCCLPP_DEVICE_COMPILE)
@@ -254,6 +400,21 @@ MSCCLPP_DEVICE_INLINE __fp8_e5m2 clip(__fp8_e5m2 val) {
}
#endif
// --- f32x2 arithmetic ---
template <bool UseClip = true>
MSCCLPP_DEVICE_INLINE f32x2 operator+(const f32x2& a, const f32x2& b) {
#if defined(MSCCLPP_DEVICE_CUDA) && (__CUDA_ARCH__ >= 1000)
// Blackwell (SM 10.0+): packed float2 add in a single instruction.
return __fadd2_rn(a.storage, b.storage);
#else
f32x2 result;
result.data[0] = a.data[0] + b.data[0];
result.data[1] = a.data[1] + b.data[1];
return result;
#endif
}
template <bool UseClip = true>
MSCCLPP_DEVICE_INLINE f16x2 operator+(const f16x2& a, const f16x2& b) {
__half2 result;
@@ -265,6 +426,18 @@ MSCCLPP_DEVICE_INLINE f16x2 operator+(const f16x2& a, const f16x2& b) {
return result;
}
template <bool UseClip = true>
MSCCLPP_DEVICE_INLINE f16x4 operator+(const f16x4& a, const f16x4& b) {
// Decompose into 2× packed __hadd2 (2 instructions instead of 4 scalar __hadd).
const f16x2* a2 = reinterpret_cast<const f16x2*>(&a);
const f16x2* b2 = reinterpret_cast<const f16x2*>(&b);
f16x4 result;
f16x2* r2 = reinterpret_cast<f16x2*>(&result);
r2[0] = a2[0] + b2[0];
r2[1] = a2[1] + b2[1];
return result;
}
template <bool UseClip = true>
MSCCLPP_DEVICE_INLINE bf16x2 operator+(const bf16x2& a, const bf16x2& b) {
__bfloat162 result;
@@ -449,6 +622,14 @@ MSCCLPP_DEVICE_INLINE T min(const T& a, const T& b) {
return (a < b ? a : b);
}
template <>
MSCCLPP_DEVICE_INLINE f32x2 min(const f32x2& a, const f32x2& b) {
f32x2 result;
result.data[0] = fminf(a.data[0], b.data[0]);
result.data[1] = fminf(a.data[1], b.data[1]);
return result;
}
template <>
MSCCLPP_DEVICE_INLINE f16x2 min(const f16x2& a, const f16x2& b) {
#if defined(MSCCLPP_DEVICE_HIP)
@@ -489,6 +670,51 @@ MSCCLPP_DEVICE_INLINE u8x4 min(const u8x4& a, const u8x4& b) {
#endif
}
/// Convert a vector type From to vector type To.
/// Primary template with auto-decomposition: vectors with N > 4 elements decompose into x4 chunks,
/// vectors with N == 4 decompose into x2 chunks, enabling optimized x2/x4 specializations to be reached.
/// Specialized below for optimized FP8 conversion paths at x2/x4 level.
template <typename To, typename From>
MSCCLPP_DEVICE_INLINE To to(const From& v) {
static_assert(To::Size == From::Size, "to<To, From>: vector sizes must match");
constexpr int N = From::Size;
// Auto-decompose: N > 4 → split into x4 chunks
if constexpr (N > 4 && N % 4 == 0) {
constexpr int nChunks = N / 4;
using FromChunk = VectorType<typename From::ElementType, 4>;
using ToChunk = VectorType<typename To::ElementType, 4>;
const FromChunk* in = reinterpret_cast<const FromChunk*>(&v);
To result;
ToChunk* out = reinterpret_cast<ToChunk*>(&result);
#pragma unroll
for (int c = 0; c < nChunks; ++c) {
out[c] = to<ToChunk>(in[c]);
}
return result;
}
// Auto-decompose: N == 4 → split into 2x x2 chunks
else if constexpr (N == 4) {
using FromChunk = VectorType<typename From::ElementType, 2>;
using ToChunk = VectorType<typename To::ElementType, 2>;
const FromChunk* in = reinterpret_cast<const FromChunk*>(&v);
To result;
ToChunk* out = reinterpret_cast<ToChunk*>(&result);
out[0] = to<ToChunk>(in[0]);
out[1] = to<ToChunk>(in[1]);
return result;
}
// Base case: element-wise conversion
else {
To result;
#pragma unroll
for (int i = 0; i < N; ++i) {
result.data[i] = static_cast<typename To::ElementType>(v.data[i]);
}
return result;
}
}
#if defined(__FP8_TYPES_EXIST__)
template <>
MSCCLPP_DEVICE_INLINE __fp8_e4m3 min(const __fp8_e4m3& a, const __fp8_e4m3& b) {
@@ -551,7 +777,592 @@ MSCCLPP_DEVICE_INLINE f8_e5m2x4 min(const f8_e5m2x4& a, const f8_e5m2x4& b) {
return result;
}
// --- f8_e4m3 -> f32 specializations ---
/// f8_e4m3x2 -> f32x2.
/// NVIDIA: fp8 -> half (via __nv_cvt_fp8x2_to_halfraw2) -> float.
/// HIP gfx942: fp8 -> float (via __builtin_amdgcn_cvt_pk_f32_fp8).
template <>
MSCCLPP_DEVICE_INLINE f32x2 to<f32x2, f8_e4m3x2>(const f8_e4m3x2& v) {
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
auto f = __builtin_amdgcn_cvt_pk_f32_fp8(v.storage.__x, 0);
f32x2 result;
result.data[0] = f[0];
result.data[1] = f[1];
return result;
#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900
__half2_raw h2 = __nv_cvt_fp8x2_to_halfraw2(bit_cast<__nv_fp8x2_storage_t>(v.storage), __NV_E4M3);
f32x2 result;
result.data[0] = __half2float(bit_cast<__half>(h2.x));
result.data[1] = __half2float(bit_cast<__half>(h2.y));
return result;
#else
f32x2 result;
result.data[0] = float(v.data[0]);
result.data[1] = float(v.data[1]);
return result;
#endif
}
/// f8_e4m3x4 -> f32x4.
template <>
MSCCLPP_DEVICE_INLINE f32x4 to<f32x4, f8_e4m3x4>(const f8_e4m3x4& v) {
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
auto lo = __builtin_amdgcn_cvt_pk_f32_fp8(v.storage.__x, false);
auto hi = __builtin_amdgcn_cvt_pk_f32_fp8(v.storage.__x, true);
f32x4 result;
result.data[0] = lo[0];
result.data[1] = lo[1];
result.data[2] = hi[0];
result.data[3] = hi[1];
return result;
#else
const f8_e4m3x2* pair = reinterpret_cast<const f8_e4m3x2*>(&v);
f32x2 lo = to<f32x2>(pair[0]);
f32x2 hi = to<f32x2>(pair[1]);
f32x4 result;
result.data[0] = lo.data[0];
result.data[1] = lo.data[1];
result.data[2] = hi.data[0];
result.data[3] = hi.data[1];
return result;
#endif
}
// --- f8_e5m2 -> f32 specializations ---
/// f8_e5m2x2 -> f32x2.
/// NVIDIA: fp8 -> half (via __nv_cvt_fp8x2_to_halfraw2) -> float.
/// HIP gfx942: bf8 -> float (via __builtin_amdgcn_cvt_pk_f32_bf8).
template <>
MSCCLPP_DEVICE_INLINE f32x2 to<f32x2, f8_e5m2x2>(const f8_e5m2x2& v) {
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
auto f = __builtin_amdgcn_cvt_pk_f32_bf8(v.storage.__x, 0);
f32x2 result;
result.data[0] = f[0];
result.data[1] = f[1];
return result;
#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900
__half2_raw h2 = __nv_cvt_fp8x2_to_halfraw2(bit_cast<__nv_fp8x2_storage_t>(v.storage), __NV_E5M2);
f32x2 result;
result.data[0] = __half2float(bit_cast<__half>(h2.x));
result.data[1] = __half2float(bit_cast<__half>(h2.y));
return result;
#else
f32x2 result;
result.data[0] = float(v.data[0]);
result.data[1] = float(v.data[1]);
return result;
#endif
}
/// f8_e5m2x4 -> f32x4.
template <>
MSCCLPP_DEVICE_INLINE f32x4 to<f32x4, f8_e5m2x4>(const f8_e5m2x4& v) {
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
auto lo = __builtin_amdgcn_cvt_pk_f32_bf8(v.storage.__x, false);
auto hi = __builtin_amdgcn_cvt_pk_f32_bf8(v.storage.__x, true);
f32x4 result;
result.data[0] = lo[0];
result.data[1] = lo[1];
result.data[2] = hi[0];
result.data[3] = hi[1];
return result;
#else
const f8_e5m2x2* pair = reinterpret_cast<const f8_e5m2x2*>(&v);
f32x2 lo = to<f32x2>(pair[0]);
f32x2 hi = to<f32x2>(pair[1]);
f32x4 result;
result.data[0] = lo.data[0];
result.data[1] = lo.data[1];
result.data[2] = hi.data[0];
result.data[3] = hi.data[1];
return result;
#endif
}
// --- f32 -> f8_e4m3 specializations (downcast) ---
/// f32x2 -> f8_e4m3x2.
/// HIP gfx942: float -> fp8 (via __builtin_amdgcn_cvt_pk_fp8_f32).
/// NVIDIA SM90+: float -> half -> fp8 (via __nv_cvt_halfraw2_to_fp8x2).
/// NVIDIA pre-SM90: float -> half -> fp8 (via __nv_cvt_halfraw_to_fp8, element-wise).
template <>
MSCCLPP_DEVICE_INLINE f8_e4m3x2 to<f8_e4m3x2, f32x2>(const f32x2& v) {
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
uint32_t packed = __builtin_amdgcn_cvt_pk_fp8_f32(v.data[0], v.data[1], 0, false);
return bit_cast<f8_e4m3x2>(static_cast<__hip_fp8x2_storage_t>(packed));
#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900
__half2_raw h2;
h2.x = bit_cast<unsigned short>(__float2half_rn(v.data[0]));
h2.y = bit_cast<unsigned short>(__float2half_rn(v.data[1]));
__nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2(h2, __NV_SATFINITE, __NV_E4M3);
return bit_cast<f8_e4m3x2>(fp8x2);
#elif defined(MSCCLPP_DEVICE_CUDA)
__half_raw h0, h1;
h0.x = bit_cast<unsigned short>(__float2half_rn(v.data[0]));
h1.x = bit_cast<unsigned short>(__float2half_rn(v.data[1]));
f8_e4m3x2 result;
result.data[0] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h0, __NV_SATFINITE, __NV_E4M3));
result.data[1] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h1, __NV_SATFINITE, __NV_E4M3));
return result;
#else
f8_e4m3x2 result;
result.data[0] = static_cast<__fp8_e4m3>(v.data[0]);
result.data[1] = static_cast<__fp8_e4m3>(v.data[1]);
return result;
#endif
}
/// f32x4 -> f8_e4m3x4.
template <>
MSCCLPP_DEVICE_INLINE f8_e4m3x4 to<f8_e4m3x4, f32x4>(const f32x4& v) {
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
uint32_t packed = __builtin_amdgcn_cvt_pk_fp8_f32(v.data[0], v.data[1], 0, false);
packed = __builtin_amdgcn_cvt_pk_fp8_f32(v.data[2], v.data[3], packed, true);
return bit_cast<f8_e4m3x4>(packed);
#else
f32x2 lo, hi;
lo.data[0] = v.data[0];
lo.data[1] = v.data[1];
hi.data[0] = v.data[2];
hi.data[1] = v.data[3];
f8_e4m3x2 lo_fp8 = to<f8_e4m3x2>(lo);
f8_e4m3x2 hi_fp8 = to<f8_e4m3x2>(hi);
f8_e4m3x4 result;
result.data[0] = lo_fp8.data[0];
result.data[1] = lo_fp8.data[1];
result.data[2] = hi_fp8.data[0];
result.data[3] = hi_fp8.data[1];
return result;
#endif
}
// --- f32 -> f8_e5m2 specializations (downcast) ---
/// f32x2 -> f8_e5m2x2.
/// HIP gfx942: float -> bf8 (via __builtin_amdgcn_cvt_pk_bf8_f32).
/// NVIDIA SM90+: float -> half -> fp8 (via __nv_cvt_halfraw2_to_fp8x2 with __NV_E5M2).
/// NVIDIA pre-SM90: float -> half -> fp8 (via __nv_cvt_halfraw_to_fp8, element-wise).
template <>
MSCCLPP_DEVICE_INLINE f8_e5m2x2 to<f8_e5m2x2, f32x2>(const f32x2& v) {
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
uint32_t packed = __builtin_amdgcn_cvt_pk_bf8_f32(v.data[0], v.data[1], 0, false);
return bit_cast<f8_e5m2x2>(static_cast<__hip_fp8x2_storage_t>(packed));
#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900
__half2_raw h2;
h2.x = bit_cast<unsigned short>(__float2half_rn(v.data[0]));
h2.y = bit_cast<unsigned short>(__float2half_rn(v.data[1]));
__nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2(h2, __NV_SATFINITE, __NV_E5M2);
return bit_cast<f8_e5m2x2>(fp8x2);
#elif defined(MSCCLPP_DEVICE_CUDA)
__half_raw h0, h1;
h0.x = bit_cast<unsigned short>(__float2half_rn(v.data[0]));
h1.x = bit_cast<unsigned short>(__float2half_rn(v.data[1]));
f8_e5m2x2 result;
result.data[0] = bit_cast<__fp8_e5m2>(__nv_cvt_halfraw_to_fp8(h0, __NV_SATFINITE, __NV_E5M2));
result.data[1] = bit_cast<__fp8_e5m2>(__nv_cvt_halfraw_to_fp8(h1, __NV_SATFINITE, __NV_E5M2));
return result;
#else
f8_e5m2x2 result;
result.data[0] = static_cast<__fp8_e5m2>(v.data[0]);
result.data[1] = static_cast<__fp8_e5m2>(v.data[1]);
return result;
#endif
}
/// f32x4 -> f8_e5m2x4.
template <>
MSCCLPP_DEVICE_INLINE f8_e5m2x4 to<f8_e5m2x4, f32x4>(const f32x4& v) {
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
uint32_t packed = __builtin_amdgcn_cvt_pk_bf8_f32(v.data[0], v.data[1], 0, false);
packed = __builtin_amdgcn_cvt_pk_bf8_f32(v.data[2], v.data[3], packed, true);
return bit_cast<f8_e5m2x4>(packed);
#else
f32x2 lo, hi;
lo.data[0] = v.data[0];
lo.data[1] = v.data[1];
hi.data[0] = v.data[2];
hi.data[1] = v.data[3];
f8_e5m2x2 lo_fp8 = to<f8_e5m2x2>(lo);
f8_e5m2x2 hi_fp8 = to<f8_e5m2x2>(hi);
f8_e5m2x4 result;
result.data[0] = lo_fp8.data[0];
result.data[1] = lo_fp8.data[1];
result.data[2] = hi_fp8.data[0];
result.data[3] = hi_fp8.data[1];
return result;
#endif
}
// --- f8_e4m3 <-> f16 conversion specializations ---
/// f8_e4m3x2 -> f16x2.
/// NVIDIA SM90+: packed intrinsic (1 instruction).
/// HIP gfx942: fp8 -> float -> half (via AMD builtin).
/// Pre-SM90 / fallback: element-wise scalar conversion.
template <>
MSCCLPP_DEVICE_INLINE f16x2 to<f16x2, f8_e4m3x2>(const f8_e4m3x2& v) {
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
auto f = __builtin_amdgcn_cvt_pk_f32_fp8(v.storage.__x, 0);
f16x2 result;
result.data[0] = __float2half(f[0]);
result.data[1] = __float2half(f[1]);
return result;
#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900
__half2_raw h2 = __nv_cvt_fp8x2_to_halfraw2(bit_cast<__nv_fp8x2_storage_t>(v.storage), __NV_E4M3);
return bit_cast<f16x2>(h2);
#else
f16x2 result;
result.data[0] = static_cast<__half>(v.data[0]);
result.data[1] = static_cast<__half>(v.data[1]);
return result;
#endif
}
/// f16x2 -> f8_e4m3x2.
/// NVIDIA SM90+: packed intrinsic (1 instruction).
/// HIP gfx942: half -> float -> fp8 (via AMD builtin).
/// Pre-SM90: element-wise scalar conversion.
template <>
MSCCLPP_DEVICE_INLINE f8_e4m3x2 to<f8_e4m3x2, f16x2>(const f16x2& v) {
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
float f0 = __half2float(v.data[0]);
float f1 = __half2float(v.data[1]);
uint32_t packed = __builtin_amdgcn_cvt_pk_fp8_f32(f0, f1, 0, false);
return bit_cast<f8_e4m3x2>(static_cast<__hip_fp8x2_storage_t>(packed));
#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900
__half2_raw h2 = bit_cast<__half2_raw>(v);
__nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2(h2, __NV_SATFINITE, __NV_E4M3);
return bit_cast<f8_e4m3x2>(fp8x2);
#elif defined(MSCCLPP_DEVICE_CUDA)
__half_raw h0, h1;
h0.x = bit_cast<unsigned short>(v.data[0]);
h1.x = bit_cast<unsigned short>(v.data[1]);
f8_e4m3x2 result;
result.data[0] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h0, __NV_SATFINITE, __NV_E4M3));
result.data[1] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h1, __NV_SATFINITE, __NV_E4M3));
return result;
#else
f8_e4m3x2 result;
result.data[0] = static_cast<__fp8_e4m3>(v.data[0]);
result.data[1] = static_cast<__fp8_e4m3>(v.data[1]);
return result;
#endif
}
#endif // defined(__FP8_TYPES_EXIST__)
// --- fp8_e4m3b15 <-> fp16 direct conversion specializations ---
// These are the PRIMARY conversions: fp8_b15 <-> fp16 is just a 1-bit exponent shift
// (E4 bias=15 <-> E5 bias=15), no precision loss since fp16 has 10 mantissa bits
// vs fp8's 3. fp32 conversions are derived by routing through fp16.
/// f8_e4m3b15x2 -> f16x2.
/// Direct fp8 -> fp16 via branch-free bit manipulation.
template <>
MSCCLPP_DEVICE_INLINE f16x2 to<f16x2, f8_e4m3b15x2>(const f8_e4m3b15x2& v) {
#if defined(MSCCLPP_DEVICE_CUDA)
uint16_t in = v.storage.__x;
// Spread 2 fp8 bytes into packed fp16 pair, adjust exponent E4->E5.
uint32_t a0 = ((uint32_t)(in & 0xFFu) << 8) | ((uint32_t)(in >> 8) << 24);
uint32_t b0 = (a0 & 0x7f007f00u) >> 1;
uint32_t out0 = b0 | (a0 & 0x80008000u);
__half2 h;
asm("mov.b32 %0, %1;" : "=r"(*reinterpret_cast<uint32_t*>(&h)) : "r"(out0));
return h;
#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
// gfx942: same bit manipulation as CUDA, store packed fp16 bits via words[].
uint16_t in = v.storage.__x;
uint32_t a0 = ((uint32_t)(in & 0xFFu) << 8) | ((uint32_t)(in >> 8) << 24);
uint32_t b0 = (a0 & 0x7f007f00u) >> 1;
uint32_t out0 = b0 | (a0 & 0x80008000u);
f16x2 result;
result.words[0] = out0;
return result;
#else
f16x2 result;
result.data[0] = __float2half(float(v.data[0]));
result.data[1] = __float2half(float(v.data[1]));
return result;
#endif
}
/// f8_e4m3b15x4 -> f16x4.
/// Uses __byte_perm + lop3 for branch-free vectorized conversion.
template <>
MSCCLPP_DEVICE_INLINE f16x4 to<f16x4, f8_e4m3b15x4>(const f8_e4m3b15x4& v) {
#if defined(MSCCLPP_DEVICE_CUDA)
uint32_t in = v.storage.__x;
uint32_t a0 = __byte_perm(0u, in, 0x5746u);
uint32_t a0_shr = a0 >> 1;
uint32_t a0_sign = a0 & 0x80008000u;
uint32_t out0;
asm("lop3.b32 %0, %1, %2, %3, 0xEA;" : "=r"(out0) : "r"(a0_shr), "r"(0x3f803f80u), "r"(a0_sign));
uint32_t a1 = __byte_perm(a0, 0u, 0x2301u);
uint32_t a1_shr = a1 >> 1;
uint32_t a1_sign = a1 & 0x80008000u;
uint32_t out1;
asm("lop3.b32 %0, %1, %2, %3, 0xEA;" : "=r"(out1) : "r"(a1_shr), "r"(0x3f803f80u), "r"(a1_sign));
f16x4 result;
asm("mov.b32 %0, %1;" : "=r"(result.words[0]) : "r"(out0));
asm("mov.b32 %0, %1;" : "=r"(result.words[1]) : "r"(out1));
return result;
#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
// gfx942: __byte_perm + bitwise E4→E5 shift (no lop3), store via words[].
uint32_t in = v.storage.__x;
uint32_t a0 = __byte_perm(0u, in, 0x5746u);
uint32_t out0 = ((a0 >> 1) & 0x3f803f80u) | (a0 & 0x80008000u);
uint32_t a1 = __byte_perm(a0, 0u, 0x2301u);
uint32_t out1 = ((a1 >> 1) & 0x3f803f80u) | (a1 & 0x80008000u);
f16x4 result;
result.words[0] = out0;
result.words[1] = out1;
return result;
#else
f16x4 result;
#pragma unroll
for (int i = 0; i < 4; ++i) {
result.data[i] = __float2half(float(v.data[i]));
}
return result;
#endif
}
/// f16x2 -> f8_e4m3b15x2.
/// Direct fp16 -> fp8 via clamp + exponent shift E5->E4 + pack.
template <>
MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to<f8_e4m3b15x2, f16x2>(const f16x2& v) {
#if defined(MSCCLPP_DEVICE_CUDA)
uint32_t in0;
asm("mov.b32 %0, %1;" : "=r"(in0) : "r"(*reinterpret_cast<const uint32_t*>(&v)));
// Clamp abs to max finite e4m3b15 (0x3B80 = 0.9375 in fp16).
uint32_t lo = in0 & 0xFFFFu, hi = in0 >> 16;
uint32_t alo = lo & 0x7FFFu, ahi = hi & 0x7FFFu;
alo = alo < 0x3B80u ? alo : 0x3B80u;
ahi = ahi < 0x3B80u ? ahi : 0x3B80u;
uint32_t a0 = alo | (ahi << 16);
a0 = a0 * 2u + 0x00800080u;
uint32_t b0 = a0 | (in0 & 0x80008000u);
uint16_t packed = (uint16_t)(((b0 >> 8) & 0xFFu) | ((b0 >> 16) & 0xFF00u));
return bit_cast<f8_e4m3b15x2>(packed);
#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
// gfx942: read packed fp16 bits, clamp via v_pk_min_u16, shift E5→E4, pack.
uint32_t in0 = v.words[0];
uint32_t abs0 = in0 & 0x7fff7fffu;
uint32_t a0;
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3B803B80u));
a0 = a0 * 2u + 0x00800080u;
uint32_t b0 = a0 | (in0 & 0x80008000u);
uint16_t packed = (uint16_t)(((b0 >> 8) & 0xFFu) | ((b0 >> 16) & 0xFF00u));
return bit_cast<f8_e4m3b15x2>(packed);
#else
f8_e4m3b15x2 result;
result.data[0] = __fp8_e4m3b15(__half2float(v.data[0]));
result.data[1] = __fp8_e4m3b15(__half2float(v.data[1]));
return result;
#endif
}
/// f16x4 -> f8_e4m3b15x4.
/// Uses __vminu2 + lop3 + __byte_perm for branch-free vectorized conversion.
template <>
MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to<f8_e4m3b15x4, f16x4>(const f16x4& v) {
#if defined(MSCCLPP_DEVICE_CUDA)
uint32_t in0, in1;
asm("mov.b32 %0, %1;" : "=r"(in0) : "r"(v.words[0]));
asm("mov.b32 %0, %1;" : "=r"(in1) : "r"(v.words[1]));
uint32_t abs0 = in0 & 0x7fff7fffu;
uint32_t abs1 = in1 & 0x7fff7fffu;
uint32_t a0 = __vminu2(abs0, 0x3B803B80u);
uint32_t a1 = __vminu2(abs1, 0x3B803B80u);
a0 = a0 * 2u + 0x00800080u;
a1 = a1 * 2u + 0x00800080u;
uint32_t b0, b1;
asm("lop3.b32 %0, %1, %2, %3, 0xf8;" : "=r"(b0) : "r"(a0), "r"(in0), "r"(0x80008000u));
asm("lop3.b32 %0, %1, %2, %3, 0xf8;" : "=r"(b1) : "r"(a1), "r"(in1), "r"(0x80008000u));
uint32_t packed = __byte_perm(b0, b1, 0x7531u);
return bit_cast<f8_e4m3b15x4>(packed);
#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
// gfx942: read packed fp16 bits, clamp via v_pk_min_u16, shift E5→E4, __byte_perm pack.
uint32_t in0 = v.words[0], in1 = v.words[1];
uint32_t abs0 = in0 & 0x7fff7fffu, abs1 = in1 & 0x7fff7fffu;
uint32_t a0, a1;
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3B803B80u));
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a1) : "v"(abs1), "v"(0x3B803B80u));
a0 = a0 * 2u + 0x00800080u;
a1 = a1 * 2u + 0x00800080u;
uint32_t b0 = a0 | (in0 & 0x80008000u);
uint32_t b1 = a1 | (in1 & 0x80008000u);
uint32_t packed = __byte_perm(b0, b1, 0x7531u);
return bit_cast<f8_e4m3b15x4>(packed);
#else
f8_e4m3b15x4 result;
#pragma unroll
for (int i = 0; i < 4; ++i) {
result.data[i] = __fp8_e4m3b15(__half2float(v.data[i]));
}
return result;
#endif
}
// --- fp8_e4m3b15 <-> f32 conversion specializations (software, always available) ---
/// f8_e4m3b15x2 -> f32x2.
/// Routes through fp16: fp8→fp16 (bit manip) then fp16→f32.
template <>
MSCCLPP_DEVICE_INLINE f32x2 to<f32x2, f8_e4m3b15x2>(const f8_e4m3b15x2& v) {
#if defined(MSCCLPP_DEVICE_CUDA)
f16x2 h = to<f16x2, f8_e4m3b15x2>(v);
float2 f2 = __half22float2(h);
return bit_cast<f32x2>(f2);
#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
f16x2 h = to<f16x2, f8_e4m3b15x2>(v);
f32x2 result;
result.data[0] = __half2float(h.data[0]);
result.data[1] = __half2float(h.data[1]);
return result;
#else
f32x2 result;
result.data[0] = float(v.data[0]);
result.data[1] = float(v.data[1]);
return result;
#endif
}
/// f8_e4m3b15x4 -> f32x4.
/// Routes through fp16: fp8→fp16 (bit manip) then fp16→f32.
template <>
MSCCLPP_DEVICE_INLINE f32x4 to<f32x4, f8_e4m3b15x4>(const f8_e4m3b15x4& v) {
#if defined(MSCCLPP_DEVICE_CUDA)
f16x4 h = to<f16x4, f8_e4m3b15x4>(v);
__half2 h0, h1;
asm("mov.b32 %0, %1;" : "=r"(*reinterpret_cast<uint32_t*>(&h0)) : "r"(h.words[0]));
asm("mov.b32 %0, %1;" : "=r"(*reinterpret_cast<uint32_t*>(&h1)) : "r"(h.words[1]));
float2 f0 = __half22float2(h0);
float2 f1 = __half22float2(h1);
f32x4 result;
result.data[0] = f0.x;
result.data[1] = f0.y;
result.data[2] = f1.x;
result.data[3] = f1.y;
return result;
#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
f16x4 h = to<f16x4, f8_e4m3b15x4>(v);
f32x4 result;
result.data[0] = __half2float(h.data[0]);
result.data[1] = __half2float(h.data[1]);
result.data[2] = __half2float(h.data[2]);
result.data[3] = __half2float(h.data[3]);
return result;
#else
f32x4 result;
#pragma unroll
for (int i = 0; i < 4; ++i) {
result.data[i] = float(v.data[i]);
}
return result;
#endif
}
/// f32x2 -> f8_e4m3b15x2.
/// Routes through fp16: f32→fp16 then fp16→fp8 (clamp + exponent shift + pack).
template <>
MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to<f8_e4m3b15x2, f32x2>(const f32x2& v) {
#if defined(MSCCLPP_DEVICE_CUDA)
float2 f2 = {v.data[0], v.data[1]};
__half2 h = __float22half2_rn(f2);
return to<f8_e4m3b15x2, f16x2>(h);
#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
f16x2 h;
h.data[0] = __float2half_rn(v.data[0]);
h.data[1] = __float2half_rn(v.data[1]);
return to<f8_e4m3b15x2, f16x2>(h);
#else
f8_e4m3b15x2 result;
result.data[0] = __fp8_e4m3b15(v.data[0]);
result.data[1] = __fp8_e4m3b15(v.data[1]);
return result;
#endif
}
/// f32x4 -> f8_e4m3b15x4.
/// Routes through fp16: f32→fp16 then fp16→fp8 (clamp + exponent shift + pack).
template <>
MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to<f8_e4m3b15x4, f32x4>(const f32x4& v) {
#if defined(MSCCLPP_DEVICE_CUDA)
float2 f01 = {v.data[0], v.data[1]};
float2 f23 = {v.data[2], v.data[3]};
__half2 h01 = __float22half2_rn(f01);
__half2 h23 = __float22half2_rn(f23);
f16x4 h;
asm("mov.b32 %0, %1;" : "=r"(h.words[0]) : "r"(*reinterpret_cast<uint32_t*>(&h01)));
asm("mov.b32 %0, %1;" : "=r"(h.words[1]) : "r"(*reinterpret_cast<uint32_t*>(&h23)));
return to<f8_e4m3b15x4, f16x4>(h);
#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
f16x4 h;
h.words[0] = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(v.data[0], v.data[1]));
h.words[1] = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(v.data[2], v.data[3]));
return to<f8_e4m3b15x4, f16x4>(h);
#else
f8_e4m3b15x4 result;
#pragma unroll
for (int i = 0; i < 4; ++i) {
result.data[i] = __fp8_e4m3b15(v.data[i]);
}
return result;
#endif
}
// --- fp8_e4m3b15 arithmetic (software, always available) ---
template <bool UseClip = true>
MSCCLPP_DEVICE_INLINE __fp8_e4m3b15 operator+(const __fp8_e4m3b15& a, const __fp8_e4m3b15& b) {
return __fp8_e4m3b15(float(a) + float(b));
}
template <bool UseClip = true>
MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 operator+(const f8_e4m3b15x2& a, const f8_e4m3b15x2& b) {
f8_e4m3b15x2 result;
result.data[0] = __fp8_e4m3b15(float(a.data[0]) + float(b.data[0]));
result.data[1] = __fp8_e4m3b15(float(a.data[1]) + float(b.data[1]));
return result;
}
template <bool UseClip = true>
MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 operator+(const f8_e4m3b15x4& a, const f8_e4m3b15x4& b) {
f8_e4m3b15x4 result;
#pragma unroll
for (int i = 0; i < 4; ++i) {
result.data[i] = __fp8_e4m3b15(float(a.data[i]) + float(b.data[i]));
}
return result;
}
// --- fp8_e4m3b15 min (software) ---
template <>
MSCCLPP_DEVICE_INLINE __fp8_e4m3b15 min(const __fp8_e4m3b15& a, const __fp8_e4m3b15& b) {
return __fp8_e4m3b15(fminf(float(a), float(b)));
}
MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 min(const f8_e4m3b15x2& a, const f8_e4m3b15x2& b) {
f8_e4m3b15x2 result;
result.data[0] = mscclpp::min(a.data[0], b.data[0]);
result.data[1] = mscclpp::min(a.data[1], b.data[1]);
return result;
}
MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 min(const f8_e4m3b15x4& a, const f8_e4m3b15x4& b) {
f8_e4m3b15x4 result;
#pragma unroll
for (int i = 0; i < 4; ++i) {
result.data[i] = mscclpp::min(a.data[i], b.data[i]);
}
return result;
}
#endif // MSCCLPP_DEVICE_COMPILE
} // namespace mscclpp

View File

@@ -24,4 +24,7 @@ set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME _mscclpp)
set_target_properties(mscclpp_py PROPERTIES INSTALL_RPATH "\$ORIGIN/lib")
target_link_libraries(mscclpp_py PRIVATE dlpack mscclpp mscclpp_collectives ${GPU_LIBRARIES})
target_include_directories(mscclpp_py SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
if(MSCCLPP_USE_ROCM)
target_compile_definitions(mscclpp_py PRIVATE MSCCLPP_USE_ROCM)
endif()
install(TARGETS mscclpp_py LIBRARY DESTINATION .)

View File

@@ -75,15 +75,17 @@ void register_algorithm(nb::module_& m) {
[](Algorithm& self, std::shared_ptr<Communicator> comm, uintptr_t input, uintptr_t output,
size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op, uintptr_t stream,
std::shared_ptr<Executor> executor, int nBlocks, int nThreadsPerBlock, bool symmetricMemory,
std::unordered_map<std::string, uintptr_t> extras) {
std::unordered_map<std::string, uintptr_t> extras, int32_t accumDtype) {
return self.execute(comm, reinterpret_cast<const void*>(input), reinterpret_cast<void*>(output),
inputSize, outputSize, dtype, op, reinterpret_cast<cudaStream_t>(stream), executor,
nBlocks, nThreadsPerBlock, symmetricMemory, extras);
nBlocks, nThreadsPerBlock, symmetricMemory, extras,
static_cast<DataType>(accumDtype));
},
nb::arg("comm"), nb::arg("input"), nb::arg("output"), nb::arg("input_size"), nb::arg("output_size"),
nb::arg("dtype"), nb::arg("op") = ReduceOp::NOP, nb::arg("stream") = 0, nb::arg("executor") = nullptr,
nb::arg("n_blocks") = 0, nb::arg("n_threads_per_block") = 0, nb::arg("symmetric_memory") = false,
nb::arg("extras") = std::unordered_map<std::string, uintptr_t>())
nb::arg("extras") = std::unordered_map<std::string, uintptr_t>(),
nb::arg("accum_dtype") = static_cast<int32_t>(DataType::AUTO))
.def("reset", &Algorithm::reset);
nb::class_<Algorithm::Constraint>(algorithmClass, "Constraint")

View File

@@ -47,7 +47,8 @@ void register_core(nb::module_& m) {
.value("bfloat16", DataType::BFLOAT16)
.value("float8_e4m3", DataType::FLOAT8_E4M3)
.value("float8_e5m2", DataType::FLOAT8_E5M2)
.value("uint8", DataType::UINT8);
.value("uint8", DataType::UINT8)
.value("float8_e4m3b15", DataType::FLOAT8_E4M3B15);
nb::class_<Bootstrap>(m, "CppBootstrap")
.def("get_rank", &Bootstrap::getRank)

View File

@@ -4,6 +4,7 @@
#include <nanobind/nanobind.h>
#include <nanobind/stl/function.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/vector.h>

View File

@@ -34,6 +34,19 @@ static DLDataType getDlType(std::string type) {
return DLDataType{kDLBfloat, 16, 1};
} else if (type == "torch.float16") {
return DLDataType{kDLFloat, 16, 1};
} else if (type == "torch.float8_e4m3fn") {
return DLDataType{kDLFloat8_e4m3fn, 8, 1};
} else if (type == "torch.float8_e4m3fnuz") {
return DLDataType{kDLFloat8_e4m3fnuz, 8, 1};
} else if (type == "torch.float8_e5m2") {
return DLDataType{kDLFloat8_e5m2, 8, 1};
} else if (type == "torch.float8_e5m2fnuz") {
return DLDataType{kDLFloat8_e5m2fnuz, 8, 1};
} else if (type == "torch.uint8") {
return DLDataType{kDLUInt, 8, 1};
} else if (type == "fp8_e4m3b15") {
// No standard DLPack code for fp8_e4m3b15; store as raw uint8 bytes.
return DLDataType{kDLUInt, 8, 1};
} else {
throw Error("Unsupported type: " + type, ErrorCode::InvalidUsage);
}

View File

@@ -177,6 +177,7 @@ class Algorithm:
nthreads_per_block=0,
symmetric_memory: bool = False,
extras: Optional[Dict[str, int]] = None,
accum_dtype: Optional[CppDataType] = None,
) -> int:
"""Execute the collective algorithm.
@@ -194,10 +195,14 @@ class Algorithm:
nthreads_per_block: Number of threads per block (0 for auto-selection).
symmetric_memory: Whether to use symmetric memory optimization (default: False).
extras: Additional algorithm-specific parameters.
accum_dtype: Data type for accumulation during reduction. If None, defaults to
the same as dtype. Use DataType.float32 for high-precision FP8 accumulation.
Returns:
The result code (0 for success).
"""
merged_extras = dict(extras) if extras is not None else {}
accum_dtype = accum_dtype if accum_dtype is not None else dtype
return self._algorithm.execute(
comm,
int(input_buffer),
@@ -211,7 +216,8 @@ class Algorithm:
nblocks,
nthreads_per_block,
symmetric_memory,
extras if extras is not None else {},
merged_extras,
int(accum_dtype),
)
def reset(self):

View File

@@ -140,7 +140,7 @@ class MemoryChannel:
for tb_id in tb_list:
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type)
tb_channel_ids = get_program().setup_channel(tb, self)
tb_channel_ids = get_program().setup_channel(tb_id, self)
op = GetOperation(
src_buff=[RemoteChunk(src_chunk.buffer, src_chunk.index, src_chunk.size, tb_chunk_id)],
dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)],

View File

@@ -745,7 +745,7 @@ class ReduceOperation(BaseOperation):
remote_dst_buff=self.remote_dst_buff + other.dst_buff,
channel_ids=self.channel_ids,
put_channel_ids=self.put_channel_ids + other.channel_ids,
channel_type=self.channel_type,
channel_type=other.channel_type,
reduce_operation=self.reduce_operation,
tbg_info=self.tbg_info,
packet=self.packet,

View File

@@ -1,5 +1,5 @@
mpi4py==4.1.1
cupy==13.6.0
mpi4py
cupy
prettytable
netifaces
pytest

View File

@@ -0,0 +1,397 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Correctness test for FP8 allreduce with different accumulation types.
#
# Verifies that FP8 allreduce with higher-precision accumulation produces
# results at least as accurate as native FP8 accumulation, by comparing
# against a float32 reference.
#
# Usage:
# mpirun -np 8 pytest python/test/test_fp8_accum.py -v
import cupy as cp
import numpy as np
import pytest
from mscclpp import CommGroup, GpuBuffer, DataType, ReduceOp, is_nvls_supported
from mscclpp.ext import AlgorithmCollectionBuilder
from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group
# FP8 E4M3 (hardware) requires SM >= 89 (Ada / Hopper) on NVIDIA GPUs.
# On AMD/ROCm (e.g. MI300X), FP8 is supported natively — no skip needed.
_is_hip = hasattr(cp.cuda.runtime, "is_hip") and cp.cuda.runtime.is_hip
_skip_fp8 = not _is_hip and int(cp.cuda.Device().compute_capability) < 89
pytestmark = pytest.mark.skipif(_skip_fp8, reason="FP8 accum tests require SM >= 89 on CUDA")
# ---------------------------------------------------------------------------
# FP8 E4M3FN helpers (bias=7, no infinity, NaN = exp=15 & mant=7)
# ---------------------------------------------------------------------------
def e4m3fn_to_float(uint8_array):
"""Decode a cupy uint8 array of E4M3FN bit patterns to float32."""
bits = uint8_array.astype(cp.int32)
sign = (bits >> 7) & 1
exp = (bits >> 3) & 0xF
mant = bits & 0x7
# Normal: (-1)^s * 2^(exp-7) * (1 + mant/8)
normal_val = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), (exp - 7).astype(cp.int32))
# Subnormal (exp==0): (-1)^s * 2^(-6) * (mant/8)
subnormal_val = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-6))
result = cp.where(exp == 0, subnormal_val, normal_val)
result = cp.where(sign == 1, -result, result)
# Zero
result = cp.where((exp == 0) & (mant == 0), cp.float32(0.0), result)
# NaN: exp==15 & mant==7
nan_mask = (exp == 15) & (mant == 7)
result = cp.where(nan_mask, cp.float32(float("nan")), result)
return result
def float_to_e4m3fn(f32_array, chunk_size=65536):
"""Encode a cupy float32 array to uint8 E4M3FN bit patterns.
Uses a lookup-table approach: precompute all 128 positive E4M3FN values,
then find nearest match per element via chunked broadcast comparison.
"""
# Build lookup table of all 128 positive E4M3FN values (0x00..0x7F)
all_bytes = cp.arange(128, dtype=cp.uint8)
all_floats = e4m3fn_to_float(all_bytes) # (128,) float32
# Mark NaN entries as inf so they're never selected as nearest
all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats)
# Clamp input and extract sign
clamped = f32_array.astype(cp.float32)
clamped = cp.clip(clamped, -448.0, 448.0)
signs = (clamped < 0).astype(cp.uint8)
absval = cp.abs(clamped)
result = cp.zeros(absval.shape, dtype=cp.uint8)
n = absval.size
absval_flat = absval.ravel()
result_flat = result.ravel()
for start in range(0, n, chunk_size):
end = min(start + chunk_size, n)
chunk = absval_flat[start:end]
# (chunk_size, 128) difference matrix
diffs = cp.abs(chunk[:, None] - all_floats[None, :])
result_flat[start:end] = cp.argmin(diffs, axis=1).astype(cp.uint8)
# Combine with sign bit
result = result_flat.reshape(absval.shape)
result = result | (signs << 7)
# Handle exact zero
result = cp.where(absval == 0, cp.uint8(0), result)
return result
# ---------------------------------------------------------------------------
# FP8 E4M3B15 helpers (bias=15, max=0.9375, NaN = exp==15 or bits==0x80)
# ---------------------------------------------------------------------------
def e4m3b15_to_float(uint8_array):
"""Decode a cupy uint8 array of E4M3B15 bit patterns to float32."""
bits = uint8_array.astype(cp.int32)
sign = (bits >> 7) & 1
exp = (bits >> 3) & 0xF
mant = bits & 0x7
# Normal: (-1)^s * 2^(exp-15) * (1 + mant/8)
normal_val = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), (exp - 15).astype(cp.int32))
# Subnormal (exp==0): (-1)^s * 2^(-14) * (mant/8)
subnormal_val = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-14))
result = cp.where(exp == 0, subnormal_val, normal_val)
result = cp.where(sign == 1, -result, result)
# Zero
result = cp.where((exp == 0) & (mant == 0), cp.float32(0.0), result)
# NaN: exp==15 or negative zero (0x80)
nan_mask = (exp == 15) | (uint8_array.astype(cp.int32) == 0x80)
result = cp.where(nan_mask, cp.float32(float("nan")), result)
return result
def float_to_e4m3b15(f32_array, chunk_size=65536):
"""Encode a cupy float32 array to uint8 E4M3B15 bit patterns.
Same lookup-table approach as float_to_e4m3fn.
"""
# Build lookup table of all 128 positive E4M3B15 values (0x00..0x7F)
all_bytes = cp.arange(128, dtype=cp.uint8)
all_floats = e4m3b15_to_float(all_bytes) # (128,) float32
# Mark NaN entries as inf so they're never selected as nearest
all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats)
# Clamp input and extract sign
clamped = f32_array.astype(cp.float32)
clamped = cp.clip(clamped, -0.9375, 0.9375)
signs = (clamped < 0).astype(cp.uint8)
absval = cp.abs(clamped)
result = cp.zeros(absval.shape, dtype=cp.uint8)
n = absval.size
absval_flat = absval.ravel()
result_flat = result.ravel()
for start in range(0, n, chunk_size):
end = min(start + chunk_size, n)
chunk = absval_flat[start:end]
# (chunk_size, 128) difference matrix
diffs = cp.abs(chunk[:, None] - all_floats[None, :])
result_flat[start:end] = cp.argmin(diffs, axis=1).astype(cp.uint8)
# Combine with sign bit
result = result_flat.reshape(absval.shape)
result = result | (signs << 7)
# Handle exact zero
result = cp.where(absval == 0, cp.uint8(0), result)
return result
# ---------------------------------------------------------------------------
# Shared test helpers
# ---------------------------------------------------------------------------
def setup_algorithms(mpi_group):
"""Build default algorithms and return (comm_group, algo_map, scratch_buf)."""
comm_group = CommGroup(mpi_group.comm)
scratch = GpuBuffer(1 << 27, dtype=cp.uint8) # 128 MB
AlgorithmCollectionBuilder.reset()
builder = AlgorithmCollectionBuilder()
algorithms = builder.build_default_algorithms(
scratch_buffer=scratch.data.ptr,
scratch_buffer_size=scratch.nbytes,
rank=comm_group.my_rank,
)
algo_map = {a.name: a for a in algorithms}
return comm_group, algo_map, scratch
def run_allreduce(algo, comm_group, buffer, dtype, accum_dtype=None, nblocks=0, nthreads_per_block=0):
"""Run allreduce in-place on buffer and return a copy of the result."""
ret = algo.execute(
comm=comm_group.communicator,
input_buffer=buffer.data.ptr,
output_buffer=buffer.data.ptr,
input_size=buffer.nbytes,
output_size=buffer.nbytes,
dtype=dtype,
op=ReduceOp.SUM,
stream=cp.cuda.get_current_stream().ptr,
nblocks=nblocks,
nthreads_per_block=nthreads_per_block,
symmetric_memory=True,
accum_dtype=accum_dtype,
)
cp.cuda.Device().synchronize()
assert ret == 0, f"Allreduce failed with error code {ret}"
return buffer.copy()
# ---------------------------------------------------------------------------
# Test: FP8 E4M3 accumulation correctness
# ---------------------------------------------------------------------------
@parametrize_mpi_groups(8)
@pytest.mark.parametrize(
"algo_name",
[
"default_allreduce_packet",
"default_allreduce_nvls_packet",
"default_allreduce_fullmesh",
"default_allreduce_rsag_zero_copy",
"default_allreduce_allpair_packet",
],
)
@pytest.mark.parametrize("size", [1024, 4096, 16384, 65536, 262144, 1048576])
def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
"""Verify that FP8 E4M3 allreduce with higher-precision accumulation is at
least as accurate as native FP8 accumulation, across all algorithm variants."""
rank = mpi_group.comm.rank
world_size = mpi_group.comm.size
comm_group, algo_map, scratch = setup_algorithms(mpi_group)
if algo_name not in algo_map:
pytest.skip(f"{algo_name} not available")
if "nvls" in algo_name and not is_nvls_supported():
pytest.skip(f"{algo_name} requires NVLS which is not supported on this platform")
algo = algo_map[algo_name]
buf = GpuBuffer(size, dtype=cp.uint8)
accum_configs = [
("fp8_native", DataType.float8_e4m3),
("float16", DataType.float16),
("float32", DataType.float32),
]
# rsag_zero_copy and fullmesh need explicit block/thread counts
if "rsag" in algo_name:
nb = max(1, min(32, size // (world_size * 32)))
nt = 1024
elif "fullmesh" in algo_name:
nb = 35
nt = 512
else:
nb = 0
nt = 0
errors = {}
for accum_label, accum_dtype in accum_configs:
# Generate deterministic per-rank data (use numpy to avoid hipRAND issues on ROCm)
rng = np.random.RandomState(42 + rank)
src_f32 = cp.asarray(rng.randn(size).astype(np.float32))
src_f32 = cp.clip(src_f32, -240.0, 240.0)
src_fp8 = float_to_e4m3fn(src_f32)
# Copy into symmetric buffer
buf[:] = src_fp8
cp.cuda.Device().synchronize()
# Run allreduce
result = run_allreduce(
algo,
comm_group,
buf,
dtype=DataType.float8_e4m3,
accum_dtype=accum_dtype,
nblocks=nb,
nthreads_per_block=nt,
)
result_f32 = e4m3fn_to_float(result)
# Compute float32 reference: sum all ranks' quantized FP8 inputs in float32
ref_f32 = cp.zeros(size, dtype=cp.float32)
for r in range(world_size):
rng_r = np.random.RandomState(42 + r)
rank_data = cp.asarray(rng_r.randn(size).astype(np.float32))
rank_data = cp.clip(rank_data, -240.0, 240.0)
rank_data_fp8 = float_to_e4m3fn(rank_data)
ref_f32 += e4m3fn_to_float(rank_data_fp8)
# Compute errors
abs_err = cp.abs(result_f32 - ref_f32)
mean_abs_err = float(cp.mean(abs_err))
errors[accum_label] = mean_abs_err
# Reset between runs
algo.reset()
# Higher-precision accumulation should be at least as accurate as native fp8
assert (
errors["float16"] <= errors["fp8_native"] + 1e-6
), f"float16 accum ({errors['float16']:.6f}) worse than native ({errors['fp8_native']:.6f})"
assert (
errors["float32"] <= errors["fp8_native"] + 1e-6
), f"float32 accum ({errors['float32']:.6f}) worse than native ({errors['fp8_native']:.6f})"
# ---------------------------------------------------------------------------
# Test: FP8 E4M3B15 accumulation correctness
# ---------------------------------------------------------------------------
@parametrize_mpi_groups(8)
@pytest.mark.parametrize(
"algo_name",
[
"default_allreduce_packet",
"default_allreduce_nvls_packet",
"default_allreduce_rsag_zero_copy",
"default_allreduce_fullmesh",
"default_allreduce_allpair_packet",
],
)
@pytest.mark.parametrize("size", [1024, 4096, 65536])
def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int):
"""Verify that FP8 E4M3B15 allreduce with higher-precision accumulation is at
least as accurate as native E4M3B15 accumulation."""
rank = mpi_group.comm.rank
world_size = mpi_group.comm.size
comm_group, algo_map, scratch = setup_algorithms(mpi_group)
if algo_name not in algo_map:
pytest.skip(f"{algo_name} not available")
if "nvls" in algo_name and not is_nvls_supported():
pytest.skip(f"{algo_name} requires NVLS which is not supported on this platform")
algo = algo_map[algo_name]
buf = GpuBuffer(size, dtype=cp.uint8)
accum_configs = [
("e4m3b15_native", DataType.float8_e4m3b15),
("float16", DataType.float16),
("float32", DataType.float32),
]
# rsag_zero_copy needs explicit block/thread counts, scaled to data size
if "rsag" in algo_name:
nb = max(1, min(32, size // (world_size * 32)))
nt = 1024
else:
nb = 0
nt = 0
errors = {}
for accum_label, accum_dtype in accum_configs:
# Generate deterministic per-rank random uint8 values in valid e4m3b15 range
rng = np.random.RandomState(42 + rank)
raw = cp.asarray(rng.randint(0, 0x78, (size,)).astype(np.uint8))
signs = cp.asarray(rng.randint(0, 2, (size,)).astype(np.uint8)) << 7
src_uint8 = raw | signs
# Fix negative zero -> positive zero
src_uint8 = cp.where(src_uint8 == 0x80, cp.uint8(0), src_uint8)
# Copy into symmetric buffer
buf[:] = src_uint8
cp.cuda.Device().synchronize()
# Run allreduce
result = run_allreduce(
algo,
comm_group,
buf,
dtype=DataType.float8_e4m3b15,
accum_dtype=accum_dtype,
nblocks=nb,
nthreads_per_block=nt,
)
# Decode result
result_f32 = e4m3b15_to_float(result)
# Compute float32 reference
ref_f32 = cp.zeros(size, dtype=cp.float32)
for r in range(world_size):
rng_r = np.random.RandomState(42 + r)
raw_r = cp.asarray(rng_r.randint(0, 0x78, (size,)).astype(np.uint8))
signs_r = cp.asarray(rng_r.randint(0, 2, (size,)).astype(np.uint8)) << 7
bits_r = raw_r | signs_r
bits_r = cp.where(bits_r == 0x80, cp.uint8(0), bits_r)
ref_f32 += e4m3b15_to_float(bits_r)
# Clamp reference to e4m3b15 representable range
ref_f32 = cp.clip(ref_f32, -0.9375, 0.9375)
# Compute errors (only on valid entries)
valid = ~cp.isnan(result_f32) & ~cp.isnan(ref_f32)
abs_err = cp.abs(result_f32[valid] - ref_f32[valid])
mean_abs_err = float(cp.mean(abs_err)) if abs_err.size > 0 else 0.0
errors[accum_label] = mean_abs_err
algo.reset()
# Higher-precision accumulation should be at least as accurate as native
assert (
errors["float16"] <= errors["e4m3b15_native"] + 1e-8
), f"float16 accum ({errors['float16']:.8f}) worse than native ({errors['e4m3b15_native']:.8f})"
assert (
errors["float32"] <= errors["e4m3b15_native"] + 1e-8
), f"float32 accum ({errors['float32']:.8f}) worse than native ({errors['e4m3b15_native']:.8f})"

View File

@@ -41,7 +41,9 @@ NativeAlgorithm::NativeAlgorithm(std::string name, std::string collective, InitF
CommResult NativeAlgorithm::execute(std::shared_ptr<Communicator> comm, const void* input, void* output,
size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op,
cudaStream_t stream, std::shared_ptr<Executor>, int nBlocks, int nThreadsPerBlock,
bool symmetricMemory, const std::unordered_map<std::string, uintptr_t>& extras) {
bool symmetricMemory, const std::unordered_map<std::string, uintptr_t>& extras,
DataType accumDtype) {
if (accumDtype == DataType::AUTO) accumDtype = dtype;
if (!initialized_) {
initFunc_(comm);
initialized_ = true;
@@ -53,7 +55,7 @@ CommResult NativeAlgorithm::execute(std::shared_ptr<Communicator> comm, const vo
contexts_[ctxKey] = ctx;
}
return kernelLaunchFunc_(contexts_[ctxKey], input, output, inputSize, outputSize, dtype, op, stream, nBlocks,
nThreadsPerBlock, extras);
nThreadsPerBlock, extras, accumDtype);
}
const std::string& NativeAlgorithm::name() const { return name_; }
@@ -77,10 +79,7 @@ const CollectiveBufferMode& NativeAlgorithm::bufferMode() const { return bufferM
Algorithm::Constraint NativeAlgorithm::constraint() const { return constraint_; }
void NativeAlgorithm::reset() {
contexts_.clear();
initialized_ = false;
}
void NativeAlgorithm::reset() { contexts_.clear(); }
void AlgorithmCollection::registerAlgorithm(const std::string collective, const std::string algoName,
std::shared_ptr<Algorithm> algorithm) {
@@ -166,7 +165,7 @@ Algorithm::Constraint DslAlgorithm::constraint() const { return constraint_; }
CommResult DslAlgorithm::execute(std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
size_t outputSize, DataType dtype, ReduceOp, cudaStream_t stream,
std::shared_ptr<Executor> executor, int, int, bool,
const std::unordered_map<std::string, uintptr_t>&) {
const std::unordered_map<std::string, uintptr_t>&, DataType) {
if (!executor) {
THROW(EXEC, Error, ErrorCode::InvalidUsage, "Executor is null in DslAlgorithm::execute");
}
@@ -192,6 +191,10 @@ CommResult DslAlgorithm::execute(std::shared_ptr<Communicator> comm, const void*
plan_, stream);
break;
#endif
case DataType::FLOAT8_E4M3B15:
executor->execute(rank, (__fp8_e4m3b15*)input, (__fp8_e4m3b15*)output, inputSize, outputSize,
DataType::FLOAT8_E4M3B15, plan_, stream);
break;
case DataType::INT32:
case DataType::UINT32:
executor->execute(rank, (int*)input, (int*)output, inputSize, outputSize, DataType::UINT32, plan_, stream);

View File

@@ -56,7 +56,7 @@ Endpoint::Impl::Impl(const EndpointConfig& config, Context::Impl& contextImpl)
}
ibQp_ = contextImpl.getIbContext(config_.transport)
->createQp(config_.ib.port, gidIndex, config_.ib.maxCqSize, config_.ib.maxCqPollNum,
->createQp(config_.ib.port, config_.ib.gidIndex, config_.ib.maxCqSize, config_.ib.maxCqPollNum,
config_.ib.maxSendWr, maxRecvWr, config_.ib.maxWrPerSend, ibNoAtomic_);
ibQpInfo_ = ibQp_->getInfo();
} else if (config_.transport == Transport::Ethernet) {

View File

@@ -82,6 +82,12 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
case DataType::FLOAT8_E5M2:
// FP8 is not supported in CUDA execution kernel.
break;
case DataType::FLOAT8_E4M3B15:
// fp8_e4m3b15 is a software type not supported in the CUDA execution kernel.
break;
case DataType::AUTO:
// AUTO is a sentinel resolved before reaching this point; nothing to do.
break;
}
}

View File

@@ -210,7 +210,7 @@ MSCCLPP_DEVICE_INLINE void handleReadReduceSend(const Operation& op, void* input
sizeof(int4);
void* remoteMemory = static_cast<char*>(memoryChannelBufferPtrs_[op.inputBufferRefs[index + 1].id]);
val = mscclpp::read<int4>(remoteMemory, srcOffset + idx);
tmp = cal_vector<T, OpType>(tmp, val);
tmp = calVector<T, OpType>(tmp, val);
}
output4[outputOffset4 + idx] = tmp;
if constexpr (SendToRemote) {
@@ -353,9 +353,9 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPackets(const Operation& op, void* in
for (uint32_t index = 0; index < nSrcs; ++index) {
PacketType* pkt = (PacketType*)((char*)scratch + scratchOffset_ + 2 * inputOffsets[index]);
PacketPayload<PacketType> val = pkt[idx].read(flag_);
data = cal_vector<T, OpType>(data, val);
data = calVector<T, OpType>(data, val);
}
data = cal_vector<T, OpType>(data, srcPacketPayload[idx]);
data = calVector<T, OpType>(data, srcPacketPayload[idx]);
dstPacketPayload[idx] = data;
if constexpr (SendToRemote) {
@@ -394,9 +394,9 @@ MSCCLPP_DEVICE_INLINE void handleReduceCopySendPackets(const Operation& op, void
for (uint32_t index = 0; index < nSrcs; ++index) {
PacketType* pkt = (PacketType*)((char*)scratch + scratchOffset_ + 2 * inputOffsets[index]);
PacketPayload<PacketType> val = pkt[idx].read(flag_);
data = cal_vector<T, OpType>(data, val);
data = calVector<T, OpType>(data, val);
}
data = cal_vector<T, OpType>(data, srcPacketPayload[idx]);
data = calVector<T, OpType>(data, srcPacketPayload[idx]);
dstPacketPayload[idx] = data;
PacketType* dst_val = &dstPkt[idx];
dst_val->write(data, flag_);
@@ -464,7 +464,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(const Operation& op, void* input, vo
size_t buffOffset =
(inputOffsets[index] + getOffset<ReuseScratch>(outputBufferRefs[index].type, offset)) / sizeof(int4);
int4 val = buff4[buffOffset + idx];
tmp = cal_vector<T, OpType>(tmp, val);
tmp = calVector<T, OpType>(tmp, val);
}
dst4[dstOffset4 + idx] = tmp;
if constexpr (SendToRemote) {
@@ -899,6 +899,17 @@ class ExecutionKernel {
#endif
break;
#endif // __FP8_TYPES_EXIST__
case DataType::FLOAT8_E4M3B15:
executionKernel<__fp8_e4m3b15, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (__fp8_e4m3b15*)src, (__fp8_e4m3b15*)dst, (__fp8_e4m3b15*)scratch, scratchOffset, scratchChunkSize,
plan, semaphores, localMemoryIdBegin, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
#else
);
#endif
break;
case DataType::UINT8:
executionKernel<uint8_t, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (uint8_t*)src, (uint8_t*)dst, (uint8_t*)scratch, scratchOffset, scratchChunkSize, plan, semaphores,
@@ -910,6 +921,10 @@ class ExecutionKernel {
);
#endif
break;
case DataType::AUTO:
// AUTO is a sentinel that must be resolved before reaching this point.
assert(false && "DataType::AUTO must be resolved before kernel launch");
break;
}
}
#else // !defined(MSCCLPP_DEVICE_HIP)

View File

@@ -14,7 +14,7 @@ namespace mscclpp {
// Generic element-wise calculation helper
template <typename T, ReduceOp OpType>
MSCCLPP_DEVICE_INLINE T cal_elements(const T& a, const T& b) {
MSCCLPP_DEVICE_INLINE T calElements(const T& a, const T& b) {
if constexpr (OpType == SUM) {
return a + b;
} else if constexpr (OpType == MIN) {
@@ -24,56 +24,168 @@ MSCCLPP_DEVICE_INLINE T cal_elements(const T& a, const T& b) {
}
// Generic vector reduction helpers
template <typename T, ReduceOp OpType>
MSCCLPP_DEVICE_INLINE int4 cal_vector_helper(const int4& a, const int4& b) {
int4 ret;
ret.w = bit_cast<int, T>(cal_elements<T, OpType>(bit_cast<T, int>(a.w), bit_cast<T, int>(b.w)));
ret.x = bit_cast<int, T>(cal_elements<T, OpType>(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
ret.y = bit_cast<int, T>(cal_elements<T, OpType>(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
ret.z = bit_cast<int, T>(cal_elements<T, OpType>(bit_cast<T, int>(a.z), bit_cast<T, int>(b.z)));
return ret;
}
template <typename T, ReduceOp OpType>
MSCCLPP_DEVICE_INLINE uint2 cal_vector_helper(const uint2& a, const uint2& b) {
MSCCLPP_DEVICE_INLINE uint2 calVectorHelper(const uint2& a, const uint2& b) {
uint2 ret;
ret.x = bit_cast<uint32_t, T>(cal_elements<T, OpType>(bit_cast<T, uint32_t>(a.x), bit_cast<T, uint32_t>(b.x)));
ret.y = bit_cast<uint32_t, T>(cal_elements<T, OpType>(bit_cast<T, uint32_t>(a.y), bit_cast<T, uint32_t>(b.y)));
ret.x = bit_cast<uint32_t, T>(calElements<T, OpType>(bit_cast<T, uint32_t>(a.x), bit_cast<T, uint32_t>(b.x)));
ret.y = bit_cast<uint32_t, T>(calElements<T, OpType>(bit_cast<T, uint32_t>(a.y), bit_cast<T, uint32_t>(b.y)));
return ret;
}
template <typename T, ReduceOp OpType>
MSCCLPP_DEVICE_INLINE int cal_vector_helper(const int& a, const int& b) {
return bit_cast<int, T>(cal_elements<T, OpType>(bit_cast<T, int>(a), bit_cast<T, int>(b)));
/// f32x2 specialization for uint2: uses packed f32x2 operator+ (Blackwell __fadd2_rn when available).
template <>
MSCCLPP_DEVICE_INLINE uint2 calVectorHelper<f32x2, SUM>(const uint2& a, const uint2& b) {
f32x2 fa = bit_cast<f32x2, uint2>(a);
f32x2 fb = bit_cast<f32x2, uint2>(b);
f32x2 fr = fa + fb;
return bit_cast<uint2, f32x2>(fr);
}
template <>
MSCCLPP_DEVICE_INLINE uint2 calVectorHelper<f32x2, MIN>(const uint2& a, const uint2& b) {
f32x2 fa = bit_cast<f32x2, uint2>(a);
f32x2 fb = bit_cast<f32x2, uint2>(b);
f32x2 fr = mscclpp::min(fa, fb);
return bit_cast<uint2, f32x2>(fr);
}
template <typename T, ReduceOp OpType>
MSCCLPP_DEVICE_INLINE uint32_t cal_vector_helper(const uint32_t& a, const uint32_t& b) {
return bit_cast<uint32_t, T>(cal_elements<T, OpType>(bit_cast<T, uint32_t>(a), bit_cast<T, uint32_t>(b)));
MSCCLPP_DEVICE_INLINE int4 calVectorHelper(const int4& a, const int4& b) {
int4 ret;
ret.w = bit_cast<int, T>(calElements<T, OpType>(bit_cast<T, int>(a.w), bit_cast<T, int>(b.w)));
ret.x = bit_cast<int, T>(calElements<T, OpType>(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
ret.y = bit_cast<int, T>(calElements<T, OpType>(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
ret.z = bit_cast<int, T>(calElements<T, OpType>(bit_cast<T, int>(a.z), bit_cast<T, int>(b.z)));
return ret;
}
// cal_vector wrapper - converts scalar types to vector types and calls cal_vector_helper
/// f32x2 specialization for int4: process as two uint2 pairs using packed f32x2 arithmetic.
template <>
MSCCLPP_DEVICE_INLINE int4 calVectorHelper<f32x2, SUM>(const int4& a, const int4& b) {
uint2 lo_a = {(uint32_t)a.x, (uint32_t)a.y};
uint2 hi_a = {(uint32_t)a.z, (uint32_t)a.w};
uint2 lo_b = {(uint32_t)b.x, (uint32_t)b.y};
uint2 hi_b = {(uint32_t)b.z, (uint32_t)b.w};
uint2 lo_r = calVectorHelper<f32x2, SUM>(lo_a, lo_b);
uint2 hi_r = calVectorHelper<f32x2, SUM>(hi_a, hi_b);
return {(int)lo_r.x, (int)lo_r.y, (int)hi_r.x, (int)hi_r.y};
}
template <>
MSCCLPP_DEVICE_INLINE int4 calVectorHelper<f32x2, MIN>(const int4& a, const int4& b) {
uint2 lo_a = {(uint32_t)a.x, (uint32_t)a.y};
uint2 hi_a = {(uint32_t)a.z, (uint32_t)a.w};
uint2 lo_b = {(uint32_t)b.x, (uint32_t)b.y};
uint2 hi_b = {(uint32_t)b.z, (uint32_t)b.w};
uint2 lo_r = calVectorHelper<f32x2, MIN>(lo_a, lo_b);
uint2 hi_r = calVectorHelper<f32x2, MIN>(hi_a, hi_b);
return {(int)lo_r.x, (int)lo_r.y, (int)hi_r.x, (int)hi_r.y};
}
template <typename T, ReduceOp OpType>
MSCCLPP_DEVICE_INLINE int calVectorHelper(const int& a, const int& b) {
return bit_cast<int, T>(calElements<T, OpType>(bit_cast<T, int>(a), bit_cast<T, int>(b)));
}
template <typename T, ReduceOp OpType>
MSCCLPP_DEVICE_INLINE uint32_t calVectorHelper(const uint32_t& a, const uint32_t& b) {
return bit_cast<uint32_t, T>(calElements<T, OpType>(bit_cast<T, uint32_t>(a), bit_cast<T, uint32_t>(b)));
}
/// f32x2 specialization for uint32_t: a single float packed in 32 bits (scalar fallback).
template <>
MSCCLPP_DEVICE_INLINE uint32_t calVectorHelper<f32x2, SUM>(const uint32_t& a, const uint32_t& b) {
float fa = bit_cast<float, uint32_t>(a);
float fb = bit_cast<float, uint32_t>(b);
return bit_cast<uint32_t, float>(fa + fb);
}
template <>
MSCCLPP_DEVICE_INLINE uint32_t calVectorHelper<f32x2, MIN>(const uint32_t& a, const uint32_t& b) {
float fa = bit_cast<float, uint32_t>(a);
float fb = bit_cast<float, uint32_t>(b);
return bit_cast<uint32_t, float>(fminf(fa, fb));
}
// calVector wrapper converts scalar types to vector types and calls calVectorHelper
template <typename T, ReduceOp OpType, typename DataType>
MSCCLPP_DEVICE_INLINE DataType cal_vector(const DataType& a, const DataType& b) {
MSCCLPP_DEVICE_INLINE DataType calVector(const DataType& a, const DataType& b) {
// Define the vectorized computation type based on the element type
static_assert(sizeof(DataType) % sizeof(T) == 0, "DataType size must be multiple of T size");
static_assert(sizeof(DataType) >= 4, "DataType size must be at least 4 bytes");
using CompType = typename std::conditional_t<
std::is_same_v<T, __half>, f16x2,
std::is_same_v<T, float>, f32x2,
std::conditional_t<
std::is_same_v<T, __bfloat16>, bf16x2,
std::conditional_t<std::is_same_v<T, uint8_t>, u8x4,
std::is_same_v<T, __half>, f16x2,
std::conditional_t<
std::is_same_v<T, __bfloat16>, bf16x2,
std::conditional_t<
std::is_same_v<T, uint8_t>, u8x4,
std::conditional_t<std::is_same_v<T, __fp8_e4m3b15>, f8_e4m3b15x4,
#if defined(__FP8_TYPES_EXIST__)
std::conditional_t<std::is_same_v<T, __fp8_e4m3>, f8_e4m3x4,
std::conditional_t<std::is_same_v<T, __fp8_e5m2>, f8_e5m2x4,
#endif
T
#if defined(__FP8_TYPES_EXIST__)
>>>>>;
std::conditional_t<std::is_same_v<T, __fp8_e4m3>, f8_e4m3x4,
std::conditional_t<std::is_same_v<T, __fp8_e5m2>, f8_e5m2x4, T>>
#else
>>>;
T
#endif
return cal_vector_helper<CompType, OpType>(a, b);
>>>>>;
return calVectorHelper<CompType, OpType>(a, b);
}
/// Upcast a packed DataType (containing T elements) to a packed AccDataType (containing AccumT elements).
/// Uses the optimized to<>() specializations when available (e.g. FP8 -> float hardware intrinsics).
/// When AccumT == T, this is a no-op identity.
template <typename T, typename AccumT, typename AccDataType, typename DataType>
MSCCLPP_DEVICE_INLINE AccDataType upcastVector(const DataType& val) {
if constexpr (std::is_same_v<T, AccumT>) {
return val;
} else {
constexpr int nElems = sizeof(DataType) / sizeof(T);
using FromVec = VectorType<T, nElems>;
using ToVec = VectorType<AccumT, nElems>;
ToVec result = mscclpp::to<ToVec>(reinterpret_cast<const FromVec&>(val));
return reinterpret_cast<const AccDataType&>(result);
}
}
/// Downcast a packed AccDataType (containing AccumT elements) back to DataType (containing T elements).
/// Uses the optimized to<>() specializations when available.
/// When AccumT == T, this is a no-op identity.
template <typename T, typename AccumT, typename DataType, typename AccDataType>
MSCCLPP_DEVICE_INLINE DataType downcastVector(const AccDataType& val) {
if constexpr (std::is_same_v<T, AccumT>) {
return val;
} else {
constexpr int nElems = sizeof(DataType) / sizeof(T);
using FromVec = VectorType<T, nElems>;
using ToVec = VectorType<AccumT, nElems>;
FromVec result = mscclpp::to<FromVec>(reinterpret_cast<const ToVec&>(val));
return reinterpret_cast<const DataType&>(result);
}
}
/// Accumulate `val` (packed T elements in DataType) into `acc` (packed AccumT elements in AccDataType).
/// When AccumT == T, falls back to the standard calVector.
/// Otherwise, upcasts val to AccumT, reduces element-wise, and returns the AccumT accumulator.
template <typename T, typename AccumT, ReduceOp OpType, typename AccDataType, typename DataType>
MSCCLPP_DEVICE_INLINE AccDataType calVectorAccum(const AccDataType& acc, const DataType& val) {
if constexpr (std::is_same_v<T, AccumT>) {
return calVector<T, OpType>(acc, val);
} else {
constexpr int nElems = sizeof(DataType) / sizeof(T);
using FromVec = VectorType<T, nElems>;
using ToVec = VectorType<AccumT, nElems>;
ToVec fv = mscclpp::to<ToVec>(reinterpret_cast<const FromVec&>(val));
const ToVec& fa = reinterpret_cast<const ToVec&>(acc);
ToVec fr;
#pragma unroll
for (int i = 0; i < nElems; ++i) {
fr.data[i] = calElements<AccumT, OpType>(fa.data[i], fv.data[i]);
}
return reinterpret_cast<const AccDataType&>(fr);
}
}
#endif // defined(MSCCLPP_DEVICE_COMPILE)

View File

@@ -183,7 +183,8 @@ std::shared_ptr<Algorithm> AllgatherFullmesh::build() {
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize, [[maybe_unused]] DataType dtype, [[maybe_unused]] ReduceOp op,
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>& extras) -> CommResult {
const std::unordered_map<std::string, uintptr_t>& extras,
[[maybe_unused]] DataType accumDtype) -> CommResult {
return self->allgatherKernelFunc(ctx, input, output, inputSize, stream, nBlocks, nThreadsPerBlock, extras);
},
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t inputSize,

View File

@@ -212,7 +212,8 @@ std::shared_ptr<Algorithm> AllgatherFullmesh2::build() {
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize, [[maybe_unused]] mscclpp::DataType dtype, [[maybe_unused]] ReduceOp op,
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>& extras) -> mscclpp::CommResult {
const std::unordered_map<std::string, uintptr_t>& extras,
[[maybe_unused]] mscclpp::DataType accumDtype) -> mscclpp::CommResult {
return self->allgatherKernelFunc(ctx, input, output, inputSize, stream, nBlocks, nThreadsPerBlock, extras);
},
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t inputSize,

View File

@@ -2,6 +2,7 @@
// Licensed under the MIT license.
#include <collective_utils.hpp>
#include <type_traits>
#include "allreduce/allreduce_allpair_packet.hpp"
#include "allreduce/common.hpp"
@@ -11,7 +12,7 @@
namespace mscclpp {
namespace collective {
template <ReduceOp OpType, typename T>
template <ReduceOp OpType, typename T, typename AccumT = T>
__global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHandle<MemoryChannel>* memoryChannels,
size_t channelDataOffset, size_t scratchBufferSize, int rank, int nRanksPerNode,
int worldSize, size_t nelems, uint32_t numScratchBuff, void* flags,
@@ -43,13 +44,16 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand
// step 2: Reduce Data
for (size_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nelems; idx += blockDim.x * gridDim.x) {
uint32_t data = src[idx];
using AccRaw = std::conditional_t<std::is_same_v<T, AccumT>, uint32_t,
mscclpp::VectorType<AccumT, sizeof(uint32_t) / sizeof(T)>>;
AccRaw acc = mscclpp::upcastVector<T, AccumT, AccRaw>(data);
for (int index = 0; index < nPeers; index++) {
const int remoteRank = index < rank ? index : index + 1;
LL8Packet* dstPkt = (LL8Packet*)scratchBuff + remoteRank * nelems;
uint32_t val = dstPkt[idx].read(flag, -1);
data = cal_vector<T, OpType>(val, data);
acc = mscclpp::calVectorAccum<T, AccumT, OpType, AccRaw>(acc, val);
}
dst[idx] = data;
dst[idx] = mscclpp::downcastVector<T, AccumT, uint32_t>(acc);
}
__syncthreads();
if (threadIdx.x == 0) {
@@ -67,7 +71,7 @@ inline std::pair<int, int> getDefaultBlockNumAndThreadNum(size_t inputSize, int
return {(worldSize - 1) * 4, 512};
}
template <ReduceOp OpType, typename T>
template <ReduceOp OpType, typename T, typename AccumT = T>
struct AllpairAdapter {
static cudaError_t call(const void* buff, void* scratch, void* resultBuff, void* memoryChannels, void*,
DeviceHandle<SwitchChannel>*, DeviceHandle<SwitchChannel>*, size_t channelInOffset, size_t,
@@ -76,7 +80,12 @@ struct AllpairAdapter {
int nThreadsPerBlock = 0) {
using ChannelType = DeviceHandle<MemoryChannel>;
const size_t nelems = inputSize / sizeof(T);
allreduceAllPairs<OpType, T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
// Round nBlocks to multiple of nPeers so every block maps to a valid peer.
const int nPeers = worldSize - 1;
if (nPeers > 0) {
nBlocks = (nBlocks / nPeers) * nPeers;
}
allreduceAllPairs<OpType, T, AccumT><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
(T*)buff, (T*)scratch, (T*)resultBuff, (ChannelType*)memoryChannels, channelInOffset, scratchBufferSize, rank,
nRanksPerNode, worldSize, nelems, numScratchBuff, flags, flagSize);
return cudaGetLastError();
@@ -94,18 +103,24 @@ void AllreduceAllpairPacket::initialize(std::shared_ptr<Communicator> comm) {
CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output,
size_t inputSize, [[maybe_unused]] DataType dtype, ReduceOp op,
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>&) {
const std::unordered_map<std::string, uintptr_t>&,
DataType accumDtype) {
auto algoCtx = std::static_pointer_cast<AlgorithmCtx>(ctx);
std::pair<int, int> blockAndThreadNum{nBlocks, nThreadsPerBlock};
if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) {
blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize, algoCtx->workSize);
}
// nBlocks must be at least nPeers for allpair — each block maps to one peer.
const int nPeers = algoCtx->nRanksPerNode - 1;
if (nPeers > 0 && blockAndThreadNum.first < nPeers) {
return CommResult::CommInvalidArgument;
}
size_t sendBytes;
CUdeviceptr sendBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)input));
size_t channelInOffset = (char*)input - (char*)sendBasePtr;
AllreduceFunc allreduce = dispatch<AllpairAdapter>(op, dtype);
AllreduceFunc allreduce = dispatch<AllpairAdapter>(op, dtype, accumDtype);
if (!allreduce) {
WARN("Unsupported operation or data type for allreduce: op=%d, dtype=%d", op, static_cast<int>(dtype));
return CommResult::CommInvalidArgument;
@@ -161,9 +176,9 @@ std::shared_ptr<Algorithm> AllreduceAllpairPacket::build() {
[self](std::shared_ptr<Communicator> comm) { self->initialize(comm); },
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) {
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype) {
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
extras);
extras, accumDtype);
},
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize,

View File

@@ -9,7 +9,7 @@
namespace mscclpp {
namespace collective {
template <ReduceOp OpType, typename T>
template <ReduceOp OpType, typename T, typename AccumT = T>
__global__ void __launch_bounds__(512, 1)
allreduceFullmesh(T* buff, T* scratch, T* resultBuff, DeviceHandle<MemoryChannel>* memoryChannels,
DeviceHandle<MemoryChannel>* memoryOutChannels, size_t channelOutDataOffset, int rank,
@@ -26,6 +26,10 @@ __global__ void __launch_bounds__(512, 1)
int4* scratch4 = reinterpret_cast<int4*>((char*)scratch);
int4* resultBuff4 = reinterpret_cast<int4*>(resultBuff);
// AccumVec: wider vector for mixed-precision accumulation. When AccumT==T, this is just int4 (no-op).
constexpr int nElemsPerInt4 = sizeof(int4) / sizeof(T);
using AccumVec = std::conditional_t<std::is_same_v<T, AccumT>, int4, mscclpp::VectorType<AccumT, nElemsPerInt4>>;
// Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4`
constexpr size_t unitNInt4 = 512;
const size_t maxNInt4PerBlock =
@@ -81,12 +85,14 @@ __global__ void __launch_bounds__(512, 1)
__syncthreads();
for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) {
int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock];
int4 rawData = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock];
AccumVec acc = mscclpp::upcastVector<T, AccumT, AccumVec>(rawData);
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx];
data = cal_vector<T, OpType>(val, data);
acc = mscclpp::calVectorAccum<T, AccumT, OpType, AccumVec>(acc, val);
}
int4 data = mscclpp::downcastVector<T, AccumT, int4>(acc);
resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data;
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4),
@@ -121,12 +127,14 @@ __global__ void __launch_bounds__(512, 1)
__syncthreads();
for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) {
int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock];
int4 rawData = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock];
AccumVec acc = mscclpp::upcastVector<T, AccumT, AccumVec>(rawData);
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx];
data = cal_vector<T, OpType>(val, data);
acc = mscclpp::calVectorAccum<T, AccumT, OpType, AccumVec>(acc, val);
}
int4 data = mscclpp::downcastVector<T, AccumT, int4>(acc);
resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data;
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4),
@@ -144,7 +152,7 @@ __global__ void __launch_bounds__(512, 1)
}
}
template <ReduceOp OpType, typename T>
template <ReduceOp OpType, typename T, typename AccumT = T>
struct AllreduceAllconnectAdapter {
static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* memoryOutChannels,
DeviceHandle<SwitchChannel>*, DeviceHandle<SwitchChannel>*, size_t,
@@ -155,7 +163,7 @@ struct AllreduceAllconnectAdapter {
size_t nelems = inputSize / sizeof(T);
if (nBlocks == 0) nBlocks = 35;
if (nThreadsPerBlock == 0) nThreadsPerBlock = 512;
allreduceFullmesh<OpType, T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
allreduceFullmesh<OpType, T, AccumT><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
(T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, (ChannelType*)memoryOutChannels,
channelOutDataOffset, rank, nRanksPerNode, worldSize, nelems);
return cudaGetLastError();
@@ -174,10 +182,10 @@ void AllreduceFullmesh::initialize(std::shared_ptr<Communicator> comm) {
localScratchMemory_ = std::move(localMemory);
}
CommResult AllreduceFullmesh::allreduceKernelFunc(const std::shared_ptr<void> ctx_void, const void* input, void* output,
size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream,
int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>&) {
CommResult AllreduceFullmesh::allreduceKernelFunc(
const std::shared_ptr<void> ctx_void, const void* input, void* output, size_t inputSize, DataType dtype,
ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
[[maybe_unused]] const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype) {
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
size_t recvBytes;
CUdeviceptr recvBasePtr;
@@ -198,13 +206,20 @@ CommResult AllreduceFullmesh::allreduceKernelFunc(const std::shared_ptr<void> ct
}
inputChannelHandles = this->memoryChannelsMap_[input].second;
AllreduceFunc allreduce = dispatch<AllreduceAllconnectAdapter>(op, dtype);
AllreduceFunc allreduce = dispatch<AllreduceAllconnectAdapter>(op, dtype, accumDtype);
if (!allreduce) {
WARN("Unsupported operation or data type for allreduce: op=%d, dtype=%d", static_cast<int>(op),
static_cast<int>(dtype));
return CommResult::CommInvalidArgument;
}
std::pair<int, int> numBlocksAndThreads = {nBlocks, nThreadsPerBlock};
if (numBlocksAndThreads.first > 64) {
WARN("AllreduceFullmesh: number of blocks exceeds maximum supported blocks, which is 64");
return mscclpp::CommResult::CommInvalidArgument;
}
if (numBlocksAndThreads.first == 0 || numBlocksAndThreads.second == 0) {
numBlocksAndThreads = {35, 512};
}
cudaError_t error =
allreduce(input, this->scratchBuffer_, output, inputChannelHandles.get(), ctx->memoryChannelDeviceHandles.get(),
nullptr, nullptr, 0, channelOutOffset, 0, ctx->rank, ctx->nRanksPerNode, ctx->workSize, inputSize,
@@ -261,9 +276,10 @@ std::shared_ptr<Algorithm> AllreduceFullmesh::build() {
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) -> CommResult {
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras,
DataType accumDtype) -> CommResult {
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
extras);
extras, accumDtype);
},
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize,

View File

@@ -146,7 +146,7 @@ __global__ void __launch_bounds__(1024, 1)
#endif
}
template <ReduceOp OpType, typename T>
template <ReduceOp OpType, typename T, typename AccumT = T>
struct NvlsBlockPipelineAdapter {
static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void*,
DeviceHandle<SwitchChannel>* nvlsChannels, DeviceHandle<SwitchChannel>*, size_t, size_t,
@@ -155,6 +155,9 @@ struct NvlsBlockPipelineAdapter {
// uint8_t is not supported for NVLS (no hardware support for byte-level reduction)
if constexpr (std::is_same_v<T, uint8_t>) {
return cudaErrorNotSupported;
} else if constexpr (std::is_same_v<T, __fp8_e4m3b15>) {
// fp8_e4m3b15 is a software-only type with no hardware NVLS support.
return cudaErrorNotSupported;
} else
#if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS
if constexpr (std::is_same_v<T, __fp8_e4m3> || std::is_same_v<T, __fp8_e5m2>) {
@@ -187,9 +190,10 @@ void AllreduceNvlsBlockPipeline::initialize(std::shared_ptr<Communicator> comm)
CommResult AllreduceNvlsBlockPipeline::allreduceKernelFunc(const std::shared_ptr<void> ctx_void, const void* input,
void* output, size_t inputSize, DataType dtype, ReduceOp op,
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>&) {
const std::unordered_map<std::string, uintptr_t>& extras,
DataType accumDtype) {
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
AllreduceFunc allreduce = dispatch<NvlsBlockPipelineAdapter>(op, dtype);
AllreduceFunc allreduce = dispatch<NvlsBlockPipelineAdapter>(op, dtype, accumDtype);
if (!allreduce) {
WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast<int>(dtype));
return CommResult::CommInvalidArgument;
@@ -235,9 +239,9 @@ std::shared_ptr<Algorithm> AllreduceNvlsBlockPipeline::build() {
[self](std::shared_ptr<Communicator> comm) { self->initialize(comm); },
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) {
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype) {
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
extras);
extras, accumDtype);
},
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize,

View File

@@ -1,15 +1,17 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include <type_traits>
#include "allreduce/allreduce_nvls_packet.hpp"
#include "allreduce/common.hpp"
#include "collective_utils.hpp"
#include "debug.h"
#include "logger.hpp"
namespace mscclpp {
namespace collective {
template <ReduceOp OpType, typename T>
template <ReduceOp OpType, typename T, typename AccumT = T>
__global__ void __launch_bounds__(1024, 1)
allreduceNvlsPacket([[maybe_unused]] const T* input, [[maybe_unused]] T* scratch, [[maybe_unused]] T* output,
[[maybe_unused]] mscclpp::DeviceHandle<mscclpp::SwitchChannel>* multicast,
@@ -31,15 +33,16 @@ __global__ void __launch_bounds__(1024, 1)
mscclpp::SwitchChannelDeviceHandle::multimemStore(*(mscclpp::f32x2*)(&pkt), multiPkt + i);
}
for (uint32_t i = tid; i < nPktPerRank * worldSize; i += blockDim.x * gridDim.x) {
uint data = src[i];
// When T == AccumT, stay with raw uint to avoid type mismatch in identity path.
using AccRaw =
std::conditional_t<std::is_same_v<T, AccumT>, uint, mscclpp::VectorType<AccumT, sizeof(uint) / sizeof(T)>>;
AccRaw acc = mscclpp::upcastVector<T, AccumT, AccRaw>(src[i]);
for (int peer = 0; peer < worldSize; peer++) {
if (peer == rank) {
continue;
}
if (peer == rank) continue;
uint val = scratchPkt[peer * worldSize * nPktPerRank + i].read(flag);
data = cal_vector<T, OpType>(data, val);
acc = mscclpp::calVectorAccum<T, AccumT, OpType, AccRaw>(acc, val);
}
dst[i] = data;
dst[i] = mscclpp::downcastVector<T, AccumT, uint>(acc);
}
__syncthreads();
if (threadIdx.x == 0) {
@@ -62,13 +65,13 @@ inline std::pair<int, int> getDefaultBlockNumAndThreadNum(size_t inputSize) {
return {blockNum, threadNum};
}
template <ReduceOp OpType, typename T>
template <ReduceOp OpType, typename T, typename AccumT = T>
struct AllreduceNvlsPacketAdapter {
static cudaError_t call(const void* input, void* scratch, void* output, void*, void*,
DeviceHandle<SwitchChannel>* nvlsChannels, DeviceHandle<SwitchChannel>*, size_t, size_t,
size_t scratchBufferSize, int rank, int, int worldSize, size_t inputSize, cudaStream_t stream,
void* flags, uint32_t flagBufferSize, uint32_t, int nBlocks, int nThreadsPerBlock) {
allreduceNvlsPacket<OpType, T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
allreduceNvlsPacket<OpType, T, AccumT><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
(const T*)input, (T*)scratch, (T*)output, nvlsChannels, inputSize / sizeof(T), scratchBufferSize, rank,
worldSize, flags, flagBufferSize);
return cudaGetLastError();
@@ -78,6 +81,8 @@ struct AllreduceNvlsPacketAdapter {
void AllreduceNvlsPacket::initialize(std::shared_ptr<Communicator> comm) {
int nSwitchChannels = 1;
this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels);
this->switchChannels_ =
setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels);
}
AlgorithmCtxKey AllreduceNvlsPacket::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) {
@@ -92,9 +97,7 @@ std::shared_ptr<void> AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr<
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
// setup channels
int nSwitchChannels = 1;
ctx->switchChannels =
setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels);
ctx->switchChannels = this->switchChannels_;
ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels);
return ctx;
}
@@ -102,19 +105,20 @@ std::shared_ptr<void> AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr<
CommResult AllreduceNvlsPacket::allreduceKernelFunc(const std::shared_ptr<void> ctx_void, const void* input,
void* output, size_t inputSize, mscclpp::DataType dtype,
ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>&) {
const std::unordered_map<std::string, uintptr_t>&,
mscclpp::DataType accumDtype) {
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
std::pair<int, int> blockAndThreadNum = {nBlocks, nThreadsPerBlock};
if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) {
blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize);
}
if (blockAndThreadNum.first > maxBlockNum_) {
WARN("Block number %d exceeds the maximum limit %d", blockAndThreadNum.first, maxBlockNum_);
WARN(ALGO, "Block number ", blockAndThreadNum.first, " exceeds the maximum limit ", maxBlockNum_);
return CommResult::CommInvalidArgument;
}
AllreduceFunc allreduce = dispatch<AllreduceNvlsPacketAdapter>(op, dtype);
AllreduceFunc allreduce = dispatch<AllreduceNvlsPacketAdapter>(op, dtype, accumDtype);
if (!allreduce) {
WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast<int>(dtype));
WARN(ALGO, "Unsupported operation or data type for allreduce, dtype=", static_cast<int>(dtype));
return CommResult::CommInvalidArgument;
}
cudaError_t error =
@@ -122,7 +126,7 @@ CommResult AllreduceNvlsPacket::allreduceKernelFunc(const std::shared_ptr<void>
0, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerNode, ctx->workSize, inputSize, stream,
(void*)flagBuffer_, (uint32_t)flagBufferSize_, 0, blockAndThreadNum.first, blockAndThreadNum.second);
if (error != cudaSuccess) {
WARN("AllreduceNvlsPacket failed with error: %s", cudaGetErrorString(error));
WARN(ALGO, "AllreduceNvlsPacket failed with error: ", cudaGetErrorString(error));
return CommResult::CommUnhandledCudaError;
}
return CommResult::CommSuccess;
@@ -136,9 +140,10 @@ std::shared_ptr<mscclpp::Algorithm> AllreduceNvlsPacket::build() {
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize, mscclpp::DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) {
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras,
mscclpp::DataType accumDtype) {
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
extras);
extras, accumDtype);
},
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize,

View File

@@ -109,7 +109,7 @@ __global__ void __launch_bounds__(1024, 1)
#endif
}
template <ReduceOp OpType, typename T>
template <ReduceOp OpType, typename T, typename AccumT = T>
struct NvlsWarpPipelineAdapter {
static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void*,
DeviceHandle<SwitchChannel>* nvlsChannels, DeviceHandle<SwitchChannel>*, size_t, size_t,
@@ -118,6 +118,9 @@ struct NvlsWarpPipelineAdapter {
// uint8_t is not supported for NVLS (no hardware support for byte-level reduction)
if constexpr (std::is_same_v<T, uint8_t>) {
return cudaErrorNotSupported;
} else if constexpr (std::is_same_v<T, __fp8_e4m3b15>) {
// fp8_e4m3b15 is a software-only type with no hardware NVLS support.
return cudaErrorNotSupported;
} else
#if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS
if constexpr (std::is_same_v<T, __fp8_e4m3> || std::is_same_v<T, __fp8_e5m2>) {
@@ -147,12 +150,12 @@ void AllreduceNvlsWarpPipeline::initialize(std::shared_ptr<Communicator> comm) {
this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels_);
}
CommResult AllreduceNvlsWarpPipeline::allreduceKernelFunc(const std::shared_ptr<void> ctx_void, const void* input,
void* output, size_t inputSize, DataType dtype, ReduceOp op,
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>&) {
CommResult AllreduceNvlsWarpPipeline::allreduceKernelFunc(
const std::shared_ptr<void> ctx_void, const void* input, void* output, size_t inputSize, DataType dtype,
ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
[[maybe_unused]] const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype) {
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
AllreduceFunc allreduce = dispatch<NvlsWarpPipelineAdapter>(op, dtype);
AllreduceFunc allreduce = dispatch<NvlsWarpPipelineAdapter>(op, dtype, accumDtype);
if (!allreduce) {
WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast<int>(dtype));
return CommResult::CommInvalidArgument;
@@ -198,9 +201,9 @@ std::shared_ptr<Algorithm> AllreduceNvlsWarpPipeline::build() {
[self](std::shared_ptr<Communicator> comm) { self->initialize(comm); },
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) {
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype) {
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
extras);
extras, accumDtype);
},
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize,

View File

@@ -67,7 +67,7 @@ __global__ void __launch_bounds__(1024, 1)
#endif
}
template <ReduceOp OpType, typename T>
template <ReduceOp OpType, typename T, typename AccumT = T>
struct NvlsAdapter {
static cudaError_t call(const void*, void*, void*, void* memoryChannels, void*,
mscclpp::DeviceHandle<mscclpp::SwitchChannel>* nvlsChannels,
@@ -77,6 +77,9 @@ struct NvlsAdapter {
// uint8_t is not supported for NVLS (no hardware support for byte-level reduction)
if constexpr (std::is_same_v<T, uint8_t>) {
return cudaErrorNotSupported;
} else if constexpr (std::is_same_v<T, __fp8_e4m3b15>) {
// fp8_e4m3b15 is a software-only type with no hardware NVLS support.
return cudaErrorNotSupported;
} else
#if (!defined(__CUDA_ARCH_SPECIFIC__) && !defined(__CUDA_ARCH_FAMILY_SPECIFIC__)) || (__CUDA_ARCH__ < 1000)
if constexpr (std::is_same_v<T, __fp8_e4m3> || std::is_same_v<T, __fp8_e5m2>) {
@@ -114,13 +117,14 @@ void AllreduceNvls::initialize(std::shared_ptr<mscclpp::Communicator> comm) {
CommResult AllreduceNvls::allreduceKernelFunc(const std::shared_ptr<void> ctx_void, const void* input, void* output,
size_t inputSize, mscclpp::DataType dtype, ReduceOp op,
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>&) {
[[maybe_unused]] const std::unordered_map<std::string, uintptr_t>& extras,
mscclpp::DataType accumDtype) {
if (!symmetricMemory_) {
WARN("AllreduceNvls requires symmetric memory for now.");
return CommResult::CommInvalidArgument;
}
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
AllreduceFunc allreduce = dispatch<NvlsAdapter>(op, dtype);
AllreduceFunc allreduce = dispatch<NvlsAdapter>(op, dtype, accumDtype);
if (!allreduce) {
WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast<int>(dtype));
return CommResult::CommInvalidArgument;
@@ -203,9 +207,10 @@ std::shared_ptr<mscclpp::Algorithm> AllreduceNvls::build() {
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize, mscclpp::DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) {
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras,
mscclpp::DataType accumDtype) {
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
extras);
extras, accumDtype);
},
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize,

View File

@@ -2,16 +2,17 @@
// Licensed under the MIT License.
#include <mscclpp/algorithm.hpp>
#include <type_traits>
#include "allreduce/allreduce_packet.hpp"
#include "allreduce/common.hpp"
#include "collective_utils.hpp"
#include "debug.h"
#include "logger.hpp"
namespace mscclpp {
namespace collective {
template <ReduceOp OpType, typename T>
template <ReduceOp OpType, typename T, typename AccumT = T>
__global__ void __launch_bounds__(1024, 1)
allreducePacket(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
size_t channelDataOffset, size_t scratchBufferSize, int rank, int nRanksPerNode, int worldSize,
@@ -92,12 +93,21 @@ __global__ void __launch_bounds__(1024, 1)
// step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) {
uint2 data = src[idx];
for (int index = 0; index < nPeers; index++) {
const int remoteRank = index < rank ? index : index + 1;
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank;
uint2 val = dstPkt[idx].read(flag);
data.x = cal_vector<T, OpType>(val.x, data.x);
data.y = cal_vector<T, OpType>(val.y, data.y);
{
// When T == AccumT, stay with raw uint32_t to avoid type mismatch in identity path.
using AccRaw = std::conditional_t<std::is_same_v<T, AccumT>, uint32_t,
mscclpp::VectorType<AccumT, sizeof(uint32_t) / sizeof(T)>>;
AccRaw accX = mscclpp::upcastVector<T, AccumT, AccRaw>(data.x);
AccRaw accY = mscclpp::upcastVector<T, AccumT, AccRaw>(data.y);
for (int index = 0; index < nPeers; index++) {
const int remoteRank = index < rank ? index : index + 1;
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank;
uint2 val = dstPkt[idx].read(flag);
accX = mscclpp::calVectorAccum<T, AccumT, OpType, AccRaw>(accX, val.x);
accY = mscclpp::calVectorAccum<T, AccumT, OpType, AccRaw>(accY, val.y);
}
data.x = mscclpp::downcastVector<T, AccumT, uint32_t>(accX);
data.y = mscclpp::downcastVector<T, AccumT, uint32_t>(accY);
}
dst[idx].x = data.x;
@@ -142,7 +152,7 @@ __global__ void __launch_bounds__(1024, 1)
#endif
}
template <ReduceOp OpType, typename T>
template <ReduceOp OpType, typename T, typename AccumT = T>
struct PacketAdapter {
static cudaError_t call(const void* buff, void* scratch, void* resultBuff, void* memoryChannels, void*,
DeviceHandle<SwitchChannel>*, DeviceHandle<SwitchChannel>*, size_t channelInOffset, size_t,
@@ -155,12 +165,12 @@ struct PacketAdapter {
nBlocks = nBlocks / (worldSize - 1) * (worldSize - 1);
#if defined(ENABLE_NPKIT)
size_t sharedMemSize = sizeof(NpKitEvent) * NPKIT_SHM_NUM_EVENTS;
allreducePacket<OpType><<<nBlocks, nThreadsPerBlock, sharedMemSize, stream>>>(
allreducePacket<OpType, T, AccumT><<<nBlocks, nThreadsPerBlock, sharedMemSize, stream>>>(
(T*)buff, (T*)scratch, (T*)resultBuff, (ChannelType*)memoryChannels, channelInOffset, scratchBufferSize, rank,
nRanksPerNode, worldSize, nelems, flags, flagBufferSize, numScratchBuff, NpKit::GetGpuEventCollectContexts(),
NpKit::GetCpuTimestamp());
#else
allreducePacket<OpType><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
allreducePacket<OpType, T, AccumT><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
(T*)buff, (T*)scratch, (T*)resultBuff, (ChannelType*)memoryChannels, channelInOffset, scratchBufferSize, rank,
nRanksPerNode, worldSize, nelems, flags, flagBufferSize, numScratchBuff);
#endif
@@ -186,18 +196,22 @@ inline std::pair<int, int> getDefaultBlockNumAndThreadNum(size_t inputSize, int
}
}
#if defined(__FP8_TYPES_EXIST__)
// FP8-specific tuning for 32KB-256KB range
if (dtype == DataType::FLOAT8_E4M3 || dtype == DataType::FLOAT8_E5M2) {
if (inputSize < (64 << 10)) {
nThreadsPerBlock = 64;
} else if (inputSize >= (64 << 10) && inputSize <= (128 << 10)) {
nThreadsPerBlock = 128;
} else if (inputSize >= (128 << 10) && inputSize <= (256 << 10)) {
nThreadsPerBlock = 256;
{
bool isFp8 = dtype == DataType::FLOAT8_E4M3B15;
#if defined(__FP8_TYPES_EXIST__)
isFp8 = isFp8 || dtype == DataType::FLOAT8_E4M3 || dtype == DataType::FLOAT8_E5M2;
#endif
if (isFp8) {
if (inputSize < (64 << 10)) {
nThreadsPerBlock = 64;
} else if (inputSize >= (64 << 10) && inputSize <= (128 << 10)) {
nThreadsPerBlock = 128;
} else if (inputSize >= (128 << 10) && inputSize <= (256 << 10)) {
nThreadsPerBlock = 256;
}
}
}
#endif
#endif
return {nBlocks, nThreadsPerBlock};
}
@@ -213,7 +227,8 @@ void AllreducePacket::initialize(std::shared_ptr<Communicator> comm) {
CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr<void> ctx_void, const void* input, void* output,
size_t inputSize, [[maybe_unused]] DataType dtype, ReduceOp op,
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>&) {
const std::unordered_map<std::string, uintptr_t>&,
DataType accumDtype) {
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
std::pair<int, int> blockAndThreadNum = {nBlocks, nThreadsPerBlock};
if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) {
@@ -225,9 +240,10 @@ CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr<void> ctx_
MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)input));
size_t channelInOffset = (char*)input - (char*)sendBasePtr;
AllreduceFunc allreduce = dispatch<PacketAdapter>(op, dtype);
AllreduceFunc allreduce = dispatch<PacketAdapter>(op, dtype, accumDtype);
if (!allreduce) {
WARN("Unsupported operation or data type for allreduce: op=%d, dtype=%d", op, static_cast<int>(dtype));
WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast<int>(op),
", dtype=", static_cast<int>(dtype));
return CommResult::CommInvalidArgument;
}
cudaError_t error =
@@ -236,7 +252,7 @@ CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr<void> ctx_
stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_, this->nSegmentsForScratchBuffer_,
blockAndThreadNum.first, blockAndThreadNum.second);
if (error != cudaSuccess) {
WARN("AllreducePacket failed with error: %s", cudaGetErrorString(error));
WARN(ALGO, "AllreducePacket failed with error: ", cudaGetErrorString(error));
return CommResult::CommUnhandledCudaError;
}
return CommResult::CommSuccess;
@@ -280,9 +296,9 @@ std::shared_ptr<Algorithm> AllreducePacket::build() {
"default_allreduce_packet", "allreduce", [self](std::shared_ptr<Communicator> comm) { self->initialize(comm); },
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) {
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype) {
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
extras);
extras, accumDtype);
},
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize,

View File

@@ -87,7 +87,7 @@ __global__ void __launch_bounds__(1024, 1)
int rankIdx = (rank + i + 1) % nRanksPerNode;
int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1;
int4 data = mscclpp::read<int4>(((void**)remoteMemories)[peerIdx], offset);
tmp = cal_vector<T, OpType>(data, tmp);
tmp = calVector<T, OpType>(data, tmp);
}
for (uint32_t i = 0; i < nPeers; i++) {
int rankIdx = (rank + i + 1) % nRanksPerNode;
@@ -123,7 +123,7 @@ __global__ void __launch_bounds__(1024, 1)
}
}
template <ReduceOp OpType, typename T>
template <ReduceOp OpType, typename T, typename AccumT = T>
struct AllreduceRsAgAdapter {
static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories,
DeviceHandle<SwitchChannel>* switchChannel, DeviceHandle<SwitchChannel>*, size_t, size_t,
@@ -166,9 +166,9 @@ void AllreduceRsAg::initialize(std::shared_ptr<Communicator> comm) {
CommResult AllreduceRsAg::allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output,
size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream,
int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>&) {
const std::unordered_map<std::string, uintptr_t>&, DataType accumDtype) {
auto algoCtx = std::static_pointer_cast<AlgorithmCtx>(ctx);
AllreduceFunc allreduce = dispatch<AllreduceRsAgAdapter>(op, dtype);
AllreduceFunc allreduce = dispatch<AllreduceRsAgAdapter>(op, dtype, accumDtype);
if (!allreduce) {
WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast<int>(op),
", dtype=", static_cast<int>(dtype));
@@ -213,9 +213,10 @@ std::shared_ptr<Algorithm> AllreduceRsAg::build() {
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) -> CommResult {
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras,
DataType accumDtype) -> CommResult {
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
extras);
extras, accumDtype);
},
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize,

View File

@@ -168,7 +168,7 @@ __global__ void __launch_bounds__(1024, 1)
uint32_t peerSlotOffset =
baseOffset + remoteRankId * nInt4PerIter + threadIdInPut + putStep * blockDim.x * nblocksForPut;
int4 data = scratch4[peerSlotOffset];
tmp = cal_vector<T, OpType>(data, tmp);
tmp = calVector<T, OpType>(data, tmp);
}
storeVec(resultBuff, myChunkOffset, tmp, nelems);
// Broadcast reduced result to all peers' scratch at SCATTER_AG_OFFSET + rank * nInt4PerIter
@@ -220,7 +220,7 @@ __global__ void __launch_bounds__(1024, 1)
}
}
template <ReduceOp OpType, typename T>
template <ReduceOp OpType, typename T, typename AccumT = T>
struct AllreduceRsAgPipelineAdapter {
static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories,
DeviceHandle<SwitchChannel>* switchChannel, DeviceHandle<SwitchChannel>*, size_t, size_t,
@@ -274,12 +274,12 @@ void AllreduceRsAgPipeline::initialize(std::shared_ptr<Communicator> comm) {
cudaMemcpyHostToDevice);
}
CommResult AllreduceRsAgPipeline::allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output,
size_t inputSize, DataType dtype, ReduceOp op,
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>&) {
CommResult AllreduceRsAgPipeline::allreduceKernelFunc(
const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op,
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
[[maybe_unused]] const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype) {
auto algoCtx = std::static_pointer_cast<AlgorithmCtx>(ctx);
AllreduceFunc allreduce = dispatch<AllreduceRsAgPipelineAdapter>(op, dtype);
AllreduceFunc allreduce = dispatch<AllreduceRsAgPipelineAdapter>(op, dtype, accumDtype);
if (!allreduce) {
WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast<int>(op),
", dtype=", static_cast<int>(dtype));
@@ -320,9 +320,10 @@ std::shared_ptr<Algorithm> AllreduceRsAgPipeline::build() {
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) -> CommResult {
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras,
DataType accumDtype) -> CommResult {
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
extras);
extras, accumDtype);
},
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize,

View File

@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include <type_traits>
#include "allreduce/allreduce_rsag_zero_copy.hpp"
#include "allreduce/common.hpp"
#include "collective_utils.hpp"
@@ -36,7 +38,7 @@ __device__ mscclpp::DeviceSyncer globalSyncer;
// the extra copy steps of the standard RSAG. The NRanksPerNode template
// parameter enables compile-time unrolling of peer loops (supports 4 or 8).
template <int NRanksPerNode, ReduceOp OpType, typename T>
template <int NRanksPerNode, ReduceOp OpType, typename T, typename AccumT = T>
__global__ void __launch_bounds__(1024, 1)
allreduceRsAgZeroCopy(T* buff, T* scratch, T* resultBuff, DeviceHandle<BaseMemoryChannel>* memoryChannels,
DeviceHandle<SwitchChannel>* switchChannels, void* remoteMemories, int rank, int worldSize,
@@ -73,19 +75,26 @@ __global__ void __launch_bounds__(1024, 1)
}
__syncthreads();
int4 data[NPeers];
// AccumInt4: when AccumT != T, use a wider accumulator type.
// For AccumT == T, this is just int4 (no-op conversion).
constexpr int nElemsPerInt4 = sizeof(int4) / sizeof(T);
// When T == AccumT, stay with raw int4 to avoid type mismatch in identity path.
using AccumVec = std::conditional_t<std::is_same_v<T, AccumT>, int4, mscclpp::VectorType<AccumT, nElemsPerInt4>>;
for (uint32_t idx = threadIdx.x; idx < nInt4PerBlock; idx += blockDim.x) {
uint32_t offset = idx + offset4 + rank * nInt4PerRank;
if (offset >= nInt4Total) continue;
int4 tmp = buff4[offset];
int4 tmp_raw = buff4[offset];
#pragma unroll
for (int i = 0; i < NPeers; i++) {
int rankIdx = (rank + i + 1) % NRanksPerNode;
int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1;
data[i] = mscclpp::read<int4>(((void**)remoteMemories)[peerIdx], offset);
}
AccumVec acc = mscclpp::upcastVector<T, AccumT, AccumVec>(tmp_raw);
for (int i = 0; i < NPeers; i++) {
tmp = cal_vector<T, OpType>(data[i], tmp);
acc = mscclpp::calVectorAccum<T, AccumT, OpType, AccumVec>(acc, data[i]);
}
int4 tmp = mscclpp::downcastVector<T, AccumT, int4>(acc);
#pragma unroll
for (int i = 0; i < NPeers; i++) {
int rankIdx = (rank + i + 1) % NRanksPerNode;
@@ -102,7 +111,7 @@ __global__ void __launch_bounds__(1024, 1)
}
}
template <ReduceOp OpType, typename T>
template <ReduceOp OpType, typename T, typename AccumT = T>
struct AllreduceRsAgZeroCopyAdapter {
static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories,
DeviceHandle<SwitchChannel>* switchChannel, DeviceHandle<SwitchChannel>*, size_t, size_t,
@@ -118,11 +127,11 @@ struct AllreduceRsAgZeroCopyAdapter {
}
}
if (nRanksPerNode == 4) {
allreduceRsAgZeroCopy<4, OpType, T>
allreduceRsAgZeroCopy<4, OpType, T, AccumT>
<<<nBlocks, nThreadsPerBlock, 0, stream>>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels,
switchChannel, remoteMemories, rank, worldSize, nelems);
} else if (nRanksPerNode == 8) {
allreduceRsAgZeroCopy<8, OpType, T>
allreduceRsAgZeroCopy<8, OpType, T, AccumT>
<<<nBlocks, nThreadsPerBlock, 0, stream>>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels,
switchChannel, remoteMemories, rank, worldSize, nelems);
} else {
@@ -145,9 +154,10 @@ void AllreduceRsAgZeroCopy::initialize(std::shared_ptr<Communicator> comm) {
CommResult AllreduceRsAgZeroCopy::allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output,
size_t inputSize, DataType dtype, ReduceOp op,
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>&) {
const std::unordered_map<std::string, uintptr_t>&,
DataType accumDtype) {
auto algoCtx = std::static_pointer_cast<AlgorithmCtx>(ctx);
AllreduceFunc allreduce = dispatch<AllreduceRsAgZeroCopyAdapter>(op, dtype);
AllreduceFunc allreduce = dispatch<AllreduceRsAgZeroCopyAdapter>(op, dtype, accumDtype);
if (!allreduce) {
WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast<int>(op),
", dtype=", static_cast<int>(dtype));
@@ -220,9 +230,10 @@ std::shared_ptr<Algorithm> AllreduceRsAgZeroCopy::build() {
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) -> CommResult {
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras,
DataType accumDtype) -> CommResult {
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
extras);
extras, accumDtype);
},
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
[[maybe_unused]] size_t outputSize,

View File

@@ -20,7 +20,7 @@ class AllreduceAllpairPacket : public AlgorithmBuilder {
void initialize(std::shared_ptr<Communicator> comm);
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>& extras);
const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype);
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
DataType);

View File

@@ -16,7 +16,7 @@ class AllreduceFullmesh : public mscclpp::AlgorithmBuilder {
void initialize(std::shared_ptr<Communicator> comm);
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>& extras);
const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype);
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
DataType);

View File

@@ -19,7 +19,7 @@ class AllreduceNvlsBlockPipeline : public AlgorithmBuilder {
void initialize(std::shared_ptr<Communicator> comm);
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>& extras);
const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype);
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
DataType);

View File

@@ -21,7 +21,8 @@ class AllreduceNvlsPacket : public mscclpp::AlgorithmBuilder {
void initialize(std::shared_ptr<mscclpp::Communicator> comm);
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
mscclpp::DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras);
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras,
mscclpp::DataType accumDtype);
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<mscclpp::Communicator> comm, const void*, void* output,
size_t, mscclpp::DataType);
@@ -34,6 +35,7 @@ class AllreduceNvlsPacket : public mscclpp::AlgorithmBuilder {
uintptr_t flagBuffer_;
size_t flagBufferSize_;
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections_;
std::vector<SwitchChannel> switchChannels_;
};
} // namespace collective
} // namespace mscclpp

View File

@@ -19,7 +19,7 @@ class AllreduceNvlsWarpPipeline : public AlgorithmBuilder {
void initialize(std::shared_ptr<Communicator> comm);
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>& extras);
const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype);
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
DataType);

View File

@@ -19,7 +19,7 @@ class AllreduceNvls : public AlgorithmBuilder {
void initialize(std::shared_ptr<Communicator> comm);
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>& extras);
const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype);
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
DataType);

View File

@@ -20,7 +20,7 @@ class AllreducePacket : public AlgorithmBuilder {
void initialize(std::shared_ptr<Communicator> comm);
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>& extras);
const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype);
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
DataType);

View File

@@ -19,7 +19,7 @@ class AllreduceRsAg : public mscclpp::AlgorithmBuilder {
void initialize(std::shared_ptr<Communicator> comm);
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>& extras);
const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype);
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
DataType);

View File

@@ -19,7 +19,7 @@ class AllreduceRsAgPipeline : public mscclpp::AlgorithmBuilder {
void initialize(std::shared_ptr<Communicator> comm);
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>& extras);
const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype);
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
DataType);

View File

@@ -18,7 +18,7 @@ class AllreduceRsAgZeroCopy : public mscclpp::AlgorithmBuilder {
void initialize(std::shared_ptr<Communicator> comm);
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>& extras);
const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype);
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
DataType);

View File

@@ -1,8 +1,8 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#ifndef MSCCLPP_ALLREDUCE_COMMOM_HPP_
#define MSCCLPP_ALLREDUCE_COMMOM_HPP_
#ifndef MSCCLPP_ALLREDUCE_COMMON_HPP_
#define MSCCLPP_ALLREDUCE_COMMON_HPP_
#include <cmath>
#include <mscclpp/algorithm.hpp>
@@ -77,55 +77,51 @@ using AllreduceFunc =
mscclpp::DeviceHandle<mscclpp::SwitchChannel>*, size_t, size_t, size_t, int, int, int,
size_t, cudaStream_t, void*, uint32_t, uint32_t, int, int)>;
template <template <ReduceOp, typename> class Adapter>
AllreduceFunc dispatch(ReduceOp op, mscclpp::DataType dtype) {
if (op == SUM) {
if (dtype == mscclpp::DataType::FLOAT16) {
return Adapter<SUM, half>::call;
} else if (dtype == mscclpp::DataType::FLOAT32) {
return Adapter<SUM, float>::call;
#if defined(__CUDA_BF16_TYPES_EXIST__)
} else if (dtype == mscclpp::DataType::BFLOAT16) {
return Adapter<SUM, __bfloat16>::call;
#endif
#if defined(__FP8_TYPES_EXIST__)
} else if (dtype == mscclpp::DataType::FLOAT8_E4M3) {
return Adapter<SUM, __fp8_e4m3>::call;
} else if (dtype == mscclpp::DataType::FLOAT8_E5M2) {
return Adapter<SUM, __fp8_e5m2>::call;
#endif
} else if (dtype == mscclpp::DataType::INT32 || dtype == mscclpp::DataType::UINT32) {
return Adapter<SUM, int>::call;
} else if (dtype == mscclpp::DataType::UINT8) {
return Adapter<SUM, uint8_t>::call;
} else {
return nullptr;
}
} else if (op == MIN) {
if (dtype == mscclpp::DataType::FLOAT16) {
return Adapter<MIN, half>::call;
} else if (dtype == mscclpp::DataType::FLOAT32) {
return Adapter<MIN, float>::call;
#if defined(__CUDA_BF16_TYPES_EXIST__)
} else if (dtype == mscclpp::DataType::BFLOAT16) {
return Adapter<MIN, __bfloat16>::call;
#endif
#if defined(__FP8_TYPES_EXIST__)
} else if (dtype == mscclpp::DataType::FLOAT8_E4M3) {
return Adapter<MIN, __fp8_e4m3>::call;
} else if (dtype == mscclpp::DataType::FLOAT8_E5M2) {
return Adapter<MIN, __fp8_e5m2>::call;
#endif
} else if (dtype == mscclpp::DataType::INT32 || dtype == mscclpp::DataType::UINT32) {
return Adapter<MIN, int>::call;
} else if (dtype == mscclpp::DataType::UINT8) {
return Adapter<MIN, uint8_t>::call;
} else {
return nullptr;
}
/// Dispatch helper for FP8 types with a configurable accumulation type.
template <ReduceOp Op, typename FP8T, template <ReduceOp, typename, typename> class Adapter>
AllreduceFunc dispatchFp8Accum(mscclpp::DataType accumDtype, mscclpp::DataType dtype) {
if (accumDtype == mscclpp::DataType::FLOAT32) {
return Adapter<Op, FP8T, float>::call;
} else if (accumDtype == mscclpp::DataType::FLOAT16) {
return Adapter<Op, FP8T, half>::call;
} else if (accumDtype == dtype) {
return Adapter<Op, FP8T, FP8T>::call;
}
return nullptr;
}
template <ReduceOp Op, template <ReduceOp, typename, typename> class Adapter>
AllreduceFunc dispatchByDtype(mscclpp::DataType dtype, mscclpp::DataType accumDtype) {
if (dtype == mscclpp::DataType::FLOAT16) {
return Adapter<Op, half, half>::call;
} else if (dtype == mscclpp::DataType::FLOAT32) {
return Adapter<Op, float, float>::call;
#if defined(__CUDA_BF16_TYPES_EXIST__)
} else if (dtype == mscclpp::DataType::BFLOAT16) {
return Adapter<Op, __bfloat16, __bfloat16>::call;
#endif
#if defined(__FP8_TYPES_EXIST__)
} else if (dtype == mscclpp::DataType::FLOAT8_E4M3) {
return dispatchFp8Accum<Op, __fp8_e4m3, Adapter>(accumDtype, dtype);
} else if (dtype == mscclpp::DataType::FLOAT8_E5M2) {
return dispatchFp8Accum<Op, __fp8_e5m2, Adapter>(accumDtype, dtype);
#endif
} else if (dtype == mscclpp::DataType::FLOAT8_E4M3B15) {
return dispatchFp8Accum<Op, __fp8_e4m3b15, Adapter>(accumDtype, dtype);
} else if (dtype == mscclpp::DataType::INT32 || dtype == mscclpp::DataType::UINT32) {
return Adapter<Op, int, int>::call;
} else if (dtype == mscclpp::DataType::UINT8) {
return Adapter<Op, uint8_t, uint8_t>::call;
}
return nullptr;
}
template <template <ReduceOp, typename, typename> class Adapter>
AllreduceFunc dispatch(ReduceOp op, mscclpp::DataType dtype, mscclpp::DataType accumDtype) {
if (op == SUM) return dispatchByDtype<SUM, Adapter>(dtype, accumDtype);
if (op == MIN) return dispatchByDtype<MIN, Adapter>(dtype, accumDtype);
return nullptr;
}
} // namespace collective
} // namespace mscclpp

View File

@@ -15,7 +15,8 @@ static bool isNvlsSupportedForDataType(const AlgorithmSelectorConfig& config, Da
bool nvlsSupported = config.nvlsSupported;
// NVLS does not support uint8_t (no hardware support for byte-level reduction)
if (dtype == DataType::UINT8) {
// NVLS also does not support float8_e4m3b15 (software-defined type with no hardware NVLS reduction support)
if (dtype == DataType::UINT8 || dtype == DataType::FLOAT8_E4M3B15) {
return false;
}

View File

@@ -43,6 +43,7 @@ inline size_t getDataTypeSize(mscclpp::DataType dtype) {
case mscclpp::DataType::UINT8:
case mscclpp::DataType::FLOAT8_E4M3:
case mscclpp::DataType::FLOAT8_E5M2:
case mscclpp::DataType::FLOAT8_E4M3B15:
return 1;
case mscclpp::DataType::FLOAT16:
case mscclpp::DataType::BFLOAT16:
@@ -76,6 +77,10 @@ static inline ncclDataType_t mscclppToNcclDataType(mscclpp::DataType dtype) {
case mscclpp::DataType::FLOAT8_E5M2:
return ncclFloat8e5m2;
#endif
case mscclpp::DataType::FLOAT8_E4M3B15:
// float8_e4m3b15 has no NCCL equivalent; NCCL cannot reduce this type correctly.
THROW(mscclpp::LogSubsys::NCCL, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage,
"FLOAT8_E4M3B15 (float8_e4m3b15) has no NCCL equivalent and cannot be used with NCCL collectives");
default:
THROW(mscclpp::LogSubsys::NCCL, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage,
"Unsupported mscclpp::DataType: " + std::to_string(static_cast<int>(dtype)));

View File

@@ -83,17 +83,17 @@ static inline int mscclppNcclDlopenInit() {
const char* ncclLibPath = mscclpp::env()->ncclSharedLibPath.c_str();
if (ncclLibPath != nullptr && ncclLibPath[0] != '\0') {
if (std::filesystem::is_directory(ncclLibPath)) {
WARN(MSCCLPP_NCCL, "The value of the environment variable %s is a directory", ncclLibPath);
WARN(MSCCLPP_NCCL, "MSCCLPP_NCCL_LIB_PATH points to a directory: ", ncclLibPath);
return dlopenError;
}
mscclppNcclDlHandle = dlopen(ncclLibPath, RTLD_LAZY | RTLD_NODELETE | RTLD_DEEPBIND);
if (!mscclppNcclDlHandle) {
WARN(MSCCLPP_NCCL, "Cannot open the shared library specified by MSCCLPP_NCCL_LIB_PATH: %s\n", dlerror());
WARN(MSCCLPP_NCCL, "Cannot open the shared library specified by MSCCLPP_NCCL_LIB_PATH: ", dlerror());
return dlopenError;
}
} else {
WARN(MSCCLPP_NCCL, "The value of MSCCLPP_NCCL_LIB_PATH is empty!\n");
WARN(MSCCLPP_NCCL, "The value of MSCCLPP_NCCL_LIB_PATH is empty!");
return dlopenError;
}
@@ -270,19 +270,18 @@ static std::shared_ptr<mscclpp::Algorithm> algoSelector(
return mscclpp::nccl::selectSingleNodeAllreduce(algoMap, request, config);
}
INFO(MSCCLPP_NCCL, "No suitable algorithm found for collective '%s', fallback to nccl/rccl",
request.collective.c_str());
INFO(MSCCLPP_NCCL, "No suitable algorithm found for collective '", request.collective, "', fallback to nccl/rccl");
return nullptr;
}
NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank) {
INFO(MSCCLPP_NCCL, "Initializing NCCL communicator for rank %d, world_size=%d", rank, nranks);
INFO(MSCCLPP_NCCL, "Initializing NCCL communicator for rank ", rank, ", world_size=", nranks);
if (comm == nullptr) {
WARN(MSCCLPP_NCCL, "comm is nullptr");
return ncclInvalidArgument;
}
if (nranks < 0 || rank < 0 || rank >= nranks) {
WARN(MSCCLPP_NCCL, "nranks is %d, rank is %d", nranks, rank);
WARN(MSCCLPP_NCCL, "nranks is ", nranks, ", rank is ", rank);
return ncclInvalidArgument;
}
std::shared_ptr<mscclpp::TcpBootstrap> bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, nranks);
@@ -560,8 +559,8 @@ NCCL_API ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t
return ncclInvalidArgument;
}
INFO(MSCCLPP_NCCL, "rank %d broadcast sendbuff %p recvbuff %p count %ld, dtype %d, comm: %p", rank, sendbuff,
recvbuff, count, datatype, comm);
INFO(MSCCLPP_NCCL, "rank ", rank, " broadcast sendbuff ", sendbuff, " recvbuff ", recvbuff, " count ", count,
", dtype ", datatype, ", comm: ", (void*)comm);
const char* fallbackList = mscclpp::env()->forceNcclFallbackOperation.c_str();
if (mscclppNcclDlopenSharedLib == true && mscclppNcclInFallbackList("broadcast", fallbackList)) {
@@ -619,8 +618,8 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
}
// Declarating variables
int rank = comm->comm->bootstrap()->getRank();
INFO(MSCCLPP_NCCL, "rank %d allreduce sendbuff %p recvbuff %p count %ld, dtype %d comm is %p", rank, sendbuff,
recvbuff, count, datatype, comm);
INFO(MSCCLPP_NCCL, "rank ", rank, " allreduce sendbuff ", sendbuff, " recvbuff ", recvbuff, " count ", count,
", dtype ", datatype, " comm is ", (void*)comm);
const char* fallbackList = mscclpp::env()->forceNcclFallbackOperation.c_str();
if (mscclppNcclDlopenSharedLib && mscclppNcclInFallbackList("allreduce", fallbackList)) {
@@ -673,8 +672,8 @@ NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, si
return ncclInvalidArgument;
}
INFO(MSCCLPP_NCCL, "ReduceScatter recvcount: %ld, datatype: %d, op: %d, messageSize: %ld", recvcount, datatype, op,
bytes * comm->comm->bootstrap()->getNranks());
INFO(MSCCLPP_NCCL, "ReduceScatter recvcount: ", recvcount, ", datatype: ", datatype, ", op: ", op,
", messageSize: ", bytes * comm->comm->bootstrap()->getNranks());
const char* fallbackList = mscclpp::env()->forceNcclFallbackOperation.c_str();
if (mscclppNcclDlopenSharedLib == true && mscclppNcclInFallbackList("reducescatter", fallbackList)) {
@@ -730,8 +729,8 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
int rank = comm->comm->bootstrap()->getRank();
int nRank = comm->comm->bootstrap()->getNranks();
INFO(MSCCLPP_NCCL, "rank %d allgather sendbuff %p recvbuff %p count %ld, dtype %d, comm %p", rank, sendbuff, recvbuff,
sendcount, datatype, comm);
INFO(MSCCLPP_NCCL, "rank ", rank, " allgather sendbuff ", sendbuff, " recvbuff ", recvbuff, " count ", sendcount,
", dtype ", datatype, ", comm ", (void*)comm);
const char* fallbackList = mscclpp::env()->forceNcclFallbackOperation.c_str();
if (mscclppNcclDlopenSharedLib == true && mscclppNcclInFallbackList("allgather", fallbackList)) {
@@ -866,20 +865,20 @@ ncclResult_t ncclMemAlloc(void** ptr, size_t size) {
}
} catch (const mscclpp::Error& e) {
if (e.getErrorCode() == mscclpp::ErrorCode::InvalidUsage) {
WARN(MSCCLPP_NCCL, "Invalid usage: %s", e.what());
WARN(MSCCLPP_NCCL, "Invalid usage: ", e.what());
return ncclInvalidUsage;
} else {
WARN(MSCCLPP_NCCL, "Internal error: %s", e.what());
WARN(MSCCLPP_NCCL, "Internal error: ", e.what());
return ncclInternalError;
}
} catch (const mscclpp::CudaError& e) {
WARN(MSCCLPP_NCCL, "Cuda error: %s", e.what());
WARN(MSCCLPP_NCCL, "Cuda error: ", e.what());
return ncclUnhandledCudaError;
} catch (const mscclpp::CuError& e) {
WARN(MSCCLPP_NCCL, "Cu error: %s", e.what());
WARN(MSCCLPP_NCCL, "Cu error: ", e.what());
return ncclUnhandledCudaError;
} catch (const mscclpp::BaseError& e) {
WARN(MSCCLPP_NCCL, "Base error: %s", e.what());
WARN(MSCCLPP_NCCL, "Base error: ", e.what());
return ncclInternalError;
}
ptrMap[sharedPtr.get()] = sharedPtr;

View File

@@ -30,6 +30,12 @@ fi
if [ "${PLATFORM}" == "rocm" ]; then
export CXX=/opt/rocm/bin/hipcc
fi
PIP_CMAKE_ARGS_FILE="/root/mscclpp/pip_cmake_args.txt"
if [ -f "${PIP_CMAKE_ARGS_FILE}" ]; then
export CMAKE_ARGS="$(cat ${PIP_CMAKE_ARGS_FILE})"
echo "Using CMAKE_ARGS: ${CMAKE_ARGS}"
fi
cd /root/mscclpp && pip3 install .
pip3 install setuptools_scm
python3 -m setuptools_scm --force-write-version-files

View File

@@ -8,18 +8,32 @@
#include "mp_unit_tests.hpp"
#include "utils_internal.hpp"
// Skip the current test if HostNoAtomic mode is not supported.
// On CUDA, HostNoAtomic requires GDRCopy for BAR1 signal forwarding.
// On ROCm, HostNoAtomic uses direct volatile writes and does not need GDRCopy.
// Skip the current test if the given IB mode will require GDRCopy on CUDA but it is unavailable.
// On CUDA, HostNoAtomic requires GDRCopy for BAR1 signal forwarding. When IbMode::Host or
// IbMode::Default is used and the IB device does not support RDMA atomics, the endpoint falls
// back to no-atomic mode, which also requires GDRCopy.
// On ROCm, no-atomic mode uses direct volatile writes and does not need GDRCopy.
#if defined(MSCCLPP_USE_CUDA)
#define REQUIRE_HOST_NO_ATOMIC \
do { \
if (!mscclpp::gdrEnabled()) { \
SKIP_TEST() << "HostNoAtomic requires GDRCopy: " << mscclpp::gdrStatusMessage(); \
} \
} while (0)
inline void requireGdrForIbMode(IbMode mode, mscclpp::Transport ibTransport) {
if (mscclpp::gdrEnabled()) return; // GDRCopy available — nothing to skip.
if (mode == IbMode::HostNoAtomic) {
SKIP_TEST() << "HostNoAtomic requires GDRCopy on CUDA: " << mscclpp::gdrStatusMessage();
}
// For Host/Default modes: check whether the IB device lacks RDMA atomics,
// which would cause an automatic fallback to no-atomic mode.
if (mode == IbMode::Host || mode == IbMode::Default) {
std::string devName = mscclpp::getIBDeviceName(ibTransport);
mscclpp::IbCtx ibCtx(devName);
if (!ibCtx.supportsRdmaAtomics()) {
SKIP_TEST() << "IB device " << devName
<< " lacks RDMA atomics; Host mode falls back to HostNoAtomic which requires GDRCopy: "
<< mscclpp::gdrStatusMessage();
}
}
}
#define REQUIRE_GDR_FOR_IB_MODE(mode) requireGdrForIbMode((mode), ibTransport)
#else
#define REQUIRE_HOST_NO_ATOMIC // No extra requirements on non-CUDA platforms.
#define REQUIRE_GDR_FOR_IB_MODE(mode) // No extra requirements on non-CUDA platforms.
#endif
void PortChannelOneToOneTest::SetUp() {
@@ -254,6 +268,7 @@ TEST(PortChannelOneToOneTest, PingPong) {
TEST(PortChannelOneToOneTest, PingPongIbHostMode) {
REQUIRE_IBVERBS;
REQUIRE_GDR_FOR_IB_MODE(IbMode::Host);
testPingPong(PingPongTestParams{
.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = false, .ibMode = IbMode::Host});
}
@@ -270,6 +285,7 @@ TEST(PortChannelOneToOneTest, PingPongWithPoll) {
TEST(PortChannelOneToOneTest, PingPongIbHostModeWithPoll) {
REQUIRE_IBVERBS;
REQUIRE_GDR_FOR_IB_MODE(IbMode::Host);
testPingPong(PingPongTestParams{
.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = true, .ibMode = IbMode::Host});
}
@@ -281,13 +297,14 @@ PERF_TEST(PortChannelOneToOneTest, PingPongPerf) {
PERF_TEST(PortChannelOneToOneTest, PingPongPerfIbHostMode) {
REQUIRE_IBVERBS;
REQUIRE_GDR_FOR_IB_MODE(IbMode::Host);
testPingPongPerf(PingPongTestParams{
.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = false, .ibMode = IbMode::Host});
}
PERF_TEST(PortChannelOneToOneTest, PingPongPerfIbHostNoAtomicMode) {
REQUIRE_IBVERBS;
REQUIRE_HOST_NO_ATOMIC;
REQUIRE_GDR_FOR_IB_MODE(IbMode::HostNoAtomic);
testPingPongPerf(PingPongTestParams{
.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = false, .ibMode = IbMode::HostNoAtomic});
}
@@ -469,6 +486,7 @@ TEST(PortChannelOneToOneTest, PacketPingPong) { testPacketPingPong(false, IbMode
TEST(PortChannelOneToOneTest, PacketPingPongIbHostMode) {
REQUIRE_IBVERBS;
REQUIRE_GDR_FOR_IB_MODE(IbMode::Host);
testPacketPingPong(true, IbMode::Host);
}
@@ -476,25 +494,26 @@ PERF_TEST(PortChannelOneToOneTest, PacketPingPongPerf) { testPacketPingPongPerf(
PERF_TEST(PortChannelOneToOneTest, PacketPingPongPerfIbHostMode) {
REQUIRE_IBVERBS;
REQUIRE_GDR_FOR_IB_MODE(IbMode::Host);
testPacketPingPongPerf(true, IbMode::Host);
}
PERF_TEST(PortChannelOneToOneTest, PacketPingPongPerfIbHostNoAtomicMode) {
REQUIRE_IBVERBS;
REQUIRE_HOST_NO_ATOMIC;
REQUIRE_GDR_FOR_IB_MODE(IbMode::HostNoAtomic);
testPacketPingPongPerf(true, IbMode::HostNoAtomic);
}
TEST(PortChannelOneToOneTest, PingPongIbHostNoAtomicMode) {
REQUIRE_IBVERBS;
REQUIRE_HOST_NO_ATOMIC;
REQUIRE_GDR_FOR_IB_MODE(IbMode::HostNoAtomic);
testPingPong(PingPongTestParams{
.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = false, .ibMode = IbMode::HostNoAtomic});
}
TEST(PortChannelOneToOneTest, PacketPingPongIbHostNoAtomicMode) {
REQUIRE_IBVERBS;
REQUIRE_HOST_NO_ATOMIC;
REQUIRE_GDR_FOR_IB_MODE(IbMode::HostNoAtomic);
testPacketPingPong(true, IbMode::HostNoAtomic);
}
@@ -570,13 +589,14 @@ PERF_TEST(PortChannelOneToOneTest, Bandwidth) {
PERF_TEST(PortChannelOneToOneTest, BandwidthIbHostMode) {
REQUIRE_IBVERBS;
REQUIRE_GDR_FOR_IB_MODE(IbMode::Host);
testBandwidth(PingPongTestParams{
.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = false, .ibMode = IbMode::Host});
}
PERF_TEST(PortChannelOneToOneTest, BandwidthIbHostNoAtomicMode) {
REQUIRE_IBVERBS;
REQUIRE_HOST_NO_ATOMIC;
REQUIRE_GDR_FOR_IB_MODE(IbMode::HostNoAtomic);
testBandwidth(PingPongTestParams{
.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = false, .ibMode = IbMode::HostNoAtomic});
}

View File

@@ -14,25 +14,25 @@ def parse_npkit_event_header(npkit_event_header_path):
"NOP",
"BARRIER",
"PUT",
"PUT_PACKET",
"READ_PUT_PACKET",
"PUT_PACKETS",
"READ_PUT_PACKETS",
"PUT_WITH_SIGNAL",
"PUT_WITH_SIGNAL_AND_FLUSH",
"GET",
"COPY",
"COPY_PACKET",
"TRANSFORM_TO_PACKET",
"COPY_PACKETS",
"UNPACK_PACKETS",
"SIGNAL",
"WAIT",
"FLUSH",
"REDUCE",
"REDUCE_PACKET",
"REDUCE_PACKETS",
"REDUCE_COPY_PACKETS",
"REDUCE_SEND",
"REDUCE_SEND_PACKET",
"REDUCE_SEND_PACKETS",
"REDUCE_COPY_SEND_PACKETS",
"READ_REDUCE_COPY",
"READ_REDUCE_COPY_SEND",
"READ_REDUCE",
"READ_REDUCE_SEND",
"MULTI_LOAD_REDUCE_STORE",
"RELAXED_SIGNAL",
"RELAXED_WAIT",