mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
merge main
This commit is contained in:
@@ -94,7 +94,27 @@ steps:
|
||||
du -sh build/bin/* 2>/dev/null || true
|
||||
workingDirectory: '$(System.DefaultWorkingDirectory)'
|
||||
|
||||
# 2. Download SSH key + install packages + start VMSS
|
||||
# 2. Write CMake args for pip install on remote VMs
|
||||
- task: Bash@3
|
||||
name: WritePipCmakeArgs
|
||||
displayName: Write pip CMake args
|
||||
inputs:
|
||||
targetType: 'inline'
|
||||
script: |
|
||||
set -e
|
||||
PIP_CMAKE_ARGS=""
|
||||
if [ -n "${{ parameters.gpuArch }}" ]; then
|
||||
PIP_CMAKE_ARGS="-DMSCCLPP_GPU_ARCHS=${{ parameters.gpuArch }}"
|
||||
fi
|
||||
CMAKE_EXTRA_ARGS='${{ parameters.cmakeArgs }}'
|
||||
if [ -n "${CMAKE_EXTRA_ARGS}" ]; then
|
||||
PIP_CMAKE_ARGS="${PIP_CMAKE_ARGS} ${CMAKE_EXTRA_ARGS}"
|
||||
fi
|
||||
echo "${PIP_CMAKE_ARGS}" > pip_cmake_args.txt
|
||||
echo "pip CMake args: $(cat pip_cmake_args.txt)"
|
||||
workingDirectory: '$(System.DefaultWorkingDirectory)'
|
||||
|
||||
# 3. Download SSH key + install packages + start VMSS
|
||||
- task: DownloadSecureFile@1
|
||||
name: SshKeyFile
|
||||
displayName: Download key file
|
||||
@@ -120,7 +140,7 @@ steps:
|
||||
inlineScript: |
|
||||
az vmss start --name ${{ parameters.vmssName }} --resource-group ${{ parameters.resourceGroup }}
|
||||
|
||||
# 3. Deploy test environment
|
||||
# 4. Deploy test environment
|
||||
- task: Bash@3
|
||||
name: DeployTestEnv
|
||||
displayName: Deploy Test Env
|
||||
|
||||
@@ -28,7 +28,7 @@ steps:
|
||||
grep -q NPKIT_EVENT_EXECUTOR_INIT_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_SIGNAL_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_WAIT_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_READ_REDUCE_COPY_SEND_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_READ_REDUCE_SEND_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
|
||||
- template: run-remote-task.yml
|
||||
parameters:
|
||||
@@ -42,14 +42,14 @@ steps:
|
||||
grep -q NPKIT_EVENT_EXECUTOR_INIT_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_SIGNAL_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_WAIT_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_READ_REDUCE_COPY_SEND_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_READ_REDUCE_SEND_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
rm -rf ./npkit_dump && mkdir ./npkit_dump && rm -rf ./npkit_output && mkdir ./npkit_output
|
||||
mpirun --allow-run-as-root -tag-output -x MSCCLPP_HOME=/root/mscclpp -np 8 python3 -m pytest ./python/test/test_mscclpp.py -x -k 'test_executor[allreduce_packet.json'
|
||||
python3 ./tools/npkit/npkit_trace_generator.py --npkit_dump_dir=./npkit_dump --npkit_event_header_path=./include/mscclpp/npkit/npkit_event.hpp --output_dir=./npkit_output
|
||||
grep -q NPKIT_EVENT_EXECUTOR_INIT_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_COPY_PACKET_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_PUT_PACKET_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_REDUCE_SEND_PACKET_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_PUT_PACKETS_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_REDUCE_SEND_PACKETS_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_UNPACK_PACKETS_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
|
||||
- template: stop.yml
|
||||
parameters:
|
||||
|
||||
@@ -41,6 +41,7 @@ steps:
|
||||
displayName: Run pytests
|
||||
remoteScript: |
|
||||
mpirun --allow-run-as-root -tag-output -x MSCCLPP_HOME=/root/mscclpp -x GPU_MAX_HW_QUEUES=8 -np 8 python3 -m pytest ./python/test/test_mscclpp.py -x
|
||||
mpirun --allow-run-as-root -tag-output -x MSCCLPP_HOME=/root/mscclpp -x GPU_MAX_HW_QUEUES=8 -np 8 python3 -m pytest ./python/test/test_fp8_accum.py -x
|
||||
|
||||
- template: stop.yml
|
||||
parameters:
|
||||
|
||||
@@ -30,19 +30,15 @@ find_library(GDRCOPY_LIBRARIES
|
||||
${GDRCOPY_ROOT_DIR}/lib
|
||||
/usr/local/lib
|
||||
/usr/lib
|
||||
/usr/lib/x86_64-linux-gnu
|
||||
/usr/lib/aarch64-linux-gnu)
|
||||
/usr/lib/x86_64-linux-gnu)
|
||||
|
||||
if(GDRCOPY_INCLUDE_DIRS)
|
||||
include(CheckCXXSourceCompiles)
|
||||
include(CheckSymbolExists)
|
||||
set(CMAKE_REQUIRED_INCLUDES ${GDRCOPY_INCLUDE_DIRS})
|
||||
set(CMAKE_REQUIRED_LIBRARIES ${GDRCOPY_LIBRARIES})
|
||||
check_cxx_source_compiles("
|
||||
#include <gdrapi.h>
|
||||
int main() { gdr_pin_buffer_v2(0, 0, 0, 0, 0); return 0; }
|
||||
" GDRCOPY_HAS_PIN_BUFFER_V2)
|
||||
unset(CMAKE_REQUIRED_INCLUDES)
|
||||
check_symbol_exists(gdr_pin_buffer_v2 "gdrapi.h" GDRCOPY_HAS_PIN_BUFFER_V2)
|
||||
unset(CMAKE_REQUIRED_LIBRARIES)
|
||||
unset(CMAKE_REQUIRED_INCLUDES)
|
||||
if(NOT GDRCOPY_HAS_PIN_BUFFER_V2)
|
||||
message(STATUS "GDRCopy found but too old (gdr_pin_buffer_v2 not available). Requires >= 2.5.")
|
||||
set(GDRCOPY_INCLUDE_DIRS GDRCOPY_INCLUDE_DIRS-NOTFOUND)
|
||||
|
||||
@@ -52,7 +52,7 @@ RUN OS_ARCH=$(uname -m) && \
|
||||
# Install GDRCopy userspace library for CUDA targets
|
||||
ARG TARGET="cuda13.0"
|
||||
RUN if echo "$TARGET" | grep -q "^cuda"; then \
|
||||
GDRCOPY_VERSION="2.5.1" && \
|
||||
GDRCOPY_VERSION="2.5.2" && \
|
||||
apt-get update -y && \
|
||||
apt-get install -y --no-install-recommends devscripts debhelper fakeroot pkg-config dkms && \
|
||||
cd /tmp && \
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
# from the environment for the first two.
|
||||
SPHINXOPTS ?=
|
||||
SPHINXBUILD ?= sphinx-build
|
||||
SPHINXMULTIVERSION ?= sphinx-multiversion
|
||||
SPHINXMULTIVERSION ?= python3 build_multiversion.py
|
||||
SOURCEDIR = .
|
||||
BUILDDIR = _build
|
||||
|
||||
|
||||
49
docs/build_multiversion.py
Normal file
49
docs/build_multiversion.py
Normal file
@@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Wrapper around sphinx-multiversion that patches copy_tree to generate
|
||||
_version.py in each tag checkout. This is needed because setuptools_scm
|
||||
generates _version.py at build time, but sphinx-multiversion uses
|
||||
`git archive` which only contains committed files.
|
||||
|
||||
Usage (called by Makefile):
|
||||
python3 build_multiversion.py <sourcedir> <outputdir> [sphinx-opts...]
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import sphinx_multiversion.git as smv_git
|
||||
from sphinx_multiversion import main as smv_main
|
||||
|
||||
# Save the original copy_tree
|
||||
_original_copy_tree = smv_git.copy_tree
|
||||
|
||||
|
||||
def _patched_copy_tree(gitroot, src, dst, reference, sourcepath="."):
|
||||
"""Call original copy_tree, then generate _version.py from the VERSION file."""
|
||||
_original_copy_tree(gitroot, src, dst, reference, sourcepath)
|
||||
|
||||
# Extract version from the tag name (e.g., "v0.9.0" -> "0.9.0")
|
||||
refname = getattr(reference, "refname", "") or ""
|
||||
match = re.search(r"v(\d+\.\d+\.\d+)", refname)
|
||||
if not match:
|
||||
return
|
||||
|
||||
version = match.group(1)
|
||||
version_py_dir = os.path.join(dst, "python", "mscclpp")
|
||||
if os.path.isdir(version_py_dir):
|
||||
version_py = os.path.join(version_py_dir, "_version.py")
|
||||
if not os.path.exists(version_py):
|
||||
with open(version_py, "w") as f:
|
||||
f.write(f'__version__ = "{version}"\n')
|
||||
|
||||
|
||||
# Monkey-patch
|
||||
smv_git.copy_tree = _patched_copy_tree
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(smv_main(sys.argv[1:]))
|
||||
@@ -332,7 +332,8 @@ public:
|
||||
size_t inputSize, size_t outputSize,
|
||||
mscclpp::DataType dtype, mscclpp::ReduceOp op,
|
||||
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras) {
|
||||
const std::unordered_map<std::string, uintptr_t>& extras,
|
||||
[[maybe_unused]] mscclpp::DataType accumDtype) {
|
||||
return self->kernelFunc(ctx, input, output, inputSize, dtype, stream);
|
||||
},
|
||||
// Context initialization function
|
||||
|
||||
@@ -31,7 +31,7 @@
|
||||
```
|
||||
If you don't want to build Python module, you need to set `-DMSCCLPP_BUILD_PYTHON_BINDINGS=OFF` in your `cmake` command (see details in [Install from Source](#install-from-source)).
|
||||
* (Optional, for benchmarks) MPI
|
||||
* (Optional, for NVIDIA platforms) [GDRCopy](https://github.com/NVIDIA/gdrcopy) >= 2.5.0
|
||||
* (Optional, for NVIDIA platforms) [GDRCopy](https://github.com/NVIDIA/gdrcopy) >= 2.5.1
|
||||
* GDRCopy is required for IB `HostNoAtomic` mode, which uses CPU-side signal forwarding to GPU memory via BAR1 mappings. This mode is used on platforms where RDMA atomics are not available (e.g., when using Data Direct Virtual Functions).
|
||||
* Install GDRCopy from source or via packages. See the [GDRCopy installation guide](https://github.com/NVIDIA/gdrcopy#installation).
|
||||
* Others
|
||||
|
||||
@@ -101,7 +101,8 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
|
||||
"allgather", "allgather", [self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
|
||||
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize, size_t outputSize,
|
||||
mscclpp::DataType dtype, [[maybe_unused]] mscclpp::ReduceOp op, cudaStream_t stream, int nBlocks,
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) {
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras,
|
||||
[[maybe_unused]] mscclpp::DataType accumDtype) {
|
||||
return self->allgatherKernelFunc(ctx, input, output, inputSize, stream);
|
||||
},
|
||||
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
|
||||
@@ -69,7 +69,8 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
|
||||
"allgather", "allgather", [self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
|
||||
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize, size_t outputSize,
|
||||
mscclpp::DataType dtype, [[maybe_unused]] mscclpp::ReduceOp op, cudaStream_t stream, int nBlocks,
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) {
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras,
|
||||
[[maybe_unused]] mscclpp::DataType accumDtype) {
|
||||
return self->allgatherKernelFunc(ctx, input, output, inputSize, dtype, stream);
|
||||
},
|
||||
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
|
||||
@@ -1,193 +1,117 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# MSCCLPP_MASTER_ADDR=<master_ip> MSCCLPP_MASTER_PORT=<port> torchrun --nnodes=1 --nproc_per_node=8 customized_comm_with_tuning.py
|
||||
# torchrun --nnodes=1 --nproc_per_node=8 examples/torch-integration/customized_comm_with_tuning.py
|
||||
|
||||
import os
|
||||
import torch
|
||||
import mscclpp.utils as mscclpp_utils
|
||||
import mscclpp
|
||||
import mscclpp.ext
|
||||
import netifaces as ni
|
||||
import ipaddress
|
||||
|
||||
import netifaces as ni
|
||||
import torch
|
||||
import mscclpp
|
||||
import mscclpp.ext
|
||||
import mscclpp.utils as mscclpp_utils
|
||||
|
||||
def load_algorithms(scratch_buffer: torch.tensor, rank: int) -> mscclpp.AlgorithmCollection:
|
||||
collection_builder = mscclpp.ext.AlgorithmCollectionBuilder()
|
||||
return collection_builder.build_default_algorithms(
|
||||
scratch_buffer=scratch_buffer.data_ptr(), scratch_buffer_size=scratch_buffer.nbytes, rank=rank
|
||||
# -- Helpers ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_tensor(size_bytes: int, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""Allocate a tensor backed by RawGpuBuffer (symmetric memory)."""
|
||||
# PyTorch's from_dlpack does not support certain float8 DLPack type codes.
|
||||
# Work around by importing as uint8 and reinterpreting via .view().
|
||||
_DLPACK_UNSUPPORTED = (torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)
|
||||
if dtype in _DLPACK_UNSUPPORTED:
|
||||
dlpack = mscclpp.RawGpuBuffer(size_bytes).to_dlpack(data_type=str(torch.uint8))
|
||||
return torch.utils.dlpack.from_dlpack(dlpack).view(dtype)
|
||||
dlpack = mscclpp.RawGpuBuffer(size_bytes).to_dlpack(data_type=str(dtype))
|
||||
return torch.utils.dlpack.from_dlpack(dlpack)
|
||||
|
||||
|
||||
def _load_algorithms(scratch: torch.Tensor, rank: int):
|
||||
return mscclpp.ext.AlgorithmCollectionBuilder().build_default_algorithms(
|
||||
scratch_buffer=scratch.data_ptr(),
|
||||
scratch_buffer_size=scratch.nbytes,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
|
||||
def interfaces_for_ip_netifaces(ip: str):
|
||||
def _interfaces_for_ip(ip: str):
|
||||
target = ipaddress.ip_address(ip)
|
||||
for interface in ni.interfaces():
|
||||
addresses = ni.ifaddresses(interface)
|
||||
if ni.AF_INET in addresses:
|
||||
for link in addresses[ni.AF_INET]:
|
||||
if "addr" in link:
|
||||
addr = ipaddress.ip_address(link["addr"])
|
||||
if addr == target:
|
||||
return interface
|
||||
for iface in ni.interfaces():
|
||||
addrs = ni.ifaddresses(iface)
|
||||
if ni.AF_INET in addrs:
|
||||
for link in addrs[ni.AF_INET]:
|
||||
if "addr" in link and ipaddress.ip_address(link["addr"]) == target:
|
||||
return iface
|
||||
return None
|
||||
|
||||
|
||||
def to_mscclpp_reduce_op(op: torch.distributed.ReduceOp) -> mscclpp.ReduceOp:
|
||||
def _to_mscclpp_op(op) -> mscclpp.ReduceOp:
|
||||
if op == torch.distributed.ReduceOp.SUM:
|
||||
return mscclpp.ReduceOp.SUM
|
||||
elif op == torch.distributed.ReduceOp.MIN:
|
||||
if op == torch.distributed.ReduceOp.MIN:
|
||||
return mscclpp.ReduceOp.MIN
|
||||
else:
|
||||
raise ValueError(f"unsupported op: {op}")
|
||||
raise ValueError(f"unsupported op: {op}")
|
||||
|
||||
|
||||
def _round_pow2(size: int) -> int:
|
||||
"""Round up to next power-of-2, clamped to [1024, 256 MB]."""
|
||||
size = max(size, 1024)
|
||||
size = min(size, 256 << 20)
|
||||
return 1 << (size - 1).bit_length()
|
||||
|
||||
|
||||
# -- CustomizedComm -----------------------------------------------------------
|
||||
|
||||
|
||||
class CustomizedComm:
|
||||
def __init__(self, comm: mscclpp.CommGroup):
|
||||
"""Exposes all_reduce, all_gather, barrier with lazy per-size tuning."""
|
||||
|
||||
_TUNE_N_WARMUP = 5
|
||||
_TUNE_N_GRAPH_LAUNCHES = 10
|
||||
_TUNE_N_OPS_PER_GRAPH = 100
|
||||
_CANDIDATE_NBLOCKS = [4, 8, 16, 24, 32, 48, 64, 128]
|
||||
_CANDIDATE_NTHREADS = [512, 768, 1024]
|
||||
_NBLOCKS_LIMIT = {
|
||||
"default_allreduce_nvls_packet": 16,
|
||||
"default_allreduce_packet": 56,
|
||||
"default_allreduce_allpair_packet": 56,
|
||||
"default_allreduce_fullmesh": 64,
|
||||
"default_allgather_fullmesh2": 32,
|
||||
}
|
||||
|
||||
def __init__(self, comm: mscclpp.CommGroup, symmetric_memory: bool = False):
|
||||
self.comm = comm
|
||||
self.rank = comm.my_rank
|
||||
self.world_size = comm.nranks
|
||||
self.local_rank = comm.my_rank % comm.nranks_per_node
|
||||
self.n_ranks_per_node = comm.nranks_per_node
|
||||
dlpack = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(torch.float16))
|
||||
self.scratch_buffer = torch.utils.dlpack.from_dlpack(dlpack)
|
||||
algorithms = load_algorithms(scratch_buffer=self.scratch_buffer, rank=self.rank)
|
||||
self._algorithm_nvls_packet = [
|
||||
algo
|
||||
for algo in algorithms
|
||||
if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_packet"
|
||||
][0]
|
||||
self._algorithm_rsag_zero_copy = [
|
||||
algo
|
||||
for algo in algorithms
|
||||
if algo.collective == "allreduce" and algo.name == "default_allreduce_rsag_zero_copy"
|
||||
][0]
|
||||
self._algorithm_packet = [
|
||||
algo for algo in algorithms if algo.collective == "allreduce" and algo.name == "default_allreduce_packet"
|
||||
][0]
|
||||
if mscclpp.is_nvls_supported():
|
||||
self._algorithm_nvls_zero_copy = [
|
||||
algo
|
||||
for algo in algorithms
|
||||
if algo.collective == "allreduce" and algo.name == "default_allreduce_nvls_zero_copy"
|
||||
][0]
|
||||
self._tune(n_warmup=5, n_graph_launches=10, n_ops_per_graph=100)
|
||||
self.symmetric_memory = symmetric_memory
|
||||
self._nvls = mscclpp.is_nvls_supported()
|
||||
|
||||
def _tune(self, n_warmup, n_graph_launches, n_ops_per_graph):
|
||||
sizes = [1 << i for i in range(10, 28)]
|
||||
# Pre-fill with defaults for barrier
|
||||
self.best_configs = {1024: (self._algorithm_nvls_packet, 0, 0)}
|
||||
self._scratch = _make_tensor(1 << 27, torch.float16)
|
||||
self._barrier_tensor = _make_tensor(4096, torch.float32)
|
||||
|
||||
tune_tensor = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(torch.float16))
|
||||
tune_tensor = torch.utils.dlpack.from_dlpack(tune_tensor)
|
||||
tune_tensor.normal_()
|
||||
candidates_nblocks = [4, 8, 16, 24, 32, 48, 64, 128]
|
||||
candidates_nthreads = [512, 768, 1024]
|
||||
algos = _load_algorithms(self._scratch, self.rank)
|
||||
self._algos = {(a.collective, a.name): a for a in algos}
|
||||
|
||||
for size in sizes:
|
||||
algos = []
|
||||
if mscclpp.is_nvls_supported():
|
||||
algos.append(self._algorithm_nvls_zero_copy)
|
||||
if size <= 4 * 1024 * 1024:
|
||||
algos.append(self._algorithm_nvls_packet)
|
||||
algos.append(self._algorithm_packet)
|
||||
if size >= 512 * 1024:
|
||||
algos.append(self._algorithm_rsag_zero_copy)
|
||||
# {collective: {rounded_size: (algo, nblocks, nthreads)}}
|
||||
self._tune_cache: dict[str, dict[int, tuple]] = {"allreduce": {}, "allgather": {}}
|
||||
self._tune_buf = None
|
||||
self._time_buf = None
|
||||
|
||||
best_time = float("inf")
|
||||
best_config = None
|
||||
def _algo(self, collective: str, name: str):
|
||||
return self._algos.get((collective, name))
|
||||
|
||||
for algo in algos:
|
||||
for nb in candidates_nblocks:
|
||||
if algo.name == "default_allreduce_nvls_packet" and nb > 16:
|
||||
continue
|
||||
if algo.name == "default_allreduce_packet" and nb > 56:
|
||||
continue
|
||||
for nt in candidates_nthreads:
|
||||
if self._run_algo(algo, tune_tensor, size, nb, nt) != 0:
|
||||
continue
|
||||
def _default_ar_config(self):
|
||||
"""Fallback allreduce config for barrier / timing sync."""
|
||||
pkt = self._algo("allreduce", "default_allreduce_nvls_packet")
|
||||
if self._nvls and pkt:
|
||||
return (pkt, 0, 0)
|
||||
return (self._algo("allreduce", "default_allreduce_packet"), 0, 0)
|
||||
|
||||
for _ in range(n_warmup):
|
||||
self._run_algo(algo, tune_tensor, size, nb, nt)
|
||||
self.barrier()
|
||||
# -- low-level execute --
|
||||
|
||||
capture_stream = torch.cuda.Stream()
|
||||
capture_stream.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
g = torch.cuda.CUDAGraph()
|
||||
# Warmup on capture stream
|
||||
with torch.cuda.stream(capture_stream):
|
||||
self._run_algo(algo, tune_tensor, size, nb, nt)
|
||||
capture_stream.synchronize()
|
||||
|
||||
with torch.cuda.graph(g, stream=capture_stream):
|
||||
for _ in range(n_ops_per_graph):
|
||||
self._run_algo(algo, tune_tensor, size, nb, nt)
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record(capture_stream)
|
||||
with torch.cuda.stream(capture_stream):
|
||||
for _ in range(n_graph_launches):
|
||||
g.replay()
|
||||
end_event.record(capture_stream)
|
||||
end_event.synchronize()
|
||||
|
||||
elapsed = start_event.elapsed_time(end_event)
|
||||
|
||||
# Synchronize timing results across all ranks to ensure consistent algorithm selection
|
||||
# replicate n times such due to algo limitations
|
||||
time_tensor = torch.full((self.world_size,), elapsed, dtype=torch.float64, device="cuda").to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
torch.cuda.current_stream().wait_stream(capture_stream)
|
||||
# TODO: use all_reduce may cause problem if the time elapsed between different algos are too close.
|
||||
# May change to broadcast in the future if that becomes an issue.
|
||||
self.all_reduce(time_tensor, op=torch.distributed.ReduceOp.SUM)
|
||||
avg_time = time_tensor[self.rank].item() / self.world_size
|
||||
|
||||
if avg_time < best_time:
|
||||
best_time = avg_time
|
||||
best_config = (algo, nb, nt)
|
||||
|
||||
if best_config:
|
||||
self.best_configs[size] = best_config
|
||||
if self.rank == 0:
|
||||
print(
|
||||
f"Size {size}: Best Algo {best_config[0].name} nblocks {best_config[1]} nthreads {best_config[2]} Time {(best_time/(n_graph_launches * n_ops_per_graph))*1000:.2f} us"
|
||||
)
|
||||
# reset the algorithms after tuning
|
||||
torch.cuda.synchronize()
|
||||
for algo in algos:
|
||||
algo.reset()
|
||||
|
||||
def _run_algo(self, algo: mscclpp.Algorithm, tensor, size, nblocks, nthreads):
|
||||
return algo.execute(
|
||||
comm=self.comm.communicator,
|
||||
input_buffer=tensor.data_ptr(),
|
||||
output_buffer=tensor.data_ptr(),
|
||||
input_size=size,
|
||||
output_size=size,
|
||||
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(tensor.dtype),
|
||||
op=mscclpp.ReduceOp.SUM,
|
||||
stream=torch.cuda.current_stream().cuda_stream,
|
||||
nblocks=nblocks,
|
||||
nthreads_per_block=nthreads,
|
||||
symmetric_memory=True,
|
||||
)
|
||||
|
||||
def get_tuned_config(self, size):
|
||||
if size < 1024:
|
||||
target_size = 1024
|
||||
elif size > 256 * 1024 * 1024:
|
||||
target_size = 256 * 1024 * 1024
|
||||
else:
|
||||
target_size = 1 << (size - 1).bit_length()
|
||||
return self.best_configs.get(target_size)
|
||||
|
||||
def all_reduce(self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM, stream: torch.cuda.Stream = None):
|
||||
assert op == torch.distributed.ReduceOp.SUM
|
||||
config = self.get_tuned_config(tensor.nbytes)
|
||||
algo, nblocks, nthreads = config if config else (self._algorithm_nvls_packet, 0, 0)
|
||||
def _exec_ar(self, tensor, algo, nb, nt, op=mscclpp.ReduceOp.SUM, stream=None, accum_dtype=None, sym=True):
|
||||
s = stream.cuda_stream if stream else torch.cuda.current_stream().cuda_stream
|
||||
ret = algo.execute(
|
||||
comm=self.comm.communicator,
|
||||
input_buffer=tensor.data_ptr(),
|
||||
@@ -195,107 +119,357 @@ class CustomizedComm:
|
||||
input_size=tensor.nbytes,
|
||||
output_size=tensor.nbytes,
|
||||
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(tensor.dtype),
|
||||
op=to_mscclpp_reduce_op(op),
|
||||
stream=stream.cuda_stream if stream is not None else torch.cuda.current_stream().cuda_stream,
|
||||
nblocks=nblocks,
|
||||
nthreads_per_block=nthreads,
|
||||
symmetric_memory=True,
|
||||
op=op,
|
||||
stream=s,
|
||||
nblocks=nb,
|
||||
nthreads_per_block=nt,
|
||||
symmetric_memory=sym,
|
||||
accum_dtype=accum_dtype,
|
||||
)
|
||||
if ret != 0:
|
||||
print(f"Rank {self.rank}: Algo {algo.name} failed with error {ret}")
|
||||
print(f"Rank {self.rank}: {algo.name} failed ({ret})")
|
||||
return ret
|
||||
|
||||
def _exec_ag(self, inp, out, algo, nb, nt, stream=None, sym=None):
|
||||
if sym is None:
|
||||
sym = self.symmetric_memory
|
||||
s = stream.cuda_stream if stream else torch.cuda.current_stream().cuda_stream
|
||||
ret = algo.execute(
|
||||
comm=self.comm.communicator,
|
||||
input_buffer=inp.data_ptr(),
|
||||
output_buffer=out.data_ptr(),
|
||||
input_size=inp.nbytes,
|
||||
output_size=out.nbytes,
|
||||
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(inp.dtype),
|
||||
op=mscclpp.ReduceOp.NOP,
|
||||
stream=s,
|
||||
nblocks=nb,
|
||||
nthreads_per_block=nt,
|
||||
symmetric_memory=sym,
|
||||
)
|
||||
if ret != 0:
|
||||
print(f"Rank {self.rank}: AG {algo.name} failed ({ret})")
|
||||
return ret
|
||||
|
||||
def _barrier_internal(self):
|
||||
a, nb, nt = self._default_ar_config()
|
||||
self._exec_ar(self._barrier_tensor, a, nb, nt, sym=True)
|
||||
|
||||
# -- lazy tuning --
|
||||
|
||||
def _ensure_tune_bufs(self):
|
||||
if self._tune_buf is None:
|
||||
self._tune_buf = _make_tensor(1 << 27, torch.float16)
|
||||
self._tune_buf.normal_()
|
||||
self._time_buf = _make_tensor(4096, torch.float32)
|
||||
return self._tune_buf
|
||||
|
||||
def _ar_candidates(self, size: int):
|
||||
out = []
|
||||
if size <= 4 << 20:
|
||||
a = self._algo("allreduce", "default_allreduce_nvls_packet")
|
||||
if self._nvls and a:
|
||||
out.append(a)
|
||||
a = self._algo("allreduce", "default_allreduce_packet")
|
||||
if a:
|
||||
out.append(a)
|
||||
a = self._algo("allreduce", "default_allreduce_allpair_packet")
|
||||
if a:
|
||||
out.append(a)
|
||||
if size >= 512 << 10:
|
||||
a = self._algo("allreduce", "default_allreduce_nvls_zero_copy")
|
||||
if self._nvls and self.symmetric_memory and a:
|
||||
out.append(a)
|
||||
a = self._algo("allreduce", "default_allreduce_rsag_zero_copy")
|
||||
if a:
|
||||
out.append(a)
|
||||
if torch.version.hip is not None:
|
||||
a = self._algo("allreduce", "default_allreduce_fullmesh")
|
||||
if a:
|
||||
out.append(a)
|
||||
return out
|
||||
|
||||
def _ag_candidates(self):
|
||||
a = self._algo("allgather", "default_allgather_fullmesh2")
|
||||
return [a] if a else []
|
||||
|
||||
def _run_tune(self, collective, algo, buf, size, nb, nt):
|
||||
"""Single tune invocation for either collective."""
|
||||
if collective == "allreduce":
|
||||
return algo.execute(
|
||||
comm=self.comm.communicator,
|
||||
input_buffer=buf.data_ptr(),
|
||||
output_buffer=buf.data_ptr(),
|
||||
input_size=size,
|
||||
output_size=size,
|
||||
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(buf.dtype),
|
||||
op=mscclpp.ReduceOp.SUM,
|
||||
stream=torch.cuda.current_stream().cuda_stream,
|
||||
nblocks=nb,
|
||||
nthreads_per_block=nt,
|
||||
symmetric_memory=True,
|
||||
)
|
||||
else:
|
||||
total = size * self.world_size
|
||||
out_ptr = buf.data_ptr()
|
||||
return algo.execute(
|
||||
comm=self.comm.communicator,
|
||||
input_buffer=out_ptr + self.rank * size,
|
||||
output_buffer=out_ptr,
|
||||
input_size=size,
|
||||
output_size=total,
|
||||
dtype=mscclpp_utils.torch_dtype_to_mscclpp_dtype(buf.dtype),
|
||||
op=mscclpp.ReduceOp.NOP,
|
||||
stream=torch.cuda.current_stream().cuda_stream,
|
||||
nblocks=nb,
|
||||
nthreads_per_block=nt,
|
||||
symmetric_memory=False,
|
||||
)
|
||||
|
||||
def _tune_size(self, collective: str, target_size: int):
|
||||
"""Auto-tune one (collective, target_size) pair and cache result."""
|
||||
buf = self._ensure_tune_bufs()
|
||||
cands = self._ar_candidates(target_size) if collective == "allreduce" else self._ag_candidates()
|
||||
|
||||
best_time, best_cfg = float("inf"), None
|
||||
used = set()
|
||||
run = lambda a, nb, nt: self._run_tune(collective, a, buf, target_size, nb, nt)
|
||||
|
||||
for algo in cands:
|
||||
nb_limit = self._NBLOCKS_LIMIT.get(algo.name, 128)
|
||||
for nb in self._CANDIDATE_NBLOCKS:
|
||||
if nb > nb_limit:
|
||||
continue
|
||||
for nt in self._CANDIDATE_NTHREADS:
|
||||
# Feasibility — sync result across ranks so all agree
|
||||
ret = run(algo, nb, nt)
|
||||
torch.cuda.synchronize()
|
||||
self._time_buf[0] = float(ret)
|
||||
self._exec_ar(self._time_buf[:1], *self._default_ar_config(), sym=True)
|
||||
if self._time_buf[0].item() != 0:
|
||||
continue
|
||||
used.add(algo)
|
||||
|
||||
# Warmup
|
||||
for _ in range(self._TUNE_N_WARMUP):
|
||||
run(algo, nb, nt)
|
||||
|
||||
# CUDA-graph timed benchmark
|
||||
cs = torch.cuda.Stream()
|
||||
cs.wait_stream(torch.cuda.current_stream())
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g, stream=cs):
|
||||
for _ in range(self._TUNE_N_OPS_PER_GRAPH):
|
||||
run(algo, nb, nt)
|
||||
|
||||
start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
||||
start.record(cs)
|
||||
with torch.cuda.stream(cs):
|
||||
for _ in range(self._TUNE_N_GRAPH_LAUNCHES):
|
||||
g.replay()
|
||||
end.record(cs)
|
||||
end.synchronize()
|
||||
elapsed = start.elapsed_time(end)
|
||||
|
||||
# Cross-rank timing sync
|
||||
self._time_buf.fill_(elapsed)
|
||||
torch.cuda.current_stream().wait_stream(cs)
|
||||
self._exec_ar(self._time_buf, *self._default_ar_config(), sym=True)
|
||||
avg = self._time_buf[self.rank].item() / self.world_size
|
||||
|
||||
if avg < best_time:
|
||||
best_time, best_cfg = avg, (algo, nb, nt)
|
||||
|
||||
if best_cfg:
|
||||
self._tune_cache[collective][target_size] = best_cfg
|
||||
if self.rank == 0:
|
||||
n = self._TUNE_N_GRAPH_LAUNCHES * self._TUNE_N_OPS_PER_GRAPH
|
||||
print(
|
||||
f"[tune] {collective} size={target_size}: {best_cfg[0].name} "
|
||||
f"nb={best_cfg[1]} nt={best_cfg[2]} time={best_time / n * 1000:.2f}us",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
fb = (
|
||||
self._default_ar_config()
|
||||
if collective == "allreduce"
|
||||
else ((self._ag_candidates()[0], 32, 512) if self._ag_candidates() else None)
|
||||
)
|
||||
self._tune_cache[collective][target_size] = fb
|
||||
|
||||
torch.cuda.synchronize()
|
||||
self._barrier_internal()
|
||||
for a in used:
|
||||
a.reset()
|
||||
|
||||
# -- public API --
|
||||
|
||||
def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, stream=None, accum_dtype=None):
|
||||
sz = _round_pow2(tensor.nbytes)
|
||||
if sz not in self._tune_cache["allreduce"]:
|
||||
self._tune_size("allreduce", sz)
|
||||
a, nb, nt = self._tune_cache["allreduce"][sz]
|
||||
self._exec_ar(
|
||||
tensor, a, nb, nt, op=_to_mscclpp_op(op), stream=stream, accum_dtype=accum_dtype, sym=self.symmetric_memory
|
||||
)
|
||||
|
||||
def all_gather(self, output_tensor, input_tensor, stream=None):
|
||||
sz = _round_pow2(input_tensor.nbytes)
|
||||
if sz not in self._tune_cache["allgather"]:
|
||||
self._tune_size("allgather", sz)
|
||||
a, nb, nt = self._tune_cache["allgather"][sz]
|
||||
self._exec_ag(input_tensor, output_tensor, a, nb, nt, stream=stream, sym=self.symmetric_memory)
|
||||
|
||||
def barrier(self):
|
||||
tensor = torch.empty(self.world_size, dtype=torch.float, device=torch.device("cuda"))
|
||||
self.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM, stream=torch.cuda.current_stream())
|
||||
|
||||
def benchmark(self, n_warmup=10, n_graph_launches=10, n_iter_per_graph=100):
|
||||
low = 5 * 1024
|
||||
high = 80 * 1024 * 1024
|
||||
sizes = []
|
||||
curr = low
|
||||
while curr <= high:
|
||||
sizes.append(curr)
|
||||
curr *= 2
|
||||
|
||||
if self.rank == 0:
|
||||
print(f"{'Size (Bytes)':<20} {'Time (us)':<20} {'AlgoBW (GB/s)':<20}")
|
||||
|
||||
dtype = torch.float16
|
||||
capture_stream = torch.cuda.Stream()
|
||||
|
||||
# Allocate a single large RawGpuBuffer (symmetric memory) and reuse it for all sizes.
|
||||
# Cannot allocate per-size tensors with symmetric memory.
|
||||
bench_buf = mscclpp.RawGpuBuffer(1 << 27).to_dlpack(data_type=str(dtype))
|
||||
bench_buf = torch.utils.dlpack.from_dlpack(bench_buf)
|
||||
bench_buf.normal_()
|
||||
|
||||
for size in sizes:
|
||||
n_elements = size // bench_buf.element_size()
|
||||
tensor = bench_buf[:n_elements]
|
||||
|
||||
capture_stream.wait_stream(torch.cuda.current_stream())
|
||||
# Capture Graph
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g, stream=capture_stream):
|
||||
for _ in range(n_iter_per_graph):
|
||||
self.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
|
||||
|
||||
# warmup: Execute the graph once to prime the driver
|
||||
with torch.cuda.stream(capture_stream):
|
||||
for _ in range(n_warmup):
|
||||
g.replay()
|
||||
self.barrier()
|
||||
capture_stream.synchronize()
|
||||
|
||||
# Benchmark
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record(capture_stream)
|
||||
with torch.cuda.stream(capture_stream):
|
||||
for _ in range(n_graph_launches):
|
||||
g.replay()
|
||||
end_event.record(capture_stream)
|
||||
end_event.synchronize()
|
||||
|
||||
# Get elapsed time in milliseconds
|
||||
elapsed_ms = start_event.elapsed_time(end_event)
|
||||
avg_time_ms = elapsed_ms / (n_graph_launches * n_iter_per_graph)
|
||||
time_us = avg_time_ms * 1000
|
||||
|
||||
alg_bw = size / (avg_time_ms * 1e-3) if avg_time_ms > 0 else 0
|
||||
if self.rank == 0:
|
||||
print(f"{size:<20} {time_us:<20.2f} {alg_bw / 1e9:<20.2f}")
|
||||
self._barrier_internal()
|
||||
|
||||
def destroy(self):
|
||||
self._algorithm_nvls_nonzero_copy = None
|
||||
self._algorithm_nvls_packet = None
|
||||
self.scratch_buffer = None
|
||||
self.comm = None
|
||||
self._algos.clear()
|
||||
self._tune_cache = {"allreduce": {}, "allgather": {}}
|
||||
self._tune_buf = self._time_buf = self._barrier_tensor = self._scratch = self.comm = None
|
||||
|
||||
|
||||
def init_dist() -> CustomizedComm:
|
||||
rank = int(os.environ["RANK"])
|
||||
world = int(os.environ["WORLD_SIZE"])
|
||||
master_addr = os.environ["MSCCLPP_MASTER_ADDR"]
|
||||
master_port = os.environ["MSCCLPP_MASTER_PORT"]
|
||||
interface = interfaces_for_ip_netifaces(master_addr)
|
||||
if interface is None:
|
||||
raise ValueError(f"Cannot find network interface for IP address {master_addr}")
|
||||
interfaceIpPortTrio = f"{interface}:{master_addr}:{master_port}"
|
||||
mscclpp_group = mscclpp.CommGroup(interfaceIpPortTrio=interfaceIpPortTrio, rank=rank, size=world)
|
||||
return CustomizedComm(mscclpp_group)
|
||||
# -- Benchmarks (standalone) --------------------------------------------------
|
||||
|
||||
|
||||
def _bench_sizes(low=5 * 1024, high=80 << 20):
|
||||
sizes, c = [], low
|
||||
while c <= high:
|
||||
sizes.append(c)
|
||||
c *= 2
|
||||
return sizes
|
||||
|
||||
|
||||
def benchmark_allreduce(
|
||||
comm: CustomizedComm, dtype=torch.float16, accum_dtype=None, n_warmup=10, n_graph_launches=10, n_iter=100
|
||||
):
|
||||
sizes = _bench_sizes()
|
||||
if comm.rank == 0:
|
||||
print(f"\n{'='*60}\nAllreduce Benchmark\n{'='*60}")
|
||||
print(f"{'Nelements':<18} {'Size(B)':<18} {'Time(us)':<18} {'AlgoBW(GB/s)':<18}")
|
||||
|
||||
cs = torch.cuda.Stream()
|
||||
buf = _make_tensor(1 << 27, dtype)
|
||||
buf.normal_() if dtype in (torch.float16, torch.float32, torch.bfloat16) else buf.fill_(0)
|
||||
|
||||
for size in sizes:
|
||||
nelems = size // buf.element_size()
|
||||
t = buf[: size // buf.element_size()]
|
||||
comm.all_reduce(t, accum_dtype=accum_dtype)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
cs.wait_stream(torch.cuda.current_stream())
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g, stream=cs):
|
||||
for _ in range(n_iter):
|
||||
comm.all_reduce(t, accum_dtype=accum_dtype)
|
||||
with torch.cuda.stream(cs):
|
||||
for _ in range(n_warmup):
|
||||
g.replay()
|
||||
comm.barrier()
|
||||
cs.synchronize()
|
||||
|
||||
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
||||
s.record(cs)
|
||||
with torch.cuda.stream(cs):
|
||||
for _ in range(n_graph_launches):
|
||||
g.replay()
|
||||
e.record(cs)
|
||||
e.synchronize()
|
||||
|
||||
ms = s.elapsed_time(e) / (n_graph_launches * n_iter)
|
||||
if comm.rank == 0:
|
||||
print(f"{nelems:<18} {size:<18} {ms*1000:<18.2f} {size/(ms*1e-3)/1e9:<18.2f}")
|
||||
|
||||
|
||||
def benchmark_allgather(comm: CustomizedComm, dtype=torch.float16, n_warmup=10, n_graph_launches=10, n_iter=100):
|
||||
sizes = _bench_sizes()
|
||||
if comm.rank == 0:
|
||||
print(f"\n{'='*60}\nAllgather Benchmark\n{'='*60}")
|
||||
print(f"{'PerRank(B)':<18} {'Total(B)':<18} {'Time(us)':<18} {'AlgoBW(GB/s)':<18}")
|
||||
|
||||
cs = torch.cuda.Stream()
|
||||
buf = _make_tensor(1 << 27, dtype)
|
||||
buf.normal_() if dtype in (torch.float16, torch.float32, torch.bfloat16) else buf.fill_(0)
|
||||
|
||||
for prs in sizes:
|
||||
total = prs * comm.world_size
|
||||
if total > buf.nbytes:
|
||||
break
|
||||
nt = total // buf.element_size()
|
||||
npr = prs // buf.element_size()
|
||||
out = buf[:nt]
|
||||
inp = out[comm.rank * npr : (comm.rank + 1) * npr]
|
||||
|
||||
comm.all_gather(out, inp)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
cs.wait_stream(torch.cuda.current_stream())
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g, stream=cs):
|
||||
for _ in range(n_iter):
|
||||
comm.all_gather(out, inp)
|
||||
with torch.cuda.stream(cs):
|
||||
for _ in range(n_warmup):
|
||||
g.replay()
|
||||
comm.barrier()
|
||||
cs.synchronize()
|
||||
|
||||
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
||||
s.record(cs)
|
||||
with torch.cuda.stream(cs):
|
||||
for _ in range(n_graph_launches):
|
||||
g.replay()
|
||||
e.record(cs)
|
||||
e.synchronize()
|
||||
|
||||
ms = s.elapsed_time(e) / (n_graph_launches * n_iter)
|
||||
if comm.rank == 0:
|
||||
print(f"{prs:<18} {total:<18} {ms*1000:<18.2f} {total/(ms*1e-3)/1e9:<18.2f}")
|
||||
|
||||
|
||||
# -- Bootstrap & main ---------------------------------------------------------
|
||||
|
||||
|
||||
def init_dist() -> mscclpp.CommGroup:
|
||||
addr = os.environ.get("MSCCLPP_MASTER_ADDR")
|
||||
if addr:
|
||||
rank, world = int(os.environ["RANK"]), int(os.environ["WORLD_SIZE"])
|
||||
port = os.environ["MSCCLPP_MASTER_PORT"]
|
||||
iface = _interfaces_for_ip(addr)
|
||||
if not iface:
|
||||
raise ValueError(f"No interface for {addr}")
|
||||
return mscclpp.CommGroup(interfaceIpPortTrio=f"{iface}:{addr}:{port}", rank=rank, size=world)
|
||||
import torch.distributed as dist
|
||||
|
||||
dist.init_process_group(backend="gloo")
|
||||
return mscclpp.CommGroup(torch_group=dist.group.WORLD)
|
||||
|
||||
|
||||
def main():
|
||||
local = int(os.environ["LOCAL_RANK"])
|
||||
torch.cuda.set_device(local)
|
||||
comm = init_dist()
|
||||
comm.benchmark(n_warmup=5, n_graph_launches=10, n_iter_per_graph=100)
|
||||
comm.barrier()
|
||||
|
||||
dtype_str = os.environ.get("DTYPE", "float16")
|
||||
dtype = getattr(torch, dtype_str, torch.float16)
|
||||
accum_map = {"float32": mscclpp.DataType.float32, "float16": mscclpp.DataType.float16}
|
||||
accum_str = os.environ.get("ACCUM_DTYPE")
|
||||
accum_dtype = accum_map.get(accum_str) if accum_str else None
|
||||
|
||||
comm_group = init_dist()
|
||||
cc = CustomizedComm(comm_group)
|
||||
|
||||
print(f"rank {local} starting benchmarks with dtype={dtype} accum_dtype={accum_dtype}...")
|
||||
benchmark_allreduce(cc, dtype=dtype, accum_dtype=accum_dtype)
|
||||
cc.barrier()
|
||||
torch.cuda.synchronize()
|
||||
comm.destroy()
|
||||
print(f"rank {local} All-reduce operation completed successfully.")
|
||||
|
||||
benchmark_allgather(cc, dtype=dtype)
|
||||
cc.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
cc.destroy()
|
||||
print(f"rank {local} completed successfully.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# LD_PRELOAD=<MSCCLPP_REPO>/build/lib/nccl/libmscclpp_nccl.so torchrun --nnodes=1 --nproc_per_node=8 dsl_with_nccl_api.py
|
||||
# LD_PRELOAD=<MSCCLPP_REPO>/build/lib/libmscclpp_nccl.so torchrun --nnodes=1 --nproc_per_node=8 dsl_with_nccl_api.py
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
import torch, torch.distributed as dist
|
||||
import mscclpp
|
||||
import mscclpp.ext
|
||||
from mscclpp.language.collectives import AllReduce
|
||||
from mscclpp.language.channel import SwitchChannel, MemoryChannel, BufferType, SyncType
|
||||
from mscclpp.language.program import CollectiveProgram
|
||||
from mscclpp.language.rank import Rank
|
||||
from mscclpp.language.utils import AlgoSpec
|
||||
|
||||
|
||||
def allreduce_nvls(spec: mscclpp.AlgoSpec) -> CollectiveProgram:
|
||||
def allreduce_nvls(spec: AlgoSpec) -> CollectiveProgram:
|
||||
gpu_size = spec.world_size
|
||||
with CollectiveProgram.from_spec(spec) as program:
|
||||
# Creating Channels
|
||||
@@ -63,8 +64,8 @@ def allreduce_nvls(spec: mscclpp.AlgoSpec) -> CollectiveProgram:
|
||||
return program
|
||||
|
||||
|
||||
def setup_plan(algo_collection_builder: mscclpp.AlgorithmCollectionBuilder, rank: int, world_size: int):
|
||||
spec = mscclpp.AlgoSpec(
|
||||
def setup_plan(algo_collection_builder: mscclpp.ext.AlgorithmCollectionBuilder, rank: int, world_size: int):
|
||||
spec = AlgoSpec(
|
||||
name="allreduce_nvls",
|
||||
collective=AllReduce(8, 1, True),
|
||||
nranks_per_node=8,
|
||||
@@ -94,10 +95,10 @@ def init_dist():
|
||||
rank = int(os.environ["RANK"])
|
||||
world = int(os.environ["WORLD_SIZE"])
|
||||
local = int(os.environ["LOCAL_RANK"])
|
||||
algorithm_collection_builder = mscclpp.AlgorithmCollectionBuilder()
|
||||
algorithm_collection_builder = mscclpp.ext.AlgorithmCollectionBuilder()
|
||||
setup_plan(algorithm_collection_builder, rank, world)
|
||||
algorithm_collection_builder.set_algorithm_selector(selector)
|
||||
dist.init_process_group(backend="nccl", device_id=local)
|
||||
dist.init_process_group(backend="nccl", device_id=torch.device("cuda", local))
|
||||
return rank, world, local
|
||||
|
||||
|
||||
|
||||
@@ -103,12 +103,14 @@ class Algorithm {
|
||||
/// @param nThreadsPerBlock Number of threads per block (0 for auto-selection).
|
||||
/// @param symmetricMemory Whether to use symmetric memory optimization.
|
||||
/// @param extras Additional parameters for algorithm-specific customization.
|
||||
/// @param accumDtype Data type for accumulation during reduction. DataType::AUTO resolves to dtype.
|
||||
/// @return The result of the operation.
|
||||
virtual CommResult execute(std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream,
|
||||
std::shared_ptr<Executor> executor, int nBlocks = 0, int nThreadsPerBlock = 0,
|
||||
bool symmetricMemory = false,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras = {}) = 0;
|
||||
const std::unordered_map<std::string, uintptr_t>& extras = {},
|
||||
DataType accumDtype = DataType::AUTO) = 0;
|
||||
|
||||
/// Reset the algorithm state, clearing any cached contexts.
|
||||
virtual void reset() = 0;
|
||||
@@ -186,10 +188,11 @@ class NativeAlgorithm : public Algorithm {
|
||||
/// @param nBlocks Number of CUDA blocks.
|
||||
/// @param nThreadsPerBlock Number of threads per block.
|
||||
/// @param extras Additional algorithm-specific parameters.
|
||||
/// @param accumDtype Data type for accumulation (resolved from input dtype if sentinel).
|
||||
/// @return The result of the operation.
|
||||
using KernelFunc =
|
||||
std::function<CommResult(const std::shared_ptr<void>, const void*, void*, size_t, size_t, DataType, ReduceOp,
|
||||
cudaStream_t, int, int, const std::unordered_map<std::string, uintptr_t>&)>;
|
||||
cudaStream_t, int, int, const std::unordered_map<std::string, uintptr_t>&, DataType)>;
|
||||
|
||||
/// Function type for creating algorithm contexts.
|
||||
/// @param comm The communicator.
|
||||
@@ -233,8 +236,8 @@ class NativeAlgorithm : public Algorithm {
|
||||
CommResult execute(std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream,
|
||||
std::shared_ptr<Executor> executor, int nBlocks = 0, int nThreadsPerBlock = 0,
|
||||
bool symmetricMemory = false,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras = {}) override;
|
||||
bool symmetricMemory = false, const std::unordered_map<std::string, uintptr_t>& extras = {},
|
||||
DataType accumDtype = DataType::AUTO) override;
|
||||
const std::string& name() const override;
|
||||
const std::string& collective() const override;
|
||||
const std::pair<size_t, size_t>& messageRange() const override;
|
||||
@@ -285,8 +288,8 @@ class DslAlgorithm : public Algorithm, public AlgorithmBuilder, public std::enab
|
||||
CommResult execute(std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream,
|
||||
std::shared_ptr<Executor> executor, int nBlocks = 0, int nThreadsPerBlock = 0,
|
||||
bool symmetricMemory = false,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras = {}) override;
|
||||
bool symmetricMemory = false, const std::unordered_map<std::string, uintptr_t>& extras = {},
|
||||
DataType accumDtype = DataType::AUTO) override;
|
||||
AlgorithmType type() const override { return AlgorithmType::DSL; }
|
||||
Constraint constraint() const override;
|
||||
void reset() override;
|
||||
|
||||
@@ -64,18 +64,151 @@ using __bfloat162 = __nv_bfloat162;
|
||||
|
||||
#endif
|
||||
|
||||
/// Software float8 with 4 exponent bits, 3 mantissa bits, exponent bias = 15.
|
||||
/// Format (MSB first): [sign:1][exponent:4][mantissa:3]
|
||||
/// No infinities; exp=15 is NaN. Negative zero is NaN (fnuz convention).
|
||||
/// Max finite value: 0.9375, min normal: ~6.1e-5, min subnormal: ~7.6e-6.
|
||||
struct alignas(1) __fp8_e4m3b15 {
|
||||
uint8_t __x;
|
||||
|
||||
__fp8_e4m3b15() = default;
|
||||
|
||||
/// Construct from raw bits (use __fp8_e4m3b15::fromRaw() for clarity).
|
||||
MSCCLPP_HOST_DEVICE_INLINE explicit __fp8_e4m3b15(uint8_t raw) : __x(raw) {}
|
||||
|
||||
/// Construct from float32 (explicit to avoid ambiguous conversion chains).
|
||||
MSCCLPP_HOST_DEVICE_INLINE explicit __fp8_e4m3b15(float val) : __x(fromFloat(val)) {}
|
||||
|
||||
/// Convert to float32.
|
||||
MSCCLPP_HOST_DEVICE_INLINE operator float() const { return toFloat(__x); }
|
||||
|
||||
/// Construct from a raw bit pattern without conversion.
|
||||
static MSCCLPP_HOST_DEVICE_INLINE __fp8_e4m3b15 fromRaw(uint8_t bits) {
|
||||
__fp8_e4m3b15 r;
|
||||
r.__x = bits;
|
||||
return r;
|
||||
}
|
||||
|
||||
private:
|
||||
/// Decode fp8_e4m3b15 bits → float32.
|
||||
///
|
||||
/// Uses bit manipulation through fp16 as intermediate, adapted from the Triton compiler.
|
||||
/// fp8_e4m3b15 is identical to fp8_e4m3fn (NVIDIA) except exponent bias is 15 vs 7.
|
||||
/// Algorithm: reinterpret fp8 bits into an fp16 bit pattern with exponent shifted by -8,
|
||||
/// then convert fp16 → float32.
|
||||
static MSCCLPP_HOST_DEVICE_INLINE float toFloat(uint8_t bits) {
|
||||
// Handle special values: negative zero (0x80) → NaN, exponent=15 → NaN.
|
||||
uint32_t exp = (bits >> 3) & 0xFu;
|
||||
if (bits == 0x80 || exp == 15) {
|
||||
union {
|
||||
uint32_t u;
|
||||
float f;
|
||||
} nan_val = {0x7FC00000u};
|
||||
return nan_val.f;
|
||||
}
|
||||
if (bits == 0) return 0.0f;
|
||||
|
||||
// Triton-style bit manipulation: fp8 → fp16 → fp32.
|
||||
// fp8 layout: [S:1][E:4][M:3] (bias=15)
|
||||
// fp16 layout: [S:1][E:5][M:10] (bias=15)
|
||||
//
|
||||
// Place fp8 in upper byte of fp16, then right-shift exponent+mantissa by 1
|
||||
// to convert E4 → E5 (both share bias=15). Sign bit stays at bit 15.
|
||||
// Refer:
|
||||
// https://github.com/triton-lang/triton/blob/cf34004b8a67d290a962da166f5aa2fc66751326/python/triton/language/extra/cuda/utils.py#L34
|
||||
uint16_t h = (uint16_t)bits << 8; // place fp8 in upper byte of fp16
|
||||
uint16_t sign16 = h & 0x8000u; // extract sign at fp16 position
|
||||
uint16_t nosign = h & 0x7F00u; // exponent + mantissa (no sign)
|
||||
uint16_t fp16_bits = sign16 | (nosign >> 1); // shift exponent right by 1
|
||||
|
||||
// For subnormals: when fp8 exponent=0, the above gives fp16 exponent=0
|
||||
// and fp16 mantissa = (fp8_mantissa << 7), which correctly represents
|
||||
// the subnormal fp16 value since both share bias=15.
|
||||
|
||||
// Convert fp16 bits to float via __half (works on host and device, CUDA and HIP).
|
||||
union {
|
||||
uint16_t u;
|
||||
__half h;
|
||||
} cvt = {fp16_bits};
|
||||
return __half2float(cvt.h);
|
||||
}
|
||||
|
||||
/// Encode float32 → fp8_e4m3b15 bits.
|
||||
///
|
||||
/// Algorithm adapted from Triton: float32 → fp16 → bit-manipulate → fp8.
|
||||
/// The key insight is to convert to fp16 first (which shares bias=15 with e4m3b15),
|
||||
/// then pack the fp16 bits back into 8 bits by shifting the exponent left by 1.
|
||||
static MSCCLPP_HOST_DEVICE_INLINE uint8_t fromFloat(float val) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t u;
|
||||
} in = {val};
|
||||
|
||||
// NaN → 0x80 (negative-zero bit pattern = NaN in fnuz).
|
||||
if ((in.u & 0x7F800000u) == 0x7F800000u && (in.u & 0x007FFFFFu) != 0) return 0x80u;
|
||||
|
||||
// Convert float32 → fp16 bits via __half (works on host and device, CUDA and HIP).
|
||||
__half h_val = __float2half_rn(val);
|
||||
union {
|
||||
__half h;
|
||||
uint16_t u;
|
||||
} cvt = {h_val};
|
||||
uint16_t fp16_bits = cvt.u;
|
||||
|
||||
// Clamp absolute value to max finite e4m3b15: 0.9375 → fp16 = 0x3B80.
|
||||
uint16_t abs_fp16 = fp16_bits & 0x7FFFu;
|
||||
if (abs_fp16 > 0x3B80u) abs_fp16 = 0x3B80u;
|
||||
|
||||
// Reconstruct with sign.
|
||||
uint16_t sign16 = fp16_bits & 0x8000u;
|
||||
|
||||
// Triton-style: fp16 → fp8.
|
||||
// fp16 layout: [S:1][E:5][M:10] (bias=15)
|
||||
// fp8 layout: [S:1][E:4][M:3] (bias=15)
|
||||
//
|
||||
// mad.lo.u32 a0, a0, 2, 0x00800080 → (abs_fp16 * 2 + 0x0080)
|
||||
// This shifts left by 1 (undoing the right-shift in decode) and adds rounding bias.
|
||||
// Then: lop3.b32 b0, $1, 0x80008000, a0, 0xea → (sign & 0x8000) | a0
|
||||
// Finally: prmt for byte extraction.
|
||||
//
|
||||
// Simplified for scalar: shift abs_fp16 left by 1, add rounding bias, take upper byte.
|
||||
uint16_t adjusted = (uint16_t)(abs_fp16 * 2u + 0x0080u);
|
||||
// The upper byte now contains [E:4][M:3][round_bit].
|
||||
// Combine with sign and extract.
|
||||
uint16_t with_sign = sign16 | adjusted;
|
||||
uint8_t result = (uint8_t)(with_sign >> 8);
|
||||
|
||||
// Zero → 0x00 (ensure positive zero, not negative zero which is NaN).
|
||||
if ((result & 0x7Fu) == 0) result = 0x00u;
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
/// Packed 2x fp8_e4m3b15 storage.
|
||||
struct alignas(2) __fp8x2_e4m3b15 {
|
||||
uint16_t __x;
|
||||
};
|
||||
|
||||
/// Packed 4x fp8_e4m3b15 storage.
|
||||
struct alignas(4) __fp8x4_e4m3b15 {
|
||||
uint32_t __x;
|
||||
};
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
/// Data types supported by mscclpp operations.
|
||||
enum class DataType {
|
||||
INT32, // 32-bit signed integer.
|
||||
UINT32, // 32-bit unsigned integer.
|
||||
FLOAT16, // IEEE 754 half precision.
|
||||
FLOAT32, // IEEE 754 single precision.
|
||||
BFLOAT16, // bfloat16 precision.
|
||||
FLOAT8_E4M3, // float8 with E4M3 layout.
|
||||
FLOAT8_E5M2, // float8 with E5M2 layout.
|
||||
UINT8, // 8-bit unsigned integer.
|
||||
INT32, // 32-bit signed integer.
|
||||
UINT32, // 32-bit unsigned integer.
|
||||
FLOAT16, // IEEE 754 half precision.
|
||||
FLOAT32, // IEEE 754 single precision.
|
||||
BFLOAT16, // bfloat16 precision.
|
||||
FLOAT8_E4M3, // float8 with E4M3 layout.
|
||||
FLOAT8_E5M2, // float8 with E5M2 layout.
|
||||
UINT8, // 8-bit unsigned integer.
|
||||
FLOAT8_E4M3B15, // float8 with E4M3 layout, bias=15 (software, no HW accel).
|
||||
AUTO = 255, // Sentinel: resolve to the input dtype at runtime.
|
||||
};
|
||||
|
||||
/// Word array.
|
||||
@@ -97,6 +230,7 @@ struct alignas(Bytes) Words<Bytes, false> {};
|
||||
template <typename T, int N, typename StorageT>
|
||||
union alignas(sizeof(T) * N) VectorTypeImpl {
|
||||
static_assert(N > 0, "N must be greater than 0");
|
||||
static_assert(sizeof(StorageT) >= sizeof(T) * N, "StorageT must cover the full vector size");
|
||||
|
||||
T data[N];
|
||||
Words<sizeof(T) * N> words;
|
||||
@@ -127,13 +261,14 @@ union alignas(sizeof(T) * N) VectorTypeImpl {
|
||||
MSCCLPP_HOST_DEVICE_INLINE const T& operator[](int i) const { return data[i]; }
|
||||
};
|
||||
|
||||
// Helper template to get the appropriate vector type for a given element type and count
|
||||
// Helper template to get the appropriate vector type for a given element type and count.
|
||||
template <typename T, int N>
|
||||
struct VectorTypeHelper {
|
||||
using type =
|
||||
VectorTypeImpl<T, N,
|
||||
typename std::conditional_t<N * sizeof(T) == 4, uint32_t,
|
||||
typename std::conditional_t<N * sizeof(T) == 8, uint2, uint4>>>;
|
||||
static constexpr int Bytes = N * sizeof(T);
|
||||
using type = VectorTypeImpl<
|
||||
T, N,
|
||||
std::conditional_t<Bytes == 4, uint32_t,
|
||||
std::conditional_t<Bytes == 8, uint2, std::conditional_t<Bytes <= 16, uint4, Words<Bytes>>>>>;
|
||||
};
|
||||
|
||||
/// Vector type - clean user interface (automatically selects appropriate storage type)
|
||||
@@ -170,6 +305,11 @@ DEFINE_VEC(bf16x4, __bfloat16, 4, uint2);
|
||||
DEFINE_VEC(f16x8, __half, 8, uint4);
|
||||
DEFINE_VEC(bf16x8, __bfloat16, 8, uint4);
|
||||
|
||||
// Aliases for large vector types (>16 bytes) where no native CUDA storage type exists.
|
||||
using f32x8 = VectorType<float, 8>;
|
||||
using f32x16 = VectorType<float, 16>;
|
||||
using f16x16 = VectorType<__half, 16>;
|
||||
|
||||
#if defined(__FP8_TYPES_EXIST__)
|
||||
DEFINE_VEC(f8_e4m3x2, __fp8_e4m3, 2, __fp8x2_e4m3);
|
||||
DEFINE_VEC(f8_e4m3x4, __fp8_e4m3, 4, __fp8x4_e4m3);
|
||||
@@ -181,6 +321,12 @@ DEFINE_VEC(f8_e5m2x4, __fp8_e5m2, 4, __fp8x4_e5m2);
|
||||
DEFINE_VEC(f8_e5m2x8, __fp8_e5m2, 8, uint2);
|
||||
DEFINE_VEC(f8_e5m2x16, __fp8_e5m2, 16, uint4);
|
||||
#endif
|
||||
|
||||
// fp8_e4m3b15 vectors (always available — software type, no HW dependency)
|
||||
DEFINE_VEC(f8_e4m3b15x2, __fp8_e4m3b15, 2, __fp8x2_e4m3b15);
|
||||
DEFINE_VEC(f8_e4m3b15x4, __fp8_e4m3b15, 4, __fp8x4_e4m3b15);
|
||||
DEFINE_VEC(f8_e4m3b15x8, __fp8_e4m3b15, 8, uint2);
|
||||
DEFINE_VEC(f8_e4m3b15x16, __fp8_e4m3b15, 16, uint4);
|
||||
#undef DEFINE_VEC
|
||||
|
||||
#if defined(MSCCLPP_DEVICE_COMPILE)
|
||||
@@ -254,6 +400,21 @@ MSCCLPP_DEVICE_INLINE __fp8_e5m2 clip(__fp8_e5m2 val) {
|
||||
}
|
||||
#endif
|
||||
|
||||
// --- f32x2 arithmetic ---
|
||||
|
||||
template <bool UseClip = true>
|
||||
MSCCLPP_DEVICE_INLINE f32x2 operator+(const f32x2& a, const f32x2& b) {
|
||||
#if defined(MSCCLPP_DEVICE_CUDA) && (__CUDA_ARCH__ >= 1000)
|
||||
// Blackwell (SM 10.0+): packed float2 add in a single instruction.
|
||||
return __fadd2_rn(a.storage, b.storage);
|
||||
#else
|
||||
f32x2 result;
|
||||
result.data[0] = a.data[0] + b.data[0];
|
||||
result.data[1] = a.data[1] + b.data[1];
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <bool UseClip = true>
|
||||
MSCCLPP_DEVICE_INLINE f16x2 operator+(const f16x2& a, const f16x2& b) {
|
||||
__half2 result;
|
||||
@@ -265,6 +426,18 @@ MSCCLPP_DEVICE_INLINE f16x2 operator+(const f16x2& a, const f16x2& b) {
|
||||
return result;
|
||||
}
|
||||
|
||||
template <bool UseClip = true>
|
||||
MSCCLPP_DEVICE_INLINE f16x4 operator+(const f16x4& a, const f16x4& b) {
|
||||
// Decompose into 2× packed __hadd2 (2 instructions instead of 4 scalar __hadd).
|
||||
const f16x2* a2 = reinterpret_cast<const f16x2*>(&a);
|
||||
const f16x2* b2 = reinterpret_cast<const f16x2*>(&b);
|
||||
f16x4 result;
|
||||
f16x2* r2 = reinterpret_cast<f16x2*>(&result);
|
||||
r2[0] = a2[0] + b2[0];
|
||||
r2[1] = a2[1] + b2[1];
|
||||
return result;
|
||||
}
|
||||
|
||||
template <bool UseClip = true>
|
||||
MSCCLPP_DEVICE_INLINE bf16x2 operator+(const bf16x2& a, const bf16x2& b) {
|
||||
__bfloat162 result;
|
||||
@@ -449,6 +622,14 @@ MSCCLPP_DEVICE_INLINE T min(const T& a, const T& b) {
|
||||
return (a < b ? a : b);
|
||||
}
|
||||
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE f32x2 min(const f32x2& a, const f32x2& b) {
|
||||
f32x2 result;
|
||||
result.data[0] = fminf(a.data[0], b.data[0]);
|
||||
result.data[1] = fminf(a.data[1], b.data[1]);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE f16x2 min(const f16x2& a, const f16x2& b) {
|
||||
#if defined(MSCCLPP_DEVICE_HIP)
|
||||
@@ -489,6 +670,51 @@ MSCCLPP_DEVICE_INLINE u8x4 min(const u8x4& a, const u8x4& b) {
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Convert a vector type From to vector type To.
|
||||
/// Primary template with auto-decomposition: vectors with N > 4 elements decompose into x4 chunks,
|
||||
/// vectors with N == 4 decompose into x2 chunks, enabling optimized x2/x4 specializations to be reached.
|
||||
/// Specialized below for optimized FP8 conversion paths at x2/x4 level.
|
||||
template <typename To, typename From>
|
||||
MSCCLPP_DEVICE_INLINE To to(const From& v) {
|
||||
static_assert(To::Size == From::Size, "to<To, From>: vector sizes must match");
|
||||
constexpr int N = From::Size;
|
||||
|
||||
// Auto-decompose: N > 4 → split into x4 chunks
|
||||
if constexpr (N > 4 && N % 4 == 0) {
|
||||
constexpr int nChunks = N / 4;
|
||||
using FromChunk = VectorType<typename From::ElementType, 4>;
|
||||
using ToChunk = VectorType<typename To::ElementType, 4>;
|
||||
const FromChunk* in = reinterpret_cast<const FromChunk*>(&v);
|
||||
To result;
|
||||
ToChunk* out = reinterpret_cast<ToChunk*>(&result);
|
||||
#pragma unroll
|
||||
for (int c = 0; c < nChunks; ++c) {
|
||||
out[c] = to<ToChunk>(in[c]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
// Auto-decompose: N == 4 → split into 2x x2 chunks
|
||||
else if constexpr (N == 4) {
|
||||
using FromChunk = VectorType<typename From::ElementType, 2>;
|
||||
using ToChunk = VectorType<typename To::ElementType, 2>;
|
||||
const FromChunk* in = reinterpret_cast<const FromChunk*>(&v);
|
||||
To result;
|
||||
ToChunk* out = reinterpret_cast<ToChunk*>(&result);
|
||||
out[0] = to<ToChunk>(in[0]);
|
||||
out[1] = to<ToChunk>(in[1]);
|
||||
return result;
|
||||
}
|
||||
// Base case: element-wise conversion
|
||||
else {
|
||||
To result;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result.data[i] = static_cast<typename To::ElementType>(v.data[i]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__FP8_TYPES_EXIST__)
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE __fp8_e4m3 min(const __fp8_e4m3& a, const __fp8_e4m3& b) {
|
||||
@@ -551,7 +777,592 @@ MSCCLPP_DEVICE_INLINE f8_e5m2x4 min(const f8_e5m2x4& a, const f8_e5m2x4& b) {
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// --- f8_e4m3 -> f32 specializations ---
|
||||
|
||||
/// f8_e4m3x2 -> f32x2.
|
||||
/// NVIDIA: fp8 -> half (via __nv_cvt_fp8x2_to_halfraw2) -> float.
|
||||
/// HIP gfx942: fp8 -> float (via __builtin_amdgcn_cvt_pk_f32_fp8).
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE f32x2 to<f32x2, f8_e4m3x2>(const f8_e4m3x2& v) {
|
||||
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
|
||||
auto f = __builtin_amdgcn_cvt_pk_f32_fp8(v.storage.__x, 0);
|
||||
f32x2 result;
|
||||
result.data[0] = f[0];
|
||||
result.data[1] = f[1];
|
||||
return result;
|
||||
#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900
|
||||
__half2_raw h2 = __nv_cvt_fp8x2_to_halfraw2(bit_cast<__nv_fp8x2_storage_t>(v.storage), __NV_E4M3);
|
||||
f32x2 result;
|
||||
result.data[0] = __half2float(bit_cast<__half>(h2.x));
|
||||
result.data[1] = __half2float(bit_cast<__half>(h2.y));
|
||||
return result;
|
||||
#else
|
||||
f32x2 result;
|
||||
result.data[0] = float(v.data[0]);
|
||||
result.data[1] = float(v.data[1]);
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
/// f8_e4m3x4 -> f32x4.
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE f32x4 to<f32x4, f8_e4m3x4>(const f8_e4m3x4& v) {
|
||||
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
|
||||
auto lo = __builtin_amdgcn_cvt_pk_f32_fp8(v.storage.__x, false);
|
||||
auto hi = __builtin_amdgcn_cvt_pk_f32_fp8(v.storage.__x, true);
|
||||
f32x4 result;
|
||||
result.data[0] = lo[0];
|
||||
result.data[1] = lo[1];
|
||||
result.data[2] = hi[0];
|
||||
result.data[3] = hi[1];
|
||||
return result;
|
||||
#else
|
||||
const f8_e4m3x2* pair = reinterpret_cast<const f8_e4m3x2*>(&v);
|
||||
f32x2 lo = to<f32x2>(pair[0]);
|
||||
f32x2 hi = to<f32x2>(pair[1]);
|
||||
f32x4 result;
|
||||
result.data[0] = lo.data[0];
|
||||
result.data[1] = lo.data[1];
|
||||
result.data[2] = hi.data[0];
|
||||
result.data[3] = hi.data[1];
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
// --- f8_e5m2 -> f32 specializations ---
|
||||
|
||||
/// f8_e5m2x2 -> f32x2.
|
||||
/// NVIDIA: fp8 -> half (via __nv_cvt_fp8x2_to_halfraw2) -> float.
|
||||
/// HIP gfx942: bf8 -> float (via __builtin_amdgcn_cvt_pk_f32_bf8).
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE f32x2 to<f32x2, f8_e5m2x2>(const f8_e5m2x2& v) {
|
||||
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
|
||||
auto f = __builtin_amdgcn_cvt_pk_f32_bf8(v.storage.__x, 0);
|
||||
f32x2 result;
|
||||
result.data[0] = f[0];
|
||||
result.data[1] = f[1];
|
||||
return result;
|
||||
#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900
|
||||
__half2_raw h2 = __nv_cvt_fp8x2_to_halfraw2(bit_cast<__nv_fp8x2_storage_t>(v.storage), __NV_E5M2);
|
||||
f32x2 result;
|
||||
result.data[0] = __half2float(bit_cast<__half>(h2.x));
|
||||
result.data[1] = __half2float(bit_cast<__half>(h2.y));
|
||||
return result;
|
||||
#else
|
||||
f32x2 result;
|
||||
result.data[0] = float(v.data[0]);
|
||||
result.data[1] = float(v.data[1]);
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
/// f8_e5m2x4 -> f32x4.
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE f32x4 to<f32x4, f8_e5m2x4>(const f8_e5m2x4& v) {
|
||||
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
|
||||
auto lo = __builtin_amdgcn_cvt_pk_f32_bf8(v.storage.__x, false);
|
||||
auto hi = __builtin_amdgcn_cvt_pk_f32_bf8(v.storage.__x, true);
|
||||
f32x4 result;
|
||||
result.data[0] = lo[0];
|
||||
result.data[1] = lo[1];
|
||||
result.data[2] = hi[0];
|
||||
result.data[3] = hi[1];
|
||||
return result;
|
||||
#else
|
||||
const f8_e5m2x2* pair = reinterpret_cast<const f8_e5m2x2*>(&v);
|
||||
f32x2 lo = to<f32x2>(pair[0]);
|
||||
f32x2 hi = to<f32x2>(pair[1]);
|
||||
f32x4 result;
|
||||
result.data[0] = lo.data[0];
|
||||
result.data[1] = lo.data[1];
|
||||
result.data[2] = hi.data[0];
|
||||
result.data[3] = hi.data[1];
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
// --- f32 -> f8_e4m3 specializations (downcast) ---
|
||||
|
||||
/// f32x2 -> f8_e4m3x2.
|
||||
/// HIP gfx942: float -> fp8 (via __builtin_amdgcn_cvt_pk_fp8_f32).
|
||||
/// NVIDIA SM90+: float -> half -> fp8 (via __nv_cvt_halfraw2_to_fp8x2).
|
||||
/// NVIDIA pre-SM90: float -> half -> fp8 (via __nv_cvt_halfraw_to_fp8, element-wise).
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE f8_e4m3x2 to<f8_e4m3x2, f32x2>(const f32x2& v) {
|
||||
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
|
||||
uint32_t packed = __builtin_amdgcn_cvt_pk_fp8_f32(v.data[0], v.data[1], 0, false);
|
||||
return bit_cast<f8_e4m3x2>(static_cast<__hip_fp8x2_storage_t>(packed));
|
||||
#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900
|
||||
__half2_raw h2;
|
||||
h2.x = bit_cast<unsigned short>(__float2half_rn(v.data[0]));
|
||||
h2.y = bit_cast<unsigned short>(__float2half_rn(v.data[1]));
|
||||
__nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2(h2, __NV_SATFINITE, __NV_E4M3);
|
||||
return bit_cast<f8_e4m3x2>(fp8x2);
|
||||
#elif defined(MSCCLPP_DEVICE_CUDA)
|
||||
__half_raw h0, h1;
|
||||
h0.x = bit_cast<unsigned short>(__float2half_rn(v.data[0]));
|
||||
h1.x = bit_cast<unsigned short>(__float2half_rn(v.data[1]));
|
||||
f8_e4m3x2 result;
|
||||
result.data[0] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h0, __NV_SATFINITE, __NV_E4M3));
|
||||
result.data[1] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h1, __NV_SATFINITE, __NV_E4M3));
|
||||
return result;
|
||||
#else
|
||||
f8_e4m3x2 result;
|
||||
result.data[0] = static_cast<__fp8_e4m3>(v.data[0]);
|
||||
result.data[1] = static_cast<__fp8_e4m3>(v.data[1]);
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
/// f32x4 -> f8_e4m3x4.
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE f8_e4m3x4 to<f8_e4m3x4, f32x4>(const f32x4& v) {
|
||||
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
|
||||
uint32_t packed = __builtin_amdgcn_cvt_pk_fp8_f32(v.data[0], v.data[1], 0, false);
|
||||
packed = __builtin_amdgcn_cvt_pk_fp8_f32(v.data[2], v.data[3], packed, true);
|
||||
return bit_cast<f8_e4m3x4>(packed);
|
||||
#else
|
||||
f32x2 lo, hi;
|
||||
lo.data[0] = v.data[0];
|
||||
lo.data[1] = v.data[1];
|
||||
hi.data[0] = v.data[2];
|
||||
hi.data[1] = v.data[3];
|
||||
f8_e4m3x2 lo_fp8 = to<f8_e4m3x2>(lo);
|
||||
f8_e4m3x2 hi_fp8 = to<f8_e4m3x2>(hi);
|
||||
f8_e4m3x4 result;
|
||||
result.data[0] = lo_fp8.data[0];
|
||||
result.data[1] = lo_fp8.data[1];
|
||||
result.data[2] = hi_fp8.data[0];
|
||||
result.data[3] = hi_fp8.data[1];
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
// --- f32 -> f8_e5m2 specializations (downcast) ---
|
||||
|
||||
/// f32x2 -> f8_e5m2x2.
|
||||
/// HIP gfx942: float -> bf8 (via __builtin_amdgcn_cvt_pk_bf8_f32).
|
||||
/// NVIDIA SM90+: float -> half -> fp8 (via __nv_cvt_halfraw2_to_fp8x2 with __NV_E5M2).
|
||||
/// NVIDIA pre-SM90: float -> half -> fp8 (via __nv_cvt_halfraw_to_fp8, element-wise).
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE f8_e5m2x2 to<f8_e5m2x2, f32x2>(const f32x2& v) {
|
||||
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
|
||||
uint32_t packed = __builtin_amdgcn_cvt_pk_bf8_f32(v.data[0], v.data[1], 0, false);
|
||||
return bit_cast<f8_e5m2x2>(static_cast<__hip_fp8x2_storage_t>(packed));
|
||||
#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900
|
||||
__half2_raw h2;
|
||||
h2.x = bit_cast<unsigned short>(__float2half_rn(v.data[0]));
|
||||
h2.y = bit_cast<unsigned short>(__float2half_rn(v.data[1]));
|
||||
__nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2(h2, __NV_SATFINITE, __NV_E5M2);
|
||||
return bit_cast<f8_e5m2x2>(fp8x2);
|
||||
#elif defined(MSCCLPP_DEVICE_CUDA)
|
||||
__half_raw h0, h1;
|
||||
h0.x = bit_cast<unsigned short>(__float2half_rn(v.data[0]));
|
||||
h1.x = bit_cast<unsigned short>(__float2half_rn(v.data[1]));
|
||||
f8_e5m2x2 result;
|
||||
result.data[0] = bit_cast<__fp8_e5m2>(__nv_cvt_halfraw_to_fp8(h0, __NV_SATFINITE, __NV_E5M2));
|
||||
result.data[1] = bit_cast<__fp8_e5m2>(__nv_cvt_halfraw_to_fp8(h1, __NV_SATFINITE, __NV_E5M2));
|
||||
return result;
|
||||
#else
|
||||
f8_e5m2x2 result;
|
||||
result.data[0] = static_cast<__fp8_e5m2>(v.data[0]);
|
||||
result.data[1] = static_cast<__fp8_e5m2>(v.data[1]);
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
/// f32x4 -> f8_e5m2x4.
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE f8_e5m2x4 to<f8_e5m2x4, f32x4>(const f32x4& v) {
|
||||
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
|
||||
uint32_t packed = __builtin_amdgcn_cvt_pk_bf8_f32(v.data[0], v.data[1], 0, false);
|
||||
packed = __builtin_amdgcn_cvt_pk_bf8_f32(v.data[2], v.data[3], packed, true);
|
||||
return bit_cast<f8_e5m2x4>(packed);
|
||||
#else
|
||||
f32x2 lo, hi;
|
||||
lo.data[0] = v.data[0];
|
||||
lo.data[1] = v.data[1];
|
||||
hi.data[0] = v.data[2];
|
||||
hi.data[1] = v.data[3];
|
||||
f8_e5m2x2 lo_fp8 = to<f8_e5m2x2>(lo);
|
||||
f8_e5m2x2 hi_fp8 = to<f8_e5m2x2>(hi);
|
||||
f8_e5m2x4 result;
|
||||
result.data[0] = lo_fp8.data[0];
|
||||
result.data[1] = lo_fp8.data[1];
|
||||
result.data[2] = hi_fp8.data[0];
|
||||
result.data[3] = hi_fp8.data[1];
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
// --- f8_e4m3 <-> f16 conversion specializations ---
|
||||
|
||||
/// f8_e4m3x2 -> f16x2.
|
||||
/// NVIDIA SM90+: packed intrinsic (1 instruction).
|
||||
/// HIP gfx942: fp8 -> float -> half (via AMD builtin).
|
||||
/// Pre-SM90 / fallback: element-wise scalar conversion.
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE f16x2 to<f16x2, f8_e4m3x2>(const f8_e4m3x2& v) {
|
||||
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
|
||||
auto f = __builtin_amdgcn_cvt_pk_f32_fp8(v.storage.__x, 0);
|
||||
f16x2 result;
|
||||
result.data[0] = __float2half(f[0]);
|
||||
result.data[1] = __float2half(f[1]);
|
||||
return result;
|
||||
#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900
|
||||
__half2_raw h2 = __nv_cvt_fp8x2_to_halfraw2(bit_cast<__nv_fp8x2_storage_t>(v.storage), __NV_E4M3);
|
||||
return bit_cast<f16x2>(h2);
|
||||
#else
|
||||
f16x2 result;
|
||||
result.data[0] = static_cast<__half>(v.data[0]);
|
||||
result.data[1] = static_cast<__half>(v.data[1]);
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
/// f16x2 -> f8_e4m3x2.
|
||||
/// NVIDIA SM90+: packed intrinsic (1 instruction).
|
||||
/// HIP gfx942: half -> float -> fp8 (via AMD builtin).
|
||||
/// Pre-SM90: element-wise scalar conversion.
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE f8_e4m3x2 to<f8_e4m3x2, f16x2>(const f16x2& v) {
|
||||
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
|
||||
float f0 = __half2float(v.data[0]);
|
||||
float f1 = __half2float(v.data[1]);
|
||||
uint32_t packed = __builtin_amdgcn_cvt_pk_fp8_f32(f0, f1, 0, false);
|
||||
return bit_cast<f8_e4m3x2>(static_cast<__hip_fp8x2_storage_t>(packed));
|
||||
#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900
|
||||
__half2_raw h2 = bit_cast<__half2_raw>(v);
|
||||
__nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2(h2, __NV_SATFINITE, __NV_E4M3);
|
||||
return bit_cast<f8_e4m3x2>(fp8x2);
|
||||
#elif defined(MSCCLPP_DEVICE_CUDA)
|
||||
__half_raw h0, h1;
|
||||
h0.x = bit_cast<unsigned short>(v.data[0]);
|
||||
h1.x = bit_cast<unsigned short>(v.data[1]);
|
||||
f8_e4m3x2 result;
|
||||
result.data[0] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h0, __NV_SATFINITE, __NV_E4M3));
|
||||
result.data[1] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h1, __NV_SATFINITE, __NV_E4M3));
|
||||
return result;
|
||||
#else
|
||||
f8_e4m3x2 result;
|
||||
result.data[0] = static_cast<__fp8_e4m3>(v.data[0]);
|
||||
result.data[1] = static_cast<__fp8_e4m3>(v.data[1]);
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif // defined(__FP8_TYPES_EXIST__)
|
||||
|
||||
// --- fp8_e4m3b15 <-> fp16 direct conversion specializations ---
|
||||
// These are the PRIMARY conversions: fp8_b15 <-> fp16 is just a 1-bit exponent shift
|
||||
// (E4 bias=15 <-> E5 bias=15), no precision loss since fp16 has 10 mantissa bits
|
||||
// vs fp8's 3. fp32 conversions are derived by routing through fp16.
|
||||
|
||||
/// f8_e4m3b15x2 -> f16x2.
|
||||
/// Direct fp8 -> fp16 via branch-free bit manipulation.
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE f16x2 to<f16x2, f8_e4m3b15x2>(const f8_e4m3b15x2& v) {
|
||||
#if defined(MSCCLPP_DEVICE_CUDA)
|
||||
uint16_t in = v.storage.__x;
|
||||
// Spread 2 fp8 bytes into packed fp16 pair, adjust exponent E4->E5.
|
||||
uint32_t a0 = ((uint32_t)(in & 0xFFu) << 8) | ((uint32_t)(in >> 8) << 24);
|
||||
uint32_t b0 = (a0 & 0x7f007f00u) >> 1;
|
||||
uint32_t out0 = b0 | (a0 & 0x80008000u);
|
||||
__half2 h;
|
||||
asm("mov.b32 %0, %1;" : "=r"(*reinterpret_cast<uint32_t*>(&h)) : "r"(out0));
|
||||
return h;
|
||||
#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
|
||||
// gfx942: same bit manipulation as CUDA, store packed fp16 bits via words[].
|
||||
uint16_t in = v.storage.__x;
|
||||
uint32_t a0 = ((uint32_t)(in & 0xFFu) << 8) | ((uint32_t)(in >> 8) << 24);
|
||||
uint32_t b0 = (a0 & 0x7f007f00u) >> 1;
|
||||
uint32_t out0 = b0 | (a0 & 0x80008000u);
|
||||
f16x2 result;
|
||||
result.words[0] = out0;
|
||||
return result;
|
||||
#else
|
||||
f16x2 result;
|
||||
result.data[0] = __float2half(float(v.data[0]));
|
||||
result.data[1] = __float2half(float(v.data[1]));
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
/// f8_e4m3b15x4 -> f16x4.
|
||||
/// Uses __byte_perm + lop3 for branch-free vectorized conversion.
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE f16x4 to<f16x4, f8_e4m3b15x4>(const f8_e4m3b15x4& v) {
|
||||
#if defined(MSCCLPP_DEVICE_CUDA)
|
||||
uint32_t in = v.storage.__x;
|
||||
uint32_t a0 = __byte_perm(0u, in, 0x5746u);
|
||||
uint32_t a0_shr = a0 >> 1;
|
||||
uint32_t a0_sign = a0 & 0x80008000u;
|
||||
uint32_t out0;
|
||||
asm("lop3.b32 %0, %1, %2, %3, 0xEA;" : "=r"(out0) : "r"(a0_shr), "r"(0x3f803f80u), "r"(a0_sign));
|
||||
uint32_t a1 = __byte_perm(a0, 0u, 0x2301u);
|
||||
uint32_t a1_shr = a1 >> 1;
|
||||
uint32_t a1_sign = a1 & 0x80008000u;
|
||||
uint32_t out1;
|
||||
asm("lop3.b32 %0, %1, %2, %3, 0xEA;" : "=r"(out1) : "r"(a1_shr), "r"(0x3f803f80u), "r"(a1_sign));
|
||||
f16x4 result;
|
||||
asm("mov.b32 %0, %1;" : "=r"(result.words[0]) : "r"(out0));
|
||||
asm("mov.b32 %0, %1;" : "=r"(result.words[1]) : "r"(out1));
|
||||
return result;
|
||||
#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
|
||||
// gfx942: __byte_perm + bitwise E4→E5 shift (no lop3), store via words[].
|
||||
uint32_t in = v.storage.__x;
|
||||
uint32_t a0 = __byte_perm(0u, in, 0x5746u);
|
||||
uint32_t out0 = ((a0 >> 1) & 0x3f803f80u) | (a0 & 0x80008000u);
|
||||
uint32_t a1 = __byte_perm(a0, 0u, 0x2301u);
|
||||
uint32_t out1 = ((a1 >> 1) & 0x3f803f80u) | (a1 & 0x80008000u);
|
||||
f16x4 result;
|
||||
result.words[0] = out0;
|
||||
result.words[1] = out1;
|
||||
return result;
|
||||
#else
|
||||
f16x4 result;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
result.data[i] = __float2half(float(v.data[i]));
|
||||
}
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
/// f16x2 -> f8_e4m3b15x2.
|
||||
/// Direct fp16 -> fp8 via clamp + exponent shift E5->E4 + pack.
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to<f8_e4m3b15x2, f16x2>(const f16x2& v) {
|
||||
#if defined(MSCCLPP_DEVICE_CUDA)
|
||||
uint32_t in0;
|
||||
asm("mov.b32 %0, %1;" : "=r"(in0) : "r"(*reinterpret_cast<const uint32_t*>(&v)));
|
||||
// Clamp abs to max finite e4m3b15 (0x3B80 = 0.9375 in fp16).
|
||||
uint32_t lo = in0 & 0xFFFFu, hi = in0 >> 16;
|
||||
uint32_t alo = lo & 0x7FFFu, ahi = hi & 0x7FFFu;
|
||||
alo = alo < 0x3B80u ? alo : 0x3B80u;
|
||||
ahi = ahi < 0x3B80u ? ahi : 0x3B80u;
|
||||
uint32_t a0 = alo | (ahi << 16);
|
||||
a0 = a0 * 2u + 0x00800080u;
|
||||
uint32_t b0 = a0 | (in0 & 0x80008000u);
|
||||
uint16_t packed = (uint16_t)(((b0 >> 8) & 0xFFu) | ((b0 >> 16) & 0xFF00u));
|
||||
return bit_cast<f8_e4m3b15x2>(packed);
|
||||
#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
|
||||
// gfx942: read packed fp16 bits, clamp via v_pk_min_u16, shift E5→E4, pack.
|
||||
uint32_t in0 = v.words[0];
|
||||
uint32_t abs0 = in0 & 0x7fff7fffu;
|
||||
uint32_t a0;
|
||||
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3B803B80u));
|
||||
a0 = a0 * 2u + 0x00800080u;
|
||||
uint32_t b0 = a0 | (in0 & 0x80008000u);
|
||||
uint16_t packed = (uint16_t)(((b0 >> 8) & 0xFFu) | ((b0 >> 16) & 0xFF00u));
|
||||
return bit_cast<f8_e4m3b15x2>(packed);
|
||||
#else
|
||||
f8_e4m3b15x2 result;
|
||||
result.data[0] = __fp8_e4m3b15(__half2float(v.data[0]));
|
||||
result.data[1] = __fp8_e4m3b15(__half2float(v.data[1]));
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
/// f16x4 -> f8_e4m3b15x4.
|
||||
/// Uses __vminu2 + lop3 + __byte_perm for branch-free vectorized conversion.
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to<f8_e4m3b15x4, f16x4>(const f16x4& v) {
|
||||
#if defined(MSCCLPP_DEVICE_CUDA)
|
||||
uint32_t in0, in1;
|
||||
asm("mov.b32 %0, %1;" : "=r"(in0) : "r"(v.words[0]));
|
||||
asm("mov.b32 %0, %1;" : "=r"(in1) : "r"(v.words[1]));
|
||||
uint32_t abs0 = in0 & 0x7fff7fffu;
|
||||
uint32_t abs1 = in1 & 0x7fff7fffu;
|
||||
uint32_t a0 = __vminu2(abs0, 0x3B803B80u);
|
||||
uint32_t a1 = __vminu2(abs1, 0x3B803B80u);
|
||||
a0 = a0 * 2u + 0x00800080u;
|
||||
a1 = a1 * 2u + 0x00800080u;
|
||||
uint32_t b0, b1;
|
||||
asm("lop3.b32 %0, %1, %2, %3, 0xf8;" : "=r"(b0) : "r"(a0), "r"(in0), "r"(0x80008000u));
|
||||
asm("lop3.b32 %0, %1, %2, %3, 0xf8;" : "=r"(b1) : "r"(a1), "r"(in1), "r"(0x80008000u));
|
||||
uint32_t packed = __byte_perm(b0, b1, 0x7531u);
|
||||
return bit_cast<f8_e4m3b15x4>(packed);
|
||||
#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
|
||||
// gfx942: read packed fp16 bits, clamp via v_pk_min_u16, shift E5→E4, __byte_perm pack.
|
||||
uint32_t in0 = v.words[0], in1 = v.words[1];
|
||||
uint32_t abs0 = in0 & 0x7fff7fffu, abs1 = in1 & 0x7fff7fffu;
|
||||
uint32_t a0, a1;
|
||||
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3B803B80u));
|
||||
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a1) : "v"(abs1), "v"(0x3B803B80u));
|
||||
a0 = a0 * 2u + 0x00800080u;
|
||||
a1 = a1 * 2u + 0x00800080u;
|
||||
uint32_t b0 = a0 | (in0 & 0x80008000u);
|
||||
uint32_t b1 = a1 | (in1 & 0x80008000u);
|
||||
uint32_t packed = __byte_perm(b0, b1, 0x7531u);
|
||||
return bit_cast<f8_e4m3b15x4>(packed);
|
||||
#else
|
||||
f8_e4m3b15x4 result;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
result.data[i] = __fp8_e4m3b15(__half2float(v.data[i]));
|
||||
}
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
// --- fp8_e4m3b15 <-> f32 conversion specializations (software, always available) ---
|
||||
|
||||
/// f8_e4m3b15x2 -> f32x2.
|
||||
/// Routes through fp16: fp8→fp16 (bit manip) then fp16→f32.
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE f32x2 to<f32x2, f8_e4m3b15x2>(const f8_e4m3b15x2& v) {
|
||||
#if defined(MSCCLPP_DEVICE_CUDA)
|
||||
f16x2 h = to<f16x2, f8_e4m3b15x2>(v);
|
||||
float2 f2 = __half22float2(h);
|
||||
return bit_cast<f32x2>(f2);
|
||||
#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
|
||||
f16x2 h = to<f16x2, f8_e4m3b15x2>(v);
|
||||
f32x2 result;
|
||||
result.data[0] = __half2float(h.data[0]);
|
||||
result.data[1] = __half2float(h.data[1]);
|
||||
return result;
|
||||
#else
|
||||
f32x2 result;
|
||||
result.data[0] = float(v.data[0]);
|
||||
result.data[1] = float(v.data[1]);
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
/// f8_e4m3b15x4 -> f32x4.
|
||||
/// Routes through fp16: fp8→fp16 (bit manip) then fp16→f32.
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE f32x4 to<f32x4, f8_e4m3b15x4>(const f8_e4m3b15x4& v) {
|
||||
#if defined(MSCCLPP_DEVICE_CUDA)
|
||||
f16x4 h = to<f16x4, f8_e4m3b15x4>(v);
|
||||
__half2 h0, h1;
|
||||
asm("mov.b32 %0, %1;" : "=r"(*reinterpret_cast<uint32_t*>(&h0)) : "r"(h.words[0]));
|
||||
asm("mov.b32 %0, %1;" : "=r"(*reinterpret_cast<uint32_t*>(&h1)) : "r"(h.words[1]));
|
||||
float2 f0 = __half22float2(h0);
|
||||
float2 f1 = __half22float2(h1);
|
||||
f32x4 result;
|
||||
result.data[0] = f0.x;
|
||||
result.data[1] = f0.y;
|
||||
result.data[2] = f1.x;
|
||||
result.data[3] = f1.y;
|
||||
return result;
|
||||
#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
|
||||
f16x4 h = to<f16x4, f8_e4m3b15x4>(v);
|
||||
f32x4 result;
|
||||
result.data[0] = __half2float(h.data[0]);
|
||||
result.data[1] = __half2float(h.data[1]);
|
||||
result.data[2] = __half2float(h.data[2]);
|
||||
result.data[3] = __half2float(h.data[3]);
|
||||
return result;
|
||||
#else
|
||||
f32x4 result;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
result.data[i] = float(v.data[i]);
|
||||
}
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
/// f32x2 -> f8_e4m3b15x2.
|
||||
/// Routes through fp16: f32→fp16 then fp16→fp8 (clamp + exponent shift + pack).
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to<f8_e4m3b15x2, f32x2>(const f32x2& v) {
|
||||
#if defined(MSCCLPP_DEVICE_CUDA)
|
||||
float2 f2 = {v.data[0], v.data[1]};
|
||||
__half2 h = __float22half2_rn(f2);
|
||||
return to<f8_e4m3b15x2, f16x2>(h);
|
||||
#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
|
||||
f16x2 h;
|
||||
h.data[0] = __float2half_rn(v.data[0]);
|
||||
h.data[1] = __float2half_rn(v.data[1]);
|
||||
return to<f8_e4m3b15x2, f16x2>(h);
|
||||
#else
|
||||
f8_e4m3b15x2 result;
|
||||
result.data[0] = __fp8_e4m3b15(v.data[0]);
|
||||
result.data[1] = __fp8_e4m3b15(v.data[1]);
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
/// f32x4 -> f8_e4m3b15x4.
|
||||
/// Routes through fp16: f32→fp16 then fp16→fp8 (clamp + exponent shift + pack).
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to<f8_e4m3b15x4, f32x4>(const f32x4& v) {
|
||||
#if defined(MSCCLPP_DEVICE_CUDA)
|
||||
float2 f01 = {v.data[0], v.data[1]};
|
||||
float2 f23 = {v.data[2], v.data[3]};
|
||||
__half2 h01 = __float22half2_rn(f01);
|
||||
__half2 h23 = __float22half2_rn(f23);
|
||||
f16x4 h;
|
||||
asm("mov.b32 %0, %1;" : "=r"(h.words[0]) : "r"(*reinterpret_cast<uint32_t*>(&h01)));
|
||||
asm("mov.b32 %0, %1;" : "=r"(h.words[1]) : "r"(*reinterpret_cast<uint32_t*>(&h23)));
|
||||
return to<f8_e4m3b15x4, f16x4>(h);
|
||||
#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
|
||||
f16x4 h;
|
||||
h.words[0] = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(v.data[0], v.data[1]));
|
||||
h.words[1] = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(v.data[2], v.data[3]));
|
||||
return to<f8_e4m3b15x4, f16x4>(h);
|
||||
#else
|
||||
f8_e4m3b15x4 result;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
result.data[i] = __fp8_e4m3b15(v.data[i]);
|
||||
}
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
// --- fp8_e4m3b15 arithmetic (software, always available) ---
|
||||
|
||||
template <bool UseClip = true>
|
||||
MSCCLPP_DEVICE_INLINE __fp8_e4m3b15 operator+(const __fp8_e4m3b15& a, const __fp8_e4m3b15& b) {
|
||||
return __fp8_e4m3b15(float(a) + float(b));
|
||||
}
|
||||
|
||||
template <bool UseClip = true>
|
||||
MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 operator+(const f8_e4m3b15x2& a, const f8_e4m3b15x2& b) {
|
||||
f8_e4m3b15x2 result;
|
||||
result.data[0] = __fp8_e4m3b15(float(a.data[0]) + float(b.data[0]));
|
||||
result.data[1] = __fp8_e4m3b15(float(a.data[1]) + float(b.data[1]));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <bool UseClip = true>
|
||||
MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 operator+(const f8_e4m3b15x4& a, const f8_e4m3b15x4& b) {
|
||||
f8_e4m3b15x4 result;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
result.data[i] = __fp8_e4m3b15(float(a.data[i]) + float(b.data[i]));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// --- fp8_e4m3b15 min (software) ---
|
||||
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE __fp8_e4m3b15 min(const __fp8_e4m3b15& a, const __fp8_e4m3b15& b) {
|
||||
return __fp8_e4m3b15(fminf(float(a), float(b)));
|
||||
}
|
||||
|
||||
MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 min(const f8_e4m3b15x2& a, const f8_e4m3b15x2& b) {
|
||||
f8_e4m3b15x2 result;
|
||||
result.data[0] = mscclpp::min(a.data[0], b.data[0]);
|
||||
result.data[1] = mscclpp::min(a.data[1], b.data[1]);
|
||||
return result;
|
||||
}
|
||||
|
||||
MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 min(const f8_e4m3b15x4& a, const f8_e4m3b15x4& b) {
|
||||
f8_e4m3b15x4 result;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
result.data[i] = mscclpp::min(a.data[i], b.data[i]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
#endif // MSCCLPP_DEVICE_COMPILE
|
||||
} // namespace mscclpp
|
||||
|
||||
|
||||
@@ -24,4 +24,7 @@ set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME _mscclpp)
|
||||
set_target_properties(mscclpp_py PROPERTIES INSTALL_RPATH "\$ORIGIN/lib")
|
||||
target_link_libraries(mscclpp_py PRIVATE dlpack mscclpp mscclpp_collectives ${GPU_LIBRARIES})
|
||||
target_include_directories(mscclpp_py SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
|
||||
if(MSCCLPP_USE_ROCM)
|
||||
target_compile_definitions(mscclpp_py PRIVATE MSCCLPP_USE_ROCM)
|
||||
endif()
|
||||
install(TARGETS mscclpp_py LIBRARY DESTINATION .)
|
||||
|
||||
@@ -75,15 +75,17 @@ void register_algorithm(nb::module_& m) {
|
||||
[](Algorithm& self, std::shared_ptr<Communicator> comm, uintptr_t input, uintptr_t output,
|
||||
size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op, uintptr_t stream,
|
||||
std::shared_ptr<Executor> executor, int nBlocks, int nThreadsPerBlock, bool symmetricMemory,
|
||||
std::unordered_map<std::string, uintptr_t> extras) {
|
||||
std::unordered_map<std::string, uintptr_t> extras, int32_t accumDtype) {
|
||||
return self.execute(comm, reinterpret_cast<const void*>(input), reinterpret_cast<void*>(output),
|
||||
inputSize, outputSize, dtype, op, reinterpret_cast<cudaStream_t>(stream), executor,
|
||||
nBlocks, nThreadsPerBlock, symmetricMemory, extras);
|
||||
nBlocks, nThreadsPerBlock, symmetricMemory, extras,
|
||||
static_cast<DataType>(accumDtype));
|
||||
},
|
||||
nb::arg("comm"), nb::arg("input"), nb::arg("output"), nb::arg("input_size"), nb::arg("output_size"),
|
||||
nb::arg("dtype"), nb::arg("op") = ReduceOp::NOP, nb::arg("stream") = 0, nb::arg("executor") = nullptr,
|
||||
nb::arg("n_blocks") = 0, nb::arg("n_threads_per_block") = 0, nb::arg("symmetric_memory") = false,
|
||||
nb::arg("extras") = std::unordered_map<std::string, uintptr_t>())
|
||||
nb::arg("extras") = std::unordered_map<std::string, uintptr_t>(),
|
||||
nb::arg("accum_dtype") = static_cast<int32_t>(DataType::AUTO))
|
||||
.def("reset", &Algorithm::reset);
|
||||
|
||||
nb::class_<Algorithm::Constraint>(algorithmClass, "Constraint")
|
||||
|
||||
@@ -47,7 +47,8 @@ void register_core(nb::module_& m) {
|
||||
.value("bfloat16", DataType::BFLOAT16)
|
||||
.value("float8_e4m3", DataType::FLOAT8_E4M3)
|
||||
.value("float8_e5m2", DataType::FLOAT8_E5M2)
|
||||
.value("uint8", DataType::UINT8);
|
||||
.value("uint8", DataType::UINT8)
|
||||
.value("float8_e4m3b15", DataType::FLOAT8_E4M3B15);
|
||||
|
||||
nb::class_<Bootstrap>(m, "CppBootstrap")
|
||||
.def("get_rank", &Bootstrap::getRank)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/function.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
|
||||
@@ -34,6 +34,19 @@ static DLDataType getDlType(std::string type) {
|
||||
return DLDataType{kDLBfloat, 16, 1};
|
||||
} else if (type == "torch.float16") {
|
||||
return DLDataType{kDLFloat, 16, 1};
|
||||
} else if (type == "torch.float8_e4m3fn") {
|
||||
return DLDataType{kDLFloat8_e4m3fn, 8, 1};
|
||||
} else if (type == "torch.float8_e4m3fnuz") {
|
||||
return DLDataType{kDLFloat8_e4m3fnuz, 8, 1};
|
||||
} else if (type == "torch.float8_e5m2") {
|
||||
return DLDataType{kDLFloat8_e5m2, 8, 1};
|
||||
} else if (type == "torch.float8_e5m2fnuz") {
|
||||
return DLDataType{kDLFloat8_e5m2fnuz, 8, 1};
|
||||
} else if (type == "torch.uint8") {
|
||||
return DLDataType{kDLUInt, 8, 1};
|
||||
} else if (type == "fp8_e4m3b15") {
|
||||
// No standard DLPack code for fp8_e4m3b15; store as raw uint8 bytes.
|
||||
return DLDataType{kDLUInt, 8, 1};
|
||||
} else {
|
||||
throw Error("Unsupported type: " + type, ErrorCode::InvalidUsage);
|
||||
}
|
||||
|
||||
@@ -177,6 +177,7 @@ class Algorithm:
|
||||
nthreads_per_block=0,
|
||||
symmetric_memory: bool = False,
|
||||
extras: Optional[Dict[str, int]] = None,
|
||||
accum_dtype: Optional[CppDataType] = None,
|
||||
) -> int:
|
||||
"""Execute the collective algorithm.
|
||||
|
||||
@@ -194,10 +195,14 @@ class Algorithm:
|
||||
nthreads_per_block: Number of threads per block (0 for auto-selection).
|
||||
symmetric_memory: Whether to use symmetric memory optimization (default: False).
|
||||
extras: Additional algorithm-specific parameters.
|
||||
accum_dtype: Data type for accumulation during reduction. If None, defaults to
|
||||
the same as dtype. Use DataType.float32 for high-precision FP8 accumulation.
|
||||
|
||||
Returns:
|
||||
The result code (0 for success).
|
||||
"""
|
||||
merged_extras = dict(extras) if extras is not None else {}
|
||||
accum_dtype = accum_dtype if accum_dtype is not None else dtype
|
||||
return self._algorithm.execute(
|
||||
comm,
|
||||
int(input_buffer),
|
||||
@@ -211,7 +216,8 @@ class Algorithm:
|
||||
nblocks,
|
||||
nthreads_per_block,
|
||||
symmetric_memory,
|
||||
extras if extras is not None else {},
|
||||
merged_extras,
|
||||
int(accum_dtype),
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
|
||||
@@ -140,7 +140,7 @@ class MemoryChannel:
|
||||
|
||||
for tb_id in tb_list:
|
||||
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type)
|
||||
tb_channel_ids = get_program().setup_channel(tb, self)
|
||||
tb_channel_ids = get_program().setup_channel(tb_id, self)
|
||||
op = GetOperation(
|
||||
src_buff=[RemoteChunk(src_chunk.buffer, src_chunk.index, src_chunk.size, tb_chunk_id)],
|
||||
dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)],
|
||||
|
||||
@@ -745,7 +745,7 @@ class ReduceOperation(BaseOperation):
|
||||
remote_dst_buff=self.remote_dst_buff + other.dst_buff,
|
||||
channel_ids=self.channel_ids,
|
||||
put_channel_ids=self.put_channel_ids + other.channel_ids,
|
||||
channel_type=self.channel_type,
|
||||
channel_type=other.channel_type,
|
||||
reduce_operation=self.reduce_operation,
|
||||
tbg_info=self.tbg_info,
|
||||
packet=self.packet,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
mpi4py==4.1.1
|
||||
cupy==13.6.0
|
||||
mpi4py
|
||||
cupy
|
||||
prettytable
|
||||
netifaces
|
||||
pytest
|
||||
|
||||
397
python/test/test_fp8_accum.py
Normal file
397
python/test/test_fp8_accum.py
Normal file
@@ -0,0 +1,397 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# Correctness test for FP8 allreduce with different accumulation types.
|
||||
#
|
||||
# Verifies that FP8 allreduce with higher-precision accumulation produces
|
||||
# results at least as accurate as native FP8 accumulation, by comparing
|
||||
# against a float32 reference.
|
||||
#
|
||||
# Usage:
|
||||
# mpirun -np 8 pytest python/test/test_fp8_accum.py -v
|
||||
|
||||
import cupy as cp
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mscclpp import CommGroup, GpuBuffer, DataType, ReduceOp, is_nvls_supported
|
||||
from mscclpp.ext import AlgorithmCollectionBuilder
|
||||
from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group
|
||||
|
||||
# FP8 E4M3 (hardware) requires SM >= 89 (Ada / Hopper) on NVIDIA GPUs.
|
||||
# On AMD/ROCm (e.g. MI300X), FP8 is supported natively — no skip needed.
|
||||
_is_hip = hasattr(cp.cuda.runtime, "is_hip") and cp.cuda.runtime.is_hip
|
||||
_skip_fp8 = not _is_hip and int(cp.cuda.Device().compute_capability) < 89
|
||||
pytestmark = pytest.mark.skipif(_skip_fp8, reason="FP8 accum tests require SM >= 89 on CUDA")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FP8 E4M3FN helpers (bias=7, no infinity, NaN = exp=15 & mant=7)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def e4m3fn_to_float(uint8_array):
|
||||
"""Decode a cupy uint8 array of E4M3FN bit patterns to float32."""
|
||||
bits = uint8_array.astype(cp.int32)
|
||||
sign = (bits >> 7) & 1
|
||||
exp = (bits >> 3) & 0xF
|
||||
mant = bits & 0x7
|
||||
|
||||
# Normal: (-1)^s * 2^(exp-7) * (1 + mant/8)
|
||||
normal_val = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), (exp - 7).astype(cp.int32))
|
||||
# Subnormal (exp==0): (-1)^s * 2^(-6) * (mant/8)
|
||||
subnormal_val = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-6))
|
||||
|
||||
result = cp.where(exp == 0, subnormal_val, normal_val)
|
||||
result = cp.where(sign == 1, -result, result)
|
||||
# Zero
|
||||
result = cp.where((exp == 0) & (mant == 0), cp.float32(0.0), result)
|
||||
# NaN: exp==15 & mant==7
|
||||
nan_mask = (exp == 15) & (mant == 7)
|
||||
result = cp.where(nan_mask, cp.float32(float("nan")), result)
|
||||
return result
|
||||
|
||||
|
||||
def float_to_e4m3fn(f32_array, chunk_size=65536):
|
||||
"""Encode a cupy float32 array to uint8 E4M3FN bit patterns.
|
||||
|
||||
Uses a lookup-table approach: precompute all 128 positive E4M3FN values,
|
||||
then find nearest match per element via chunked broadcast comparison.
|
||||
"""
|
||||
# Build lookup table of all 128 positive E4M3FN values (0x00..0x7F)
|
||||
all_bytes = cp.arange(128, dtype=cp.uint8)
|
||||
all_floats = e4m3fn_to_float(all_bytes) # (128,) float32
|
||||
# Mark NaN entries as inf so they're never selected as nearest
|
||||
all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats)
|
||||
|
||||
# Clamp input and extract sign
|
||||
clamped = f32_array.astype(cp.float32)
|
||||
clamped = cp.clip(clamped, -448.0, 448.0)
|
||||
signs = (clamped < 0).astype(cp.uint8)
|
||||
absval = cp.abs(clamped)
|
||||
|
||||
result = cp.zeros(absval.shape, dtype=cp.uint8)
|
||||
n = absval.size
|
||||
absval_flat = absval.ravel()
|
||||
result_flat = result.ravel()
|
||||
|
||||
for start in range(0, n, chunk_size):
|
||||
end = min(start + chunk_size, n)
|
||||
chunk = absval_flat[start:end]
|
||||
# (chunk_size, 128) difference matrix
|
||||
diffs = cp.abs(chunk[:, None] - all_floats[None, :])
|
||||
result_flat[start:end] = cp.argmin(diffs, axis=1).astype(cp.uint8)
|
||||
|
||||
# Combine with sign bit
|
||||
result = result_flat.reshape(absval.shape)
|
||||
result = result | (signs << 7)
|
||||
# Handle exact zero
|
||||
result = cp.where(absval == 0, cp.uint8(0), result)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FP8 E4M3B15 helpers (bias=15, max=0.9375, NaN = exp==15 or bits==0x80)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def e4m3b15_to_float(uint8_array):
|
||||
"""Decode a cupy uint8 array of E4M3B15 bit patterns to float32."""
|
||||
bits = uint8_array.astype(cp.int32)
|
||||
sign = (bits >> 7) & 1
|
||||
exp = (bits >> 3) & 0xF
|
||||
mant = bits & 0x7
|
||||
|
||||
# Normal: (-1)^s * 2^(exp-15) * (1 + mant/8)
|
||||
normal_val = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), (exp - 15).astype(cp.int32))
|
||||
# Subnormal (exp==0): (-1)^s * 2^(-14) * (mant/8)
|
||||
subnormal_val = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-14))
|
||||
|
||||
result = cp.where(exp == 0, subnormal_val, normal_val)
|
||||
result = cp.where(sign == 1, -result, result)
|
||||
# Zero
|
||||
result = cp.where((exp == 0) & (mant == 0), cp.float32(0.0), result)
|
||||
# NaN: exp==15 or negative zero (0x80)
|
||||
nan_mask = (exp == 15) | (uint8_array.astype(cp.int32) == 0x80)
|
||||
result = cp.where(nan_mask, cp.float32(float("nan")), result)
|
||||
return result
|
||||
|
||||
|
||||
def float_to_e4m3b15(f32_array, chunk_size=65536):
|
||||
"""Encode a cupy float32 array to uint8 E4M3B15 bit patterns.
|
||||
|
||||
Same lookup-table approach as float_to_e4m3fn.
|
||||
"""
|
||||
# Build lookup table of all 128 positive E4M3B15 values (0x00..0x7F)
|
||||
all_bytes = cp.arange(128, dtype=cp.uint8)
|
||||
all_floats = e4m3b15_to_float(all_bytes) # (128,) float32
|
||||
# Mark NaN entries as inf so they're never selected as nearest
|
||||
all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats)
|
||||
|
||||
# Clamp input and extract sign
|
||||
clamped = f32_array.astype(cp.float32)
|
||||
clamped = cp.clip(clamped, -0.9375, 0.9375)
|
||||
signs = (clamped < 0).astype(cp.uint8)
|
||||
absval = cp.abs(clamped)
|
||||
|
||||
result = cp.zeros(absval.shape, dtype=cp.uint8)
|
||||
n = absval.size
|
||||
absval_flat = absval.ravel()
|
||||
result_flat = result.ravel()
|
||||
|
||||
for start in range(0, n, chunk_size):
|
||||
end = min(start + chunk_size, n)
|
||||
chunk = absval_flat[start:end]
|
||||
# (chunk_size, 128) difference matrix
|
||||
diffs = cp.abs(chunk[:, None] - all_floats[None, :])
|
||||
result_flat[start:end] = cp.argmin(diffs, axis=1).astype(cp.uint8)
|
||||
|
||||
# Combine with sign bit
|
||||
result = result_flat.reshape(absval.shape)
|
||||
result = result | (signs << 7)
|
||||
# Handle exact zero
|
||||
result = cp.where(absval == 0, cp.uint8(0), result)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared test helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def setup_algorithms(mpi_group):
|
||||
"""Build default algorithms and return (comm_group, algo_map, scratch_buf)."""
|
||||
comm_group = CommGroup(mpi_group.comm)
|
||||
scratch = GpuBuffer(1 << 27, dtype=cp.uint8) # 128 MB
|
||||
AlgorithmCollectionBuilder.reset()
|
||||
builder = AlgorithmCollectionBuilder()
|
||||
algorithms = builder.build_default_algorithms(
|
||||
scratch_buffer=scratch.data.ptr,
|
||||
scratch_buffer_size=scratch.nbytes,
|
||||
rank=comm_group.my_rank,
|
||||
)
|
||||
algo_map = {a.name: a for a in algorithms}
|
||||
return comm_group, algo_map, scratch
|
||||
|
||||
|
||||
def run_allreduce(algo, comm_group, buffer, dtype, accum_dtype=None, nblocks=0, nthreads_per_block=0):
|
||||
"""Run allreduce in-place on buffer and return a copy of the result."""
|
||||
ret = algo.execute(
|
||||
comm=comm_group.communicator,
|
||||
input_buffer=buffer.data.ptr,
|
||||
output_buffer=buffer.data.ptr,
|
||||
input_size=buffer.nbytes,
|
||||
output_size=buffer.nbytes,
|
||||
dtype=dtype,
|
||||
op=ReduceOp.SUM,
|
||||
stream=cp.cuda.get_current_stream().ptr,
|
||||
nblocks=nblocks,
|
||||
nthreads_per_block=nthreads_per_block,
|
||||
symmetric_memory=True,
|
||||
accum_dtype=accum_dtype,
|
||||
)
|
||||
cp.cuda.Device().synchronize()
|
||||
assert ret == 0, f"Allreduce failed with error code {ret}"
|
||||
return buffer.copy()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: FP8 E4M3 accumulation correctness
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@parametrize_mpi_groups(8)
|
||||
@pytest.mark.parametrize(
|
||||
"algo_name",
|
||||
[
|
||||
"default_allreduce_packet",
|
||||
"default_allreduce_nvls_packet",
|
||||
"default_allreduce_fullmesh",
|
||||
"default_allreduce_rsag_zero_copy",
|
||||
"default_allreduce_allpair_packet",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("size", [1024, 4096, 16384, 65536, 262144, 1048576])
|
||||
def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
"""Verify that FP8 E4M3 allreduce with higher-precision accumulation is at
|
||||
least as accurate as native FP8 accumulation, across all algorithm variants."""
|
||||
rank = mpi_group.comm.rank
|
||||
world_size = mpi_group.comm.size
|
||||
|
||||
comm_group, algo_map, scratch = setup_algorithms(mpi_group)
|
||||
if algo_name not in algo_map:
|
||||
pytest.skip(f"{algo_name} not available")
|
||||
if "nvls" in algo_name and not is_nvls_supported():
|
||||
pytest.skip(f"{algo_name} requires NVLS which is not supported on this platform")
|
||||
algo = algo_map[algo_name]
|
||||
|
||||
buf = GpuBuffer(size, dtype=cp.uint8)
|
||||
|
||||
accum_configs = [
|
||||
("fp8_native", DataType.float8_e4m3),
|
||||
("float16", DataType.float16),
|
||||
("float32", DataType.float32),
|
||||
]
|
||||
|
||||
# rsag_zero_copy and fullmesh need explicit block/thread counts
|
||||
if "rsag" in algo_name:
|
||||
nb = max(1, min(32, size // (world_size * 32)))
|
||||
nt = 1024
|
||||
elif "fullmesh" in algo_name:
|
||||
nb = 35
|
||||
nt = 512
|
||||
else:
|
||||
nb = 0
|
||||
nt = 0
|
||||
|
||||
errors = {}
|
||||
for accum_label, accum_dtype in accum_configs:
|
||||
# Generate deterministic per-rank data (use numpy to avoid hipRAND issues on ROCm)
|
||||
rng = np.random.RandomState(42 + rank)
|
||||
src_f32 = cp.asarray(rng.randn(size).astype(np.float32))
|
||||
src_f32 = cp.clip(src_f32, -240.0, 240.0)
|
||||
src_fp8 = float_to_e4m3fn(src_f32)
|
||||
|
||||
# Copy into symmetric buffer
|
||||
buf[:] = src_fp8
|
||||
cp.cuda.Device().synchronize()
|
||||
|
||||
# Run allreduce
|
||||
result = run_allreduce(
|
||||
algo,
|
||||
comm_group,
|
||||
buf,
|
||||
dtype=DataType.float8_e4m3,
|
||||
accum_dtype=accum_dtype,
|
||||
nblocks=nb,
|
||||
nthreads_per_block=nt,
|
||||
)
|
||||
result_f32 = e4m3fn_to_float(result)
|
||||
|
||||
# Compute float32 reference: sum all ranks' quantized FP8 inputs in float32
|
||||
ref_f32 = cp.zeros(size, dtype=cp.float32)
|
||||
for r in range(world_size):
|
||||
rng_r = np.random.RandomState(42 + r)
|
||||
rank_data = cp.asarray(rng_r.randn(size).astype(np.float32))
|
||||
rank_data = cp.clip(rank_data, -240.0, 240.0)
|
||||
rank_data_fp8 = float_to_e4m3fn(rank_data)
|
||||
ref_f32 += e4m3fn_to_float(rank_data_fp8)
|
||||
|
||||
# Compute errors
|
||||
abs_err = cp.abs(result_f32 - ref_f32)
|
||||
mean_abs_err = float(cp.mean(abs_err))
|
||||
errors[accum_label] = mean_abs_err
|
||||
|
||||
# Reset between runs
|
||||
algo.reset()
|
||||
|
||||
# Higher-precision accumulation should be at least as accurate as native fp8
|
||||
assert (
|
||||
errors["float16"] <= errors["fp8_native"] + 1e-6
|
||||
), f"float16 accum ({errors['float16']:.6f}) worse than native ({errors['fp8_native']:.6f})"
|
||||
assert (
|
||||
errors["float32"] <= errors["fp8_native"] + 1e-6
|
||||
), f"float32 accum ({errors['float32']:.6f}) worse than native ({errors['fp8_native']:.6f})"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: FP8 E4M3B15 accumulation correctness
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@parametrize_mpi_groups(8)
|
||||
@pytest.mark.parametrize(
|
||||
"algo_name",
|
||||
[
|
||||
"default_allreduce_packet",
|
||||
"default_allreduce_nvls_packet",
|
||||
"default_allreduce_rsag_zero_copy",
|
||||
"default_allreduce_fullmesh",
|
||||
"default_allreduce_allpair_packet",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("size", [1024, 4096, 65536])
|
||||
def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
"""Verify that FP8 E4M3B15 allreduce with higher-precision accumulation is at
|
||||
least as accurate as native E4M3B15 accumulation."""
|
||||
rank = mpi_group.comm.rank
|
||||
world_size = mpi_group.comm.size
|
||||
|
||||
comm_group, algo_map, scratch = setup_algorithms(mpi_group)
|
||||
if algo_name not in algo_map:
|
||||
pytest.skip(f"{algo_name} not available")
|
||||
if "nvls" in algo_name and not is_nvls_supported():
|
||||
pytest.skip(f"{algo_name} requires NVLS which is not supported on this platform")
|
||||
|
||||
algo = algo_map[algo_name]
|
||||
buf = GpuBuffer(size, dtype=cp.uint8)
|
||||
|
||||
accum_configs = [
|
||||
("e4m3b15_native", DataType.float8_e4m3b15),
|
||||
("float16", DataType.float16),
|
||||
("float32", DataType.float32),
|
||||
]
|
||||
|
||||
# rsag_zero_copy needs explicit block/thread counts, scaled to data size
|
||||
if "rsag" in algo_name:
|
||||
nb = max(1, min(32, size // (world_size * 32)))
|
||||
nt = 1024
|
||||
else:
|
||||
nb = 0
|
||||
nt = 0
|
||||
|
||||
errors = {}
|
||||
for accum_label, accum_dtype in accum_configs:
|
||||
# Generate deterministic per-rank random uint8 values in valid e4m3b15 range
|
||||
rng = np.random.RandomState(42 + rank)
|
||||
raw = cp.asarray(rng.randint(0, 0x78, (size,)).astype(np.uint8))
|
||||
signs = cp.asarray(rng.randint(0, 2, (size,)).astype(np.uint8)) << 7
|
||||
src_uint8 = raw | signs
|
||||
# Fix negative zero -> positive zero
|
||||
src_uint8 = cp.where(src_uint8 == 0x80, cp.uint8(0), src_uint8)
|
||||
|
||||
# Copy into symmetric buffer
|
||||
buf[:] = src_uint8
|
||||
cp.cuda.Device().synchronize()
|
||||
|
||||
# Run allreduce
|
||||
result = run_allreduce(
|
||||
algo,
|
||||
comm_group,
|
||||
buf,
|
||||
dtype=DataType.float8_e4m3b15,
|
||||
accum_dtype=accum_dtype,
|
||||
nblocks=nb,
|
||||
nthreads_per_block=nt,
|
||||
)
|
||||
|
||||
# Decode result
|
||||
result_f32 = e4m3b15_to_float(result)
|
||||
|
||||
# Compute float32 reference
|
||||
ref_f32 = cp.zeros(size, dtype=cp.float32)
|
||||
for r in range(world_size):
|
||||
rng_r = np.random.RandomState(42 + r)
|
||||
raw_r = cp.asarray(rng_r.randint(0, 0x78, (size,)).astype(np.uint8))
|
||||
signs_r = cp.asarray(rng_r.randint(0, 2, (size,)).astype(np.uint8)) << 7
|
||||
bits_r = raw_r | signs_r
|
||||
bits_r = cp.where(bits_r == 0x80, cp.uint8(0), bits_r)
|
||||
ref_f32 += e4m3b15_to_float(bits_r)
|
||||
|
||||
# Clamp reference to e4m3b15 representable range
|
||||
ref_f32 = cp.clip(ref_f32, -0.9375, 0.9375)
|
||||
|
||||
# Compute errors (only on valid entries)
|
||||
valid = ~cp.isnan(result_f32) & ~cp.isnan(ref_f32)
|
||||
abs_err = cp.abs(result_f32[valid] - ref_f32[valid])
|
||||
mean_abs_err = float(cp.mean(abs_err)) if abs_err.size > 0 else 0.0
|
||||
errors[accum_label] = mean_abs_err
|
||||
|
||||
algo.reset()
|
||||
|
||||
# Higher-precision accumulation should be at least as accurate as native
|
||||
assert (
|
||||
errors["float16"] <= errors["e4m3b15_native"] + 1e-8
|
||||
), f"float16 accum ({errors['float16']:.8f}) worse than native ({errors['e4m3b15_native']:.8f})"
|
||||
assert (
|
||||
errors["float32"] <= errors["e4m3b15_native"] + 1e-8
|
||||
), f"float32 accum ({errors['float32']:.8f}) worse than native ({errors['e4m3b15_native']:.8f})"
|
||||
@@ -41,7 +41,9 @@ NativeAlgorithm::NativeAlgorithm(std::string name, std::string collective, InitF
|
||||
CommResult NativeAlgorithm::execute(std::shared_ptr<Communicator> comm, const void* input, void* output,
|
||||
size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op,
|
||||
cudaStream_t stream, std::shared_ptr<Executor>, int nBlocks, int nThreadsPerBlock,
|
||||
bool symmetricMemory, const std::unordered_map<std::string, uintptr_t>& extras) {
|
||||
bool symmetricMemory, const std::unordered_map<std::string, uintptr_t>& extras,
|
||||
DataType accumDtype) {
|
||||
if (accumDtype == DataType::AUTO) accumDtype = dtype;
|
||||
if (!initialized_) {
|
||||
initFunc_(comm);
|
||||
initialized_ = true;
|
||||
@@ -53,7 +55,7 @@ CommResult NativeAlgorithm::execute(std::shared_ptr<Communicator> comm, const vo
|
||||
contexts_[ctxKey] = ctx;
|
||||
}
|
||||
return kernelLaunchFunc_(contexts_[ctxKey], input, output, inputSize, outputSize, dtype, op, stream, nBlocks,
|
||||
nThreadsPerBlock, extras);
|
||||
nThreadsPerBlock, extras, accumDtype);
|
||||
}
|
||||
|
||||
const std::string& NativeAlgorithm::name() const { return name_; }
|
||||
@@ -77,10 +79,7 @@ const CollectiveBufferMode& NativeAlgorithm::bufferMode() const { return bufferM
|
||||
|
||||
Algorithm::Constraint NativeAlgorithm::constraint() const { return constraint_; }
|
||||
|
||||
void NativeAlgorithm::reset() {
|
||||
contexts_.clear();
|
||||
initialized_ = false;
|
||||
}
|
||||
void NativeAlgorithm::reset() { contexts_.clear(); }
|
||||
|
||||
void AlgorithmCollection::registerAlgorithm(const std::string collective, const std::string algoName,
|
||||
std::shared_ptr<Algorithm> algorithm) {
|
||||
@@ -166,7 +165,7 @@ Algorithm::Constraint DslAlgorithm::constraint() const { return constraint_; }
|
||||
CommResult DslAlgorithm::execute(std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
size_t outputSize, DataType dtype, ReduceOp, cudaStream_t stream,
|
||||
std::shared_ptr<Executor> executor, int, int, bool,
|
||||
const std::unordered_map<std::string, uintptr_t>&) {
|
||||
const std::unordered_map<std::string, uintptr_t>&, DataType) {
|
||||
if (!executor) {
|
||||
THROW(EXEC, Error, ErrorCode::InvalidUsage, "Executor is null in DslAlgorithm::execute");
|
||||
}
|
||||
@@ -192,6 +191,10 @@ CommResult DslAlgorithm::execute(std::shared_ptr<Communicator> comm, const void*
|
||||
plan_, stream);
|
||||
break;
|
||||
#endif
|
||||
case DataType::FLOAT8_E4M3B15:
|
||||
executor->execute(rank, (__fp8_e4m3b15*)input, (__fp8_e4m3b15*)output, inputSize, outputSize,
|
||||
DataType::FLOAT8_E4M3B15, plan_, stream);
|
||||
break;
|
||||
case DataType::INT32:
|
||||
case DataType::UINT32:
|
||||
executor->execute(rank, (int*)input, (int*)output, inputSize, outputSize, DataType::UINT32, plan_, stream);
|
||||
|
||||
@@ -56,7 +56,7 @@ Endpoint::Impl::Impl(const EndpointConfig& config, Context::Impl& contextImpl)
|
||||
}
|
||||
|
||||
ibQp_ = contextImpl.getIbContext(config_.transport)
|
||||
->createQp(config_.ib.port, gidIndex, config_.ib.maxCqSize, config_.ib.maxCqPollNum,
|
||||
->createQp(config_.ib.port, config_.ib.gidIndex, config_.ib.maxCqSize, config_.ib.maxCqPollNum,
|
||||
config_.ib.maxSendWr, maxRecvWr, config_.ib.maxWrPerSend, ibNoAtomic_);
|
||||
ibQpInfo_ = ibQp_->getInfo();
|
||||
} else if (config_.transport == Transport::Ethernet) {
|
||||
|
||||
@@ -82,6 +82,12 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
|
||||
case DataType::FLOAT8_E5M2:
|
||||
// FP8 is not supported in CUDA execution kernel.
|
||||
break;
|
||||
case DataType::FLOAT8_E4M3B15:
|
||||
// fp8_e4m3b15 is a software type not supported in the CUDA execution kernel.
|
||||
break;
|
||||
case DataType::AUTO:
|
||||
// AUTO is a sentinel resolved before reaching this point; nothing to do.
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -210,7 +210,7 @@ MSCCLPP_DEVICE_INLINE void handleReadReduceSend(const Operation& op, void* input
|
||||
sizeof(int4);
|
||||
void* remoteMemory = static_cast<char*>(memoryChannelBufferPtrs_[op.inputBufferRefs[index + 1].id]);
|
||||
val = mscclpp::read<int4>(remoteMemory, srcOffset + idx);
|
||||
tmp = cal_vector<T, OpType>(tmp, val);
|
||||
tmp = calVector<T, OpType>(tmp, val);
|
||||
}
|
||||
output4[outputOffset4 + idx] = tmp;
|
||||
if constexpr (SendToRemote) {
|
||||
@@ -353,9 +353,9 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPackets(const Operation& op, void* in
|
||||
for (uint32_t index = 0; index < nSrcs; ++index) {
|
||||
PacketType* pkt = (PacketType*)((char*)scratch + scratchOffset_ + 2 * inputOffsets[index]);
|
||||
PacketPayload<PacketType> val = pkt[idx].read(flag_);
|
||||
data = cal_vector<T, OpType>(data, val);
|
||||
data = calVector<T, OpType>(data, val);
|
||||
}
|
||||
data = cal_vector<T, OpType>(data, srcPacketPayload[idx]);
|
||||
data = calVector<T, OpType>(data, srcPacketPayload[idx]);
|
||||
dstPacketPayload[idx] = data;
|
||||
|
||||
if constexpr (SendToRemote) {
|
||||
@@ -394,9 +394,9 @@ MSCCLPP_DEVICE_INLINE void handleReduceCopySendPackets(const Operation& op, void
|
||||
for (uint32_t index = 0; index < nSrcs; ++index) {
|
||||
PacketType* pkt = (PacketType*)((char*)scratch + scratchOffset_ + 2 * inputOffsets[index]);
|
||||
PacketPayload<PacketType> val = pkt[idx].read(flag_);
|
||||
data = cal_vector<T, OpType>(data, val);
|
||||
data = calVector<T, OpType>(data, val);
|
||||
}
|
||||
data = cal_vector<T, OpType>(data, srcPacketPayload[idx]);
|
||||
data = calVector<T, OpType>(data, srcPacketPayload[idx]);
|
||||
dstPacketPayload[idx] = data;
|
||||
PacketType* dst_val = &dstPkt[idx];
|
||||
dst_val->write(data, flag_);
|
||||
@@ -464,7 +464,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(const Operation& op, void* input, vo
|
||||
size_t buffOffset =
|
||||
(inputOffsets[index] + getOffset<ReuseScratch>(outputBufferRefs[index].type, offset)) / sizeof(int4);
|
||||
int4 val = buff4[buffOffset + idx];
|
||||
tmp = cal_vector<T, OpType>(tmp, val);
|
||||
tmp = calVector<T, OpType>(tmp, val);
|
||||
}
|
||||
dst4[dstOffset4 + idx] = tmp;
|
||||
if constexpr (SendToRemote) {
|
||||
@@ -899,6 +899,17 @@ class ExecutionKernel {
|
||||
#endif
|
||||
break;
|
||||
#endif // __FP8_TYPES_EXIST__
|
||||
case DataType::FLOAT8_E4M3B15:
|
||||
executionKernel<__fp8_e4m3b15, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (__fp8_e4m3b15*)src, (__fp8_e4m3b15*)dst, (__fp8_e4m3b15*)scratch, scratchOffset, scratchChunkSize,
|
||||
plan, semaphores, localMemoryIdBegin, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
#else
|
||||
);
|
||||
#endif
|
||||
break;
|
||||
case DataType::UINT8:
|
||||
executionKernel<uint8_t, PacketType, ReuseScratch><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (uint8_t*)src, (uint8_t*)dst, (uint8_t*)scratch, scratchOffset, scratchChunkSize, plan, semaphores,
|
||||
@@ -910,6 +921,10 @@ class ExecutionKernel {
|
||||
);
|
||||
#endif
|
||||
break;
|
||||
case DataType::AUTO:
|
||||
// AUTO is a sentinel that must be resolved before reaching this point.
|
||||
assert(false && "DataType::AUTO must be resolved before kernel launch");
|
||||
break;
|
||||
}
|
||||
}
|
||||
#else // !defined(MSCCLPP_DEVICE_HIP)
|
||||
|
||||
@@ -14,7 +14,7 @@ namespace mscclpp {
|
||||
|
||||
// Generic element-wise calculation helper
|
||||
template <typename T, ReduceOp OpType>
|
||||
MSCCLPP_DEVICE_INLINE T cal_elements(const T& a, const T& b) {
|
||||
MSCCLPP_DEVICE_INLINE T calElements(const T& a, const T& b) {
|
||||
if constexpr (OpType == SUM) {
|
||||
return a + b;
|
||||
} else if constexpr (OpType == MIN) {
|
||||
@@ -24,56 +24,168 @@ MSCCLPP_DEVICE_INLINE T cal_elements(const T& a, const T& b) {
|
||||
}
|
||||
|
||||
// Generic vector reduction helpers
|
||||
template <typename T, ReduceOp OpType>
|
||||
MSCCLPP_DEVICE_INLINE int4 cal_vector_helper(const int4& a, const int4& b) {
|
||||
int4 ret;
|
||||
ret.w = bit_cast<int, T>(cal_elements<T, OpType>(bit_cast<T, int>(a.w), bit_cast<T, int>(b.w)));
|
||||
ret.x = bit_cast<int, T>(cal_elements<T, OpType>(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
|
||||
ret.y = bit_cast<int, T>(cal_elements<T, OpType>(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
|
||||
ret.z = bit_cast<int, T>(cal_elements<T, OpType>(bit_cast<T, int>(a.z), bit_cast<T, int>(b.z)));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T, ReduceOp OpType>
|
||||
MSCCLPP_DEVICE_INLINE uint2 cal_vector_helper(const uint2& a, const uint2& b) {
|
||||
MSCCLPP_DEVICE_INLINE uint2 calVectorHelper(const uint2& a, const uint2& b) {
|
||||
uint2 ret;
|
||||
ret.x = bit_cast<uint32_t, T>(cal_elements<T, OpType>(bit_cast<T, uint32_t>(a.x), bit_cast<T, uint32_t>(b.x)));
|
||||
ret.y = bit_cast<uint32_t, T>(cal_elements<T, OpType>(bit_cast<T, uint32_t>(a.y), bit_cast<T, uint32_t>(b.y)));
|
||||
ret.x = bit_cast<uint32_t, T>(calElements<T, OpType>(bit_cast<T, uint32_t>(a.x), bit_cast<T, uint32_t>(b.x)));
|
||||
ret.y = bit_cast<uint32_t, T>(calElements<T, OpType>(bit_cast<T, uint32_t>(a.y), bit_cast<T, uint32_t>(b.y)));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T, ReduceOp OpType>
|
||||
MSCCLPP_DEVICE_INLINE int cal_vector_helper(const int& a, const int& b) {
|
||||
return bit_cast<int, T>(cal_elements<T, OpType>(bit_cast<T, int>(a), bit_cast<T, int>(b)));
|
||||
/// f32x2 specialization for uint2: uses packed f32x2 operator+ (Blackwell __fadd2_rn when available).
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE uint2 calVectorHelper<f32x2, SUM>(const uint2& a, const uint2& b) {
|
||||
f32x2 fa = bit_cast<f32x2, uint2>(a);
|
||||
f32x2 fb = bit_cast<f32x2, uint2>(b);
|
||||
f32x2 fr = fa + fb;
|
||||
return bit_cast<uint2, f32x2>(fr);
|
||||
}
|
||||
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE uint2 calVectorHelper<f32x2, MIN>(const uint2& a, const uint2& b) {
|
||||
f32x2 fa = bit_cast<f32x2, uint2>(a);
|
||||
f32x2 fb = bit_cast<f32x2, uint2>(b);
|
||||
f32x2 fr = mscclpp::min(fa, fb);
|
||||
return bit_cast<uint2, f32x2>(fr);
|
||||
}
|
||||
|
||||
template <typename T, ReduceOp OpType>
|
||||
MSCCLPP_DEVICE_INLINE uint32_t cal_vector_helper(const uint32_t& a, const uint32_t& b) {
|
||||
return bit_cast<uint32_t, T>(cal_elements<T, OpType>(bit_cast<T, uint32_t>(a), bit_cast<T, uint32_t>(b)));
|
||||
MSCCLPP_DEVICE_INLINE int4 calVectorHelper(const int4& a, const int4& b) {
|
||||
int4 ret;
|
||||
ret.w = bit_cast<int, T>(calElements<T, OpType>(bit_cast<T, int>(a.w), bit_cast<T, int>(b.w)));
|
||||
ret.x = bit_cast<int, T>(calElements<T, OpType>(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
|
||||
ret.y = bit_cast<int, T>(calElements<T, OpType>(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
|
||||
ret.z = bit_cast<int, T>(calElements<T, OpType>(bit_cast<T, int>(a.z), bit_cast<T, int>(b.z)));
|
||||
return ret;
|
||||
}
|
||||
|
||||
// cal_vector wrapper - converts scalar types to vector types and calls cal_vector_helper
|
||||
/// f32x2 specialization for int4: process as two uint2 pairs using packed f32x2 arithmetic.
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE int4 calVectorHelper<f32x2, SUM>(const int4& a, const int4& b) {
|
||||
uint2 lo_a = {(uint32_t)a.x, (uint32_t)a.y};
|
||||
uint2 hi_a = {(uint32_t)a.z, (uint32_t)a.w};
|
||||
uint2 lo_b = {(uint32_t)b.x, (uint32_t)b.y};
|
||||
uint2 hi_b = {(uint32_t)b.z, (uint32_t)b.w};
|
||||
uint2 lo_r = calVectorHelper<f32x2, SUM>(lo_a, lo_b);
|
||||
uint2 hi_r = calVectorHelper<f32x2, SUM>(hi_a, hi_b);
|
||||
return {(int)lo_r.x, (int)lo_r.y, (int)hi_r.x, (int)hi_r.y};
|
||||
}
|
||||
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE int4 calVectorHelper<f32x2, MIN>(const int4& a, const int4& b) {
|
||||
uint2 lo_a = {(uint32_t)a.x, (uint32_t)a.y};
|
||||
uint2 hi_a = {(uint32_t)a.z, (uint32_t)a.w};
|
||||
uint2 lo_b = {(uint32_t)b.x, (uint32_t)b.y};
|
||||
uint2 hi_b = {(uint32_t)b.z, (uint32_t)b.w};
|
||||
uint2 lo_r = calVectorHelper<f32x2, MIN>(lo_a, lo_b);
|
||||
uint2 hi_r = calVectorHelper<f32x2, MIN>(hi_a, hi_b);
|
||||
return {(int)lo_r.x, (int)lo_r.y, (int)hi_r.x, (int)hi_r.y};
|
||||
}
|
||||
|
||||
template <typename T, ReduceOp OpType>
|
||||
MSCCLPP_DEVICE_INLINE int calVectorHelper(const int& a, const int& b) {
|
||||
return bit_cast<int, T>(calElements<T, OpType>(bit_cast<T, int>(a), bit_cast<T, int>(b)));
|
||||
}
|
||||
|
||||
template <typename T, ReduceOp OpType>
|
||||
MSCCLPP_DEVICE_INLINE uint32_t calVectorHelper(const uint32_t& a, const uint32_t& b) {
|
||||
return bit_cast<uint32_t, T>(calElements<T, OpType>(bit_cast<T, uint32_t>(a), bit_cast<T, uint32_t>(b)));
|
||||
}
|
||||
|
||||
/// f32x2 specialization for uint32_t: a single float packed in 32 bits (scalar fallback).
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE uint32_t calVectorHelper<f32x2, SUM>(const uint32_t& a, const uint32_t& b) {
|
||||
float fa = bit_cast<float, uint32_t>(a);
|
||||
float fb = bit_cast<float, uint32_t>(b);
|
||||
return bit_cast<uint32_t, float>(fa + fb);
|
||||
}
|
||||
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE uint32_t calVectorHelper<f32x2, MIN>(const uint32_t& a, const uint32_t& b) {
|
||||
float fa = bit_cast<float, uint32_t>(a);
|
||||
float fb = bit_cast<float, uint32_t>(b);
|
||||
return bit_cast<uint32_t, float>(fminf(fa, fb));
|
||||
}
|
||||
|
||||
// calVector wrapper – converts scalar types to vector types and calls calVectorHelper
|
||||
template <typename T, ReduceOp OpType, typename DataType>
|
||||
MSCCLPP_DEVICE_INLINE DataType cal_vector(const DataType& a, const DataType& b) {
|
||||
MSCCLPP_DEVICE_INLINE DataType calVector(const DataType& a, const DataType& b) {
|
||||
// Define the vectorized computation type based on the element type
|
||||
static_assert(sizeof(DataType) % sizeof(T) == 0, "DataType size must be multiple of T size");
|
||||
static_assert(sizeof(DataType) >= 4, "DataType size must be at least 4 bytes");
|
||||
using CompType = typename std::conditional_t<
|
||||
std::is_same_v<T, __half>, f16x2,
|
||||
std::is_same_v<T, float>, f32x2,
|
||||
std::conditional_t<
|
||||
std::is_same_v<T, __bfloat16>, bf16x2,
|
||||
std::conditional_t<std::is_same_v<T, uint8_t>, u8x4,
|
||||
std::is_same_v<T, __half>, f16x2,
|
||||
std::conditional_t<
|
||||
std::is_same_v<T, __bfloat16>, bf16x2,
|
||||
std::conditional_t<
|
||||
std::is_same_v<T, uint8_t>, u8x4,
|
||||
std::conditional_t<std::is_same_v<T, __fp8_e4m3b15>, f8_e4m3b15x4,
|
||||
#if defined(__FP8_TYPES_EXIST__)
|
||||
std::conditional_t<std::is_same_v<T, __fp8_e4m3>, f8_e4m3x4,
|
||||
std::conditional_t<std::is_same_v<T, __fp8_e5m2>, f8_e5m2x4,
|
||||
#endif
|
||||
T
|
||||
#if defined(__FP8_TYPES_EXIST__)
|
||||
>>>>>;
|
||||
std::conditional_t<std::is_same_v<T, __fp8_e4m3>, f8_e4m3x4,
|
||||
std::conditional_t<std::is_same_v<T, __fp8_e5m2>, f8_e5m2x4, T>>
|
||||
#else
|
||||
>>>;
|
||||
T
|
||||
#endif
|
||||
return cal_vector_helper<CompType, OpType>(a, b);
|
||||
>>>>>;
|
||||
return calVectorHelper<CompType, OpType>(a, b);
|
||||
}
|
||||
|
||||
/// Upcast a packed DataType (containing T elements) to a packed AccDataType (containing AccumT elements).
|
||||
/// Uses the optimized to<>() specializations when available (e.g. FP8 -> float hardware intrinsics).
|
||||
/// When AccumT == T, this is a no-op identity.
|
||||
template <typename T, typename AccumT, typename AccDataType, typename DataType>
|
||||
MSCCLPP_DEVICE_INLINE AccDataType upcastVector(const DataType& val) {
|
||||
if constexpr (std::is_same_v<T, AccumT>) {
|
||||
return val;
|
||||
} else {
|
||||
constexpr int nElems = sizeof(DataType) / sizeof(T);
|
||||
using FromVec = VectorType<T, nElems>;
|
||||
using ToVec = VectorType<AccumT, nElems>;
|
||||
ToVec result = mscclpp::to<ToVec>(reinterpret_cast<const FromVec&>(val));
|
||||
return reinterpret_cast<const AccDataType&>(result);
|
||||
}
|
||||
}
|
||||
|
||||
/// Downcast a packed AccDataType (containing AccumT elements) back to DataType (containing T elements).
|
||||
/// Uses the optimized to<>() specializations when available.
|
||||
/// When AccumT == T, this is a no-op identity.
|
||||
template <typename T, typename AccumT, typename DataType, typename AccDataType>
|
||||
MSCCLPP_DEVICE_INLINE DataType downcastVector(const AccDataType& val) {
|
||||
if constexpr (std::is_same_v<T, AccumT>) {
|
||||
return val;
|
||||
} else {
|
||||
constexpr int nElems = sizeof(DataType) / sizeof(T);
|
||||
using FromVec = VectorType<T, nElems>;
|
||||
using ToVec = VectorType<AccumT, nElems>;
|
||||
FromVec result = mscclpp::to<FromVec>(reinterpret_cast<const ToVec&>(val));
|
||||
return reinterpret_cast<const DataType&>(result);
|
||||
}
|
||||
}
|
||||
|
||||
/// Accumulate `val` (packed T elements in DataType) into `acc` (packed AccumT elements in AccDataType).
|
||||
/// When AccumT == T, falls back to the standard calVector.
|
||||
/// Otherwise, upcasts val to AccumT, reduces element-wise, and returns the AccumT accumulator.
|
||||
template <typename T, typename AccumT, ReduceOp OpType, typename AccDataType, typename DataType>
|
||||
MSCCLPP_DEVICE_INLINE AccDataType calVectorAccum(const AccDataType& acc, const DataType& val) {
|
||||
if constexpr (std::is_same_v<T, AccumT>) {
|
||||
return calVector<T, OpType>(acc, val);
|
||||
} else {
|
||||
constexpr int nElems = sizeof(DataType) / sizeof(T);
|
||||
using FromVec = VectorType<T, nElems>;
|
||||
using ToVec = VectorType<AccumT, nElems>;
|
||||
|
||||
ToVec fv = mscclpp::to<ToVec>(reinterpret_cast<const FromVec&>(val));
|
||||
const ToVec& fa = reinterpret_cast<const ToVec&>(acc);
|
||||
ToVec fr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < nElems; ++i) {
|
||||
fr.data[i] = calElements<AccumT, OpType>(fa.data[i], fv.data[i]);
|
||||
}
|
||||
return reinterpret_cast<const AccDataType&>(fr);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // defined(MSCCLPP_DEVICE_COMPILE)
|
||||
|
||||
@@ -183,7 +183,8 @@ std::shared_ptr<Algorithm> AllgatherFullmesh::build() {
|
||||
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize, [[maybe_unused]] DataType dtype, [[maybe_unused]] ReduceOp op,
|
||||
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras) -> CommResult {
|
||||
const std::unordered_map<std::string, uintptr_t>& extras,
|
||||
[[maybe_unused]] DataType accumDtype) -> CommResult {
|
||||
return self->allgatherKernelFunc(ctx, input, output, inputSize, stream, nBlocks, nThreadsPerBlock, extras);
|
||||
},
|
||||
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
|
||||
@@ -212,7 +212,8 @@ std::shared_ptr<Algorithm> AllgatherFullmesh2::build() {
|
||||
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize, [[maybe_unused]] mscclpp::DataType dtype, [[maybe_unused]] ReduceOp op,
|
||||
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras) -> mscclpp::CommResult {
|
||||
const std::unordered_map<std::string, uintptr_t>& extras,
|
||||
[[maybe_unused]] mscclpp::DataType accumDtype) -> mscclpp::CommResult {
|
||||
return self->allgatherKernelFunc(ctx, input, output, inputSize, stream, nBlocks, nThreadsPerBlock, extras);
|
||||
},
|
||||
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <collective_utils.hpp>
|
||||
#include <type_traits>
|
||||
|
||||
#include "allreduce/allreduce_allpair_packet.hpp"
|
||||
#include "allreduce/common.hpp"
|
||||
@@ -11,7 +12,7 @@
|
||||
namespace mscclpp {
|
||||
namespace collective {
|
||||
|
||||
template <ReduceOp OpType, typename T>
|
||||
template <ReduceOp OpType, typename T, typename AccumT = T>
|
||||
__global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHandle<MemoryChannel>* memoryChannels,
|
||||
size_t channelDataOffset, size_t scratchBufferSize, int rank, int nRanksPerNode,
|
||||
int worldSize, size_t nelems, uint32_t numScratchBuff, void* flags,
|
||||
@@ -43,13 +44,16 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand
|
||||
// step 2: Reduce Data
|
||||
for (size_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nelems; idx += blockDim.x * gridDim.x) {
|
||||
uint32_t data = src[idx];
|
||||
using AccRaw = std::conditional_t<std::is_same_v<T, AccumT>, uint32_t,
|
||||
mscclpp::VectorType<AccumT, sizeof(uint32_t) / sizeof(T)>>;
|
||||
AccRaw acc = mscclpp::upcastVector<T, AccumT, AccRaw>(data);
|
||||
for (int index = 0; index < nPeers; index++) {
|
||||
const int remoteRank = index < rank ? index : index + 1;
|
||||
LL8Packet* dstPkt = (LL8Packet*)scratchBuff + remoteRank * nelems;
|
||||
uint32_t val = dstPkt[idx].read(flag, -1);
|
||||
data = cal_vector<T, OpType>(val, data);
|
||||
acc = mscclpp::calVectorAccum<T, AccumT, OpType, AccRaw>(acc, val);
|
||||
}
|
||||
dst[idx] = data;
|
||||
dst[idx] = mscclpp::downcastVector<T, AccumT, uint32_t>(acc);
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
@@ -67,7 +71,7 @@ inline std::pair<int, int> getDefaultBlockNumAndThreadNum(size_t inputSize, int
|
||||
return {(worldSize - 1) * 4, 512};
|
||||
}
|
||||
|
||||
template <ReduceOp OpType, typename T>
|
||||
template <ReduceOp OpType, typename T, typename AccumT = T>
|
||||
struct AllpairAdapter {
|
||||
static cudaError_t call(const void* buff, void* scratch, void* resultBuff, void* memoryChannels, void*,
|
||||
DeviceHandle<SwitchChannel>*, DeviceHandle<SwitchChannel>*, size_t channelInOffset, size_t,
|
||||
@@ -76,7 +80,12 @@ struct AllpairAdapter {
|
||||
int nThreadsPerBlock = 0) {
|
||||
using ChannelType = DeviceHandle<MemoryChannel>;
|
||||
const size_t nelems = inputSize / sizeof(T);
|
||||
allreduceAllPairs<OpType, T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
|
||||
// Round nBlocks to multiple of nPeers so every block maps to a valid peer.
|
||||
const int nPeers = worldSize - 1;
|
||||
if (nPeers > 0) {
|
||||
nBlocks = (nBlocks / nPeers) * nPeers;
|
||||
}
|
||||
allreduceAllPairs<OpType, T, AccumT><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
|
||||
(T*)buff, (T*)scratch, (T*)resultBuff, (ChannelType*)memoryChannels, channelInOffset, scratchBufferSize, rank,
|
||||
nRanksPerNode, worldSize, nelems, numScratchBuff, flags, flagSize);
|
||||
return cudaGetLastError();
|
||||
@@ -94,18 +103,24 @@ void AllreduceAllpairPacket::initialize(std::shared_ptr<Communicator> comm) {
|
||||
CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output,
|
||||
size_t inputSize, [[maybe_unused]] DataType dtype, ReduceOp op,
|
||||
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>&) {
|
||||
const std::unordered_map<std::string, uintptr_t>&,
|
||||
DataType accumDtype) {
|
||||
auto algoCtx = std::static_pointer_cast<AlgorithmCtx>(ctx);
|
||||
std::pair<int, int> blockAndThreadNum{nBlocks, nThreadsPerBlock};
|
||||
if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) {
|
||||
blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize, algoCtx->workSize);
|
||||
}
|
||||
// nBlocks must be at least nPeers for allpair — each block maps to one peer.
|
||||
const int nPeers = algoCtx->nRanksPerNode - 1;
|
||||
if (nPeers > 0 && blockAndThreadNum.first < nPeers) {
|
||||
return CommResult::CommInvalidArgument;
|
||||
}
|
||||
size_t sendBytes;
|
||||
CUdeviceptr sendBasePtr;
|
||||
MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)input));
|
||||
size_t channelInOffset = (char*)input - (char*)sendBasePtr;
|
||||
|
||||
AllreduceFunc allreduce = dispatch<AllpairAdapter>(op, dtype);
|
||||
AllreduceFunc allreduce = dispatch<AllpairAdapter>(op, dtype, accumDtype);
|
||||
if (!allreduce) {
|
||||
WARN("Unsupported operation or data type for allreduce: op=%d, dtype=%d", op, static_cast<int>(dtype));
|
||||
return CommResult::CommInvalidArgument;
|
||||
@@ -161,9 +176,9 @@ std::shared_ptr<Algorithm> AllreduceAllpairPacket::build() {
|
||||
[self](std::shared_ptr<Communicator> comm) { self->initialize(comm); },
|
||||
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) {
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype) {
|
||||
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
|
||||
extras);
|
||||
extras, accumDtype);
|
||||
},
|
||||
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize,
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
namespace mscclpp {
|
||||
namespace collective {
|
||||
|
||||
template <ReduceOp OpType, typename T>
|
||||
template <ReduceOp OpType, typename T, typename AccumT = T>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
allreduceFullmesh(T* buff, T* scratch, T* resultBuff, DeviceHandle<MemoryChannel>* memoryChannels,
|
||||
DeviceHandle<MemoryChannel>* memoryOutChannels, size_t channelOutDataOffset, int rank,
|
||||
@@ -26,6 +26,10 @@ __global__ void __launch_bounds__(512, 1)
|
||||
int4* scratch4 = reinterpret_cast<int4*>((char*)scratch);
|
||||
int4* resultBuff4 = reinterpret_cast<int4*>(resultBuff);
|
||||
|
||||
// AccumVec: wider vector for mixed-precision accumulation. When AccumT==T, this is just int4 (no-op).
|
||||
constexpr int nElemsPerInt4 = sizeof(int4) / sizeof(T);
|
||||
using AccumVec = std::conditional_t<std::is_same_v<T, AccumT>, int4, mscclpp::VectorType<AccumT, nElemsPerInt4>>;
|
||||
|
||||
// Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4`
|
||||
constexpr size_t unitNInt4 = 512;
|
||||
const size_t maxNInt4PerBlock =
|
||||
@@ -81,12 +85,14 @@ __global__ void __launch_bounds__(512, 1)
|
||||
__syncthreads();
|
||||
|
||||
for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) {
|
||||
int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock];
|
||||
int4 rawData = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock];
|
||||
AccumVec acc = mscclpp::upcastVector<T, AccumT, AccumVec>(rawData);
|
||||
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
|
||||
const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
|
||||
int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx];
|
||||
data = cal_vector<T, OpType>(val, data);
|
||||
acc = mscclpp::calVectorAccum<T, AccumT, OpType, AccumVec>(acc, val);
|
||||
}
|
||||
int4 data = mscclpp::downcastVector<T, AccumT, int4>(acc);
|
||||
resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data;
|
||||
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
|
||||
outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4),
|
||||
@@ -121,12 +127,14 @@ __global__ void __launch_bounds__(512, 1)
|
||||
__syncthreads();
|
||||
|
||||
for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) {
|
||||
int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock];
|
||||
int4 rawData = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock];
|
||||
AccumVec acc = mscclpp::upcastVector<T, AccumT, AccumVec>(rawData);
|
||||
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
|
||||
const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
|
||||
int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx];
|
||||
data = cal_vector<T, OpType>(val, data);
|
||||
acc = mscclpp::calVectorAccum<T, AccumT, OpType, AccumVec>(acc, val);
|
||||
}
|
||||
int4 data = mscclpp::downcastVector<T, AccumT, int4>(acc);
|
||||
resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data;
|
||||
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
|
||||
outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4),
|
||||
@@ -144,7 +152,7 @@ __global__ void __launch_bounds__(512, 1)
|
||||
}
|
||||
}
|
||||
|
||||
template <ReduceOp OpType, typename T>
|
||||
template <ReduceOp OpType, typename T, typename AccumT = T>
|
||||
struct AllreduceAllconnectAdapter {
|
||||
static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* memoryOutChannels,
|
||||
DeviceHandle<SwitchChannel>*, DeviceHandle<SwitchChannel>*, size_t,
|
||||
@@ -155,7 +163,7 @@ struct AllreduceAllconnectAdapter {
|
||||
size_t nelems = inputSize / sizeof(T);
|
||||
if (nBlocks == 0) nBlocks = 35;
|
||||
if (nThreadsPerBlock == 0) nThreadsPerBlock = 512;
|
||||
allreduceFullmesh<OpType, T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
|
||||
allreduceFullmesh<OpType, T, AccumT><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
|
||||
(T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, (ChannelType*)memoryOutChannels,
|
||||
channelOutDataOffset, rank, nRanksPerNode, worldSize, nelems);
|
||||
return cudaGetLastError();
|
||||
@@ -174,10 +182,10 @@ void AllreduceFullmesh::initialize(std::shared_ptr<Communicator> comm) {
|
||||
localScratchMemory_ = std::move(localMemory);
|
||||
}
|
||||
|
||||
CommResult AllreduceFullmesh::allreduceKernelFunc(const std::shared_ptr<void> ctx_void, const void* input, void* output,
|
||||
size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream,
|
||||
int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>&) {
|
||||
CommResult AllreduceFullmesh::allreduceKernelFunc(
|
||||
const std::shared_ptr<void> ctx_void, const void* input, void* output, size_t inputSize, DataType dtype,
|
||||
ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
[[maybe_unused]] const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype) {
|
||||
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
|
||||
size_t recvBytes;
|
||||
CUdeviceptr recvBasePtr;
|
||||
@@ -198,13 +206,20 @@ CommResult AllreduceFullmesh::allreduceKernelFunc(const std::shared_ptr<void> ct
|
||||
}
|
||||
inputChannelHandles = this->memoryChannelsMap_[input].second;
|
||||
|
||||
AllreduceFunc allreduce = dispatch<AllreduceAllconnectAdapter>(op, dtype);
|
||||
AllreduceFunc allreduce = dispatch<AllreduceAllconnectAdapter>(op, dtype, accumDtype);
|
||||
if (!allreduce) {
|
||||
WARN("Unsupported operation or data type for allreduce: op=%d, dtype=%d", static_cast<int>(op),
|
||||
static_cast<int>(dtype));
|
||||
return CommResult::CommInvalidArgument;
|
||||
}
|
||||
std::pair<int, int> numBlocksAndThreads = {nBlocks, nThreadsPerBlock};
|
||||
if (numBlocksAndThreads.first > 64) {
|
||||
WARN("AllreduceFullmesh: number of blocks exceeds maximum supported blocks, which is 64");
|
||||
return mscclpp::CommResult::CommInvalidArgument;
|
||||
}
|
||||
if (numBlocksAndThreads.first == 0 || numBlocksAndThreads.second == 0) {
|
||||
numBlocksAndThreads = {35, 512};
|
||||
}
|
||||
cudaError_t error =
|
||||
allreduce(input, this->scratchBuffer_, output, inputChannelHandles.get(), ctx->memoryChannelDeviceHandles.get(),
|
||||
nullptr, nullptr, 0, channelOutOffset, 0, ctx->rank, ctx->nRanksPerNode, ctx->workSize, inputSize,
|
||||
@@ -261,9 +276,10 @@ std::shared_ptr<Algorithm> AllreduceFullmesh::build() {
|
||||
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
|
||||
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) -> CommResult {
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras,
|
||||
DataType accumDtype) -> CommResult {
|
||||
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
|
||||
extras);
|
||||
extras, accumDtype);
|
||||
},
|
||||
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize,
|
||||
|
||||
@@ -146,7 +146,7 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <ReduceOp OpType, typename T>
|
||||
template <ReduceOp OpType, typename T, typename AccumT = T>
|
||||
struct NvlsBlockPipelineAdapter {
|
||||
static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void*,
|
||||
DeviceHandle<SwitchChannel>* nvlsChannels, DeviceHandle<SwitchChannel>*, size_t, size_t,
|
||||
@@ -155,6 +155,9 @@ struct NvlsBlockPipelineAdapter {
|
||||
// uint8_t is not supported for NVLS (no hardware support for byte-level reduction)
|
||||
if constexpr (std::is_same_v<T, uint8_t>) {
|
||||
return cudaErrorNotSupported;
|
||||
} else if constexpr (std::is_same_v<T, __fp8_e4m3b15>) {
|
||||
// fp8_e4m3b15 is a software-only type with no hardware NVLS support.
|
||||
return cudaErrorNotSupported;
|
||||
} else
|
||||
#if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS
|
||||
if constexpr (std::is_same_v<T, __fp8_e4m3> || std::is_same_v<T, __fp8_e5m2>) {
|
||||
@@ -187,9 +190,10 @@ void AllreduceNvlsBlockPipeline::initialize(std::shared_ptr<Communicator> comm)
|
||||
CommResult AllreduceNvlsBlockPipeline::allreduceKernelFunc(const std::shared_ptr<void> ctx_void, const void* input,
|
||||
void* output, size_t inputSize, DataType dtype, ReduceOp op,
|
||||
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>&) {
|
||||
const std::unordered_map<std::string, uintptr_t>& extras,
|
||||
DataType accumDtype) {
|
||||
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
|
||||
AllreduceFunc allreduce = dispatch<NvlsBlockPipelineAdapter>(op, dtype);
|
||||
AllreduceFunc allreduce = dispatch<NvlsBlockPipelineAdapter>(op, dtype, accumDtype);
|
||||
if (!allreduce) {
|
||||
WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast<int>(dtype));
|
||||
return CommResult::CommInvalidArgument;
|
||||
@@ -235,9 +239,9 @@ std::shared_ptr<Algorithm> AllreduceNvlsBlockPipeline::build() {
|
||||
[self](std::shared_ptr<Communicator> comm) { self->initialize(comm); },
|
||||
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) {
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype) {
|
||||
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
|
||||
extras);
|
||||
extras, accumDtype);
|
||||
},
|
||||
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize,
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "allreduce/allreduce_nvls_packet.hpp"
|
||||
#include "allreduce/common.hpp"
|
||||
#include "collective_utils.hpp"
|
||||
#include "debug.h"
|
||||
#include "logger.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
namespace collective {
|
||||
|
||||
template <ReduceOp OpType, typename T>
|
||||
template <ReduceOp OpType, typename T, typename AccumT = T>
|
||||
__global__ void __launch_bounds__(1024, 1)
|
||||
allreduceNvlsPacket([[maybe_unused]] const T* input, [[maybe_unused]] T* scratch, [[maybe_unused]] T* output,
|
||||
[[maybe_unused]] mscclpp::DeviceHandle<mscclpp::SwitchChannel>* multicast,
|
||||
@@ -31,15 +33,16 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
mscclpp::SwitchChannelDeviceHandle::multimemStore(*(mscclpp::f32x2*)(&pkt), multiPkt + i);
|
||||
}
|
||||
for (uint32_t i = tid; i < nPktPerRank * worldSize; i += blockDim.x * gridDim.x) {
|
||||
uint data = src[i];
|
||||
// When T == AccumT, stay with raw uint to avoid type mismatch in identity path.
|
||||
using AccRaw =
|
||||
std::conditional_t<std::is_same_v<T, AccumT>, uint, mscclpp::VectorType<AccumT, sizeof(uint) / sizeof(T)>>;
|
||||
AccRaw acc = mscclpp::upcastVector<T, AccumT, AccRaw>(src[i]);
|
||||
for (int peer = 0; peer < worldSize; peer++) {
|
||||
if (peer == rank) {
|
||||
continue;
|
||||
}
|
||||
if (peer == rank) continue;
|
||||
uint val = scratchPkt[peer * worldSize * nPktPerRank + i].read(flag);
|
||||
data = cal_vector<T, OpType>(data, val);
|
||||
acc = mscclpp::calVectorAccum<T, AccumT, OpType, AccRaw>(acc, val);
|
||||
}
|
||||
dst[i] = data;
|
||||
dst[i] = mscclpp::downcastVector<T, AccumT, uint>(acc);
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
@@ -62,13 +65,13 @@ inline std::pair<int, int> getDefaultBlockNumAndThreadNum(size_t inputSize) {
|
||||
return {blockNum, threadNum};
|
||||
}
|
||||
|
||||
template <ReduceOp OpType, typename T>
|
||||
template <ReduceOp OpType, typename T, typename AccumT = T>
|
||||
struct AllreduceNvlsPacketAdapter {
|
||||
static cudaError_t call(const void* input, void* scratch, void* output, void*, void*,
|
||||
DeviceHandle<SwitchChannel>* nvlsChannels, DeviceHandle<SwitchChannel>*, size_t, size_t,
|
||||
size_t scratchBufferSize, int rank, int, int worldSize, size_t inputSize, cudaStream_t stream,
|
||||
void* flags, uint32_t flagBufferSize, uint32_t, int nBlocks, int nThreadsPerBlock) {
|
||||
allreduceNvlsPacket<OpType, T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
|
||||
allreduceNvlsPacket<OpType, T, AccumT><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
|
||||
(const T*)input, (T*)scratch, (T*)output, nvlsChannels, inputSize / sizeof(T), scratchBufferSize, rank,
|
||||
worldSize, flags, flagBufferSize);
|
||||
return cudaGetLastError();
|
||||
@@ -78,6 +81,8 @@ struct AllreduceNvlsPacketAdapter {
|
||||
void AllreduceNvlsPacket::initialize(std::shared_ptr<Communicator> comm) {
|
||||
int nSwitchChannels = 1;
|
||||
this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels);
|
||||
this->switchChannels_ =
|
||||
setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels);
|
||||
}
|
||||
|
||||
AlgorithmCtxKey AllreduceNvlsPacket::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) {
|
||||
@@ -92,9 +97,7 @@ std::shared_ptr<void> AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr<
|
||||
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
|
||||
|
||||
// setup channels
|
||||
int nSwitchChannels = 1;
|
||||
ctx->switchChannels =
|
||||
setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels);
|
||||
ctx->switchChannels = this->switchChannels_;
|
||||
ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels);
|
||||
return ctx;
|
||||
}
|
||||
@@ -102,19 +105,20 @@ std::shared_ptr<void> AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr<
|
||||
CommResult AllreduceNvlsPacket::allreduceKernelFunc(const std::shared_ptr<void> ctx_void, const void* input,
|
||||
void* output, size_t inputSize, mscclpp::DataType dtype,
|
||||
ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>&) {
|
||||
const std::unordered_map<std::string, uintptr_t>&,
|
||||
mscclpp::DataType accumDtype) {
|
||||
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
|
||||
std::pair<int, int> blockAndThreadNum = {nBlocks, nThreadsPerBlock};
|
||||
if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) {
|
||||
blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize);
|
||||
}
|
||||
if (blockAndThreadNum.first > maxBlockNum_) {
|
||||
WARN("Block number %d exceeds the maximum limit %d", blockAndThreadNum.first, maxBlockNum_);
|
||||
WARN(ALGO, "Block number ", blockAndThreadNum.first, " exceeds the maximum limit ", maxBlockNum_);
|
||||
return CommResult::CommInvalidArgument;
|
||||
}
|
||||
AllreduceFunc allreduce = dispatch<AllreduceNvlsPacketAdapter>(op, dtype);
|
||||
AllreduceFunc allreduce = dispatch<AllreduceNvlsPacketAdapter>(op, dtype, accumDtype);
|
||||
if (!allreduce) {
|
||||
WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast<int>(dtype));
|
||||
WARN(ALGO, "Unsupported operation or data type for allreduce, dtype=", static_cast<int>(dtype));
|
||||
return CommResult::CommInvalidArgument;
|
||||
}
|
||||
cudaError_t error =
|
||||
@@ -122,7 +126,7 @@ CommResult AllreduceNvlsPacket::allreduceKernelFunc(const std::shared_ptr<void>
|
||||
0, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerNode, ctx->workSize, inputSize, stream,
|
||||
(void*)flagBuffer_, (uint32_t)flagBufferSize_, 0, blockAndThreadNum.first, blockAndThreadNum.second);
|
||||
if (error != cudaSuccess) {
|
||||
WARN("AllreduceNvlsPacket failed with error: %s", cudaGetErrorString(error));
|
||||
WARN(ALGO, "AllreduceNvlsPacket failed with error: ", cudaGetErrorString(error));
|
||||
return CommResult::CommUnhandledCudaError;
|
||||
}
|
||||
return CommResult::CommSuccess;
|
||||
@@ -136,9 +140,10 @@ std::shared_ptr<mscclpp::Algorithm> AllreduceNvlsPacket::build() {
|
||||
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
|
||||
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize, mscclpp::DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) {
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras,
|
||||
mscclpp::DataType accumDtype) {
|
||||
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
|
||||
extras);
|
||||
extras, accumDtype);
|
||||
},
|
||||
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize,
|
||||
|
||||
@@ -109,7 +109,7 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <ReduceOp OpType, typename T>
|
||||
template <ReduceOp OpType, typename T, typename AccumT = T>
|
||||
struct NvlsWarpPipelineAdapter {
|
||||
static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void*,
|
||||
DeviceHandle<SwitchChannel>* nvlsChannels, DeviceHandle<SwitchChannel>*, size_t, size_t,
|
||||
@@ -118,6 +118,9 @@ struct NvlsWarpPipelineAdapter {
|
||||
// uint8_t is not supported for NVLS (no hardware support for byte-level reduction)
|
||||
if constexpr (std::is_same_v<T, uint8_t>) {
|
||||
return cudaErrorNotSupported;
|
||||
} else if constexpr (std::is_same_v<T, __fp8_e4m3b15>) {
|
||||
// fp8_e4m3b15 is a software-only type with no hardware NVLS support.
|
||||
return cudaErrorNotSupported;
|
||||
} else
|
||||
#if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS
|
||||
if constexpr (std::is_same_v<T, __fp8_e4m3> || std::is_same_v<T, __fp8_e5m2>) {
|
||||
@@ -147,12 +150,12 @@ void AllreduceNvlsWarpPipeline::initialize(std::shared_ptr<Communicator> comm) {
|
||||
this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels_);
|
||||
}
|
||||
|
||||
CommResult AllreduceNvlsWarpPipeline::allreduceKernelFunc(const std::shared_ptr<void> ctx_void, const void* input,
|
||||
void* output, size_t inputSize, DataType dtype, ReduceOp op,
|
||||
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>&) {
|
||||
CommResult AllreduceNvlsWarpPipeline::allreduceKernelFunc(
|
||||
const std::shared_ptr<void> ctx_void, const void* input, void* output, size_t inputSize, DataType dtype,
|
||||
ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
[[maybe_unused]] const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype) {
|
||||
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
|
||||
AllreduceFunc allreduce = dispatch<NvlsWarpPipelineAdapter>(op, dtype);
|
||||
AllreduceFunc allreduce = dispatch<NvlsWarpPipelineAdapter>(op, dtype, accumDtype);
|
||||
if (!allreduce) {
|
||||
WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast<int>(dtype));
|
||||
return CommResult::CommInvalidArgument;
|
||||
@@ -198,9 +201,9 @@ std::shared_ptr<Algorithm> AllreduceNvlsWarpPipeline::build() {
|
||||
[self](std::shared_ptr<Communicator> comm) { self->initialize(comm); },
|
||||
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) {
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype) {
|
||||
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
|
||||
extras);
|
||||
extras, accumDtype);
|
||||
},
|
||||
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize,
|
||||
|
||||
@@ -67,7 +67,7 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <ReduceOp OpType, typename T>
|
||||
template <ReduceOp OpType, typename T, typename AccumT = T>
|
||||
struct NvlsAdapter {
|
||||
static cudaError_t call(const void*, void*, void*, void* memoryChannels, void*,
|
||||
mscclpp::DeviceHandle<mscclpp::SwitchChannel>* nvlsChannels,
|
||||
@@ -77,6 +77,9 @@ struct NvlsAdapter {
|
||||
// uint8_t is not supported for NVLS (no hardware support for byte-level reduction)
|
||||
if constexpr (std::is_same_v<T, uint8_t>) {
|
||||
return cudaErrorNotSupported;
|
||||
} else if constexpr (std::is_same_v<T, __fp8_e4m3b15>) {
|
||||
// fp8_e4m3b15 is a software-only type with no hardware NVLS support.
|
||||
return cudaErrorNotSupported;
|
||||
} else
|
||||
#if (!defined(__CUDA_ARCH_SPECIFIC__) && !defined(__CUDA_ARCH_FAMILY_SPECIFIC__)) || (__CUDA_ARCH__ < 1000)
|
||||
if constexpr (std::is_same_v<T, __fp8_e4m3> || std::is_same_v<T, __fp8_e5m2>) {
|
||||
@@ -114,13 +117,14 @@ void AllreduceNvls::initialize(std::shared_ptr<mscclpp::Communicator> comm) {
|
||||
CommResult AllreduceNvls::allreduceKernelFunc(const std::shared_ptr<void> ctx_void, const void* input, void* output,
|
||||
size_t inputSize, mscclpp::DataType dtype, ReduceOp op,
|
||||
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>&) {
|
||||
[[maybe_unused]] const std::unordered_map<std::string, uintptr_t>& extras,
|
||||
mscclpp::DataType accumDtype) {
|
||||
if (!symmetricMemory_) {
|
||||
WARN("AllreduceNvls requires symmetric memory for now.");
|
||||
return CommResult::CommInvalidArgument;
|
||||
}
|
||||
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
|
||||
AllreduceFunc allreduce = dispatch<NvlsAdapter>(op, dtype);
|
||||
AllreduceFunc allreduce = dispatch<NvlsAdapter>(op, dtype, accumDtype);
|
||||
if (!allreduce) {
|
||||
WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast<int>(dtype));
|
||||
return CommResult::CommInvalidArgument;
|
||||
@@ -203,9 +207,10 @@ std::shared_ptr<mscclpp::Algorithm> AllreduceNvls::build() {
|
||||
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
|
||||
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize, mscclpp::DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) {
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras,
|
||||
mscclpp::DataType accumDtype) {
|
||||
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
|
||||
extras);
|
||||
extras, accumDtype);
|
||||
},
|
||||
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize,
|
||||
|
||||
@@ -2,16 +2,17 @@
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <mscclpp/algorithm.hpp>
|
||||
#include <type_traits>
|
||||
|
||||
#include "allreduce/allreduce_packet.hpp"
|
||||
#include "allreduce/common.hpp"
|
||||
#include "collective_utils.hpp"
|
||||
#include "debug.h"
|
||||
#include "logger.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
namespace collective {
|
||||
|
||||
template <ReduceOp OpType, typename T>
|
||||
template <ReduceOp OpType, typename T, typename AccumT = T>
|
||||
__global__ void __launch_bounds__(1024, 1)
|
||||
allreducePacket(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
|
||||
size_t channelDataOffset, size_t scratchBufferSize, int rank, int nRanksPerNode, int worldSize,
|
||||
@@ -92,12 +93,21 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
// step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer
|
||||
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) {
|
||||
uint2 data = src[idx];
|
||||
for (int index = 0; index < nPeers; index++) {
|
||||
const int remoteRank = index < rank ? index : index + 1;
|
||||
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank;
|
||||
uint2 val = dstPkt[idx].read(flag);
|
||||
data.x = cal_vector<T, OpType>(val.x, data.x);
|
||||
data.y = cal_vector<T, OpType>(val.y, data.y);
|
||||
{
|
||||
// When T == AccumT, stay with raw uint32_t to avoid type mismatch in identity path.
|
||||
using AccRaw = std::conditional_t<std::is_same_v<T, AccumT>, uint32_t,
|
||||
mscclpp::VectorType<AccumT, sizeof(uint32_t) / sizeof(T)>>;
|
||||
AccRaw accX = mscclpp::upcastVector<T, AccumT, AccRaw>(data.x);
|
||||
AccRaw accY = mscclpp::upcastVector<T, AccumT, AccRaw>(data.y);
|
||||
for (int index = 0; index < nPeers; index++) {
|
||||
const int remoteRank = index < rank ? index : index + 1;
|
||||
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank;
|
||||
uint2 val = dstPkt[idx].read(flag);
|
||||
accX = mscclpp::calVectorAccum<T, AccumT, OpType, AccRaw>(accX, val.x);
|
||||
accY = mscclpp::calVectorAccum<T, AccumT, OpType, AccRaw>(accY, val.y);
|
||||
}
|
||||
data.x = mscclpp::downcastVector<T, AccumT, uint32_t>(accX);
|
||||
data.y = mscclpp::downcastVector<T, AccumT, uint32_t>(accY);
|
||||
}
|
||||
|
||||
dst[idx].x = data.x;
|
||||
@@ -142,7 +152,7 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <ReduceOp OpType, typename T>
|
||||
template <ReduceOp OpType, typename T, typename AccumT = T>
|
||||
struct PacketAdapter {
|
||||
static cudaError_t call(const void* buff, void* scratch, void* resultBuff, void* memoryChannels, void*,
|
||||
DeviceHandle<SwitchChannel>*, DeviceHandle<SwitchChannel>*, size_t channelInOffset, size_t,
|
||||
@@ -155,12 +165,12 @@ struct PacketAdapter {
|
||||
nBlocks = nBlocks / (worldSize - 1) * (worldSize - 1);
|
||||
#if defined(ENABLE_NPKIT)
|
||||
size_t sharedMemSize = sizeof(NpKitEvent) * NPKIT_SHM_NUM_EVENTS;
|
||||
allreducePacket<OpType><<<nBlocks, nThreadsPerBlock, sharedMemSize, stream>>>(
|
||||
allreducePacket<OpType, T, AccumT><<<nBlocks, nThreadsPerBlock, sharedMemSize, stream>>>(
|
||||
(T*)buff, (T*)scratch, (T*)resultBuff, (ChannelType*)memoryChannels, channelInOffset, scratchBufferSize, rank,
|
||||
nRanksPerNode, worldSize, nelems, flags, flagBufferSize, numScratchBuff, NpKit::GetGpuEventCollectContexts(),
|
||||
NpKit::GetCpuTimestamp());
|
||||
#else
|
||||
allreducePacket<OpType><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
|
||||
allreducePacket<OpType, T, AccumT><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
|
||||
(T*)buff, (T*)scratch, (T*)resultBuff, (ChannelType*)memoryChannels, channelInOffset, scratchBufferSize, rank,
|
||||
nRanksPerNode, worldSize, nelems, flags, flagBufferSize, numScratchBuff);
|
||||
#endif
|
||||
@@ -186,18 +196,22 @@ inline std::pair<int, int> getDefaultBlockNumAndThreadNum(size_t inputSize, int
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__FP8_TYPES_EXIST__)
|
||||
// FP8-specific tuning for 32KB-256KB range
|
||||
if (dtype == DataType::FLOAT8_E4M3 || dtype == DataType::FLOAT8_E5M2) {
|
||||
if (inputSize < (64 << 10)) {
|
||||
nThreadsPerBlock = 64;
|
||||
} else if (inputSize >= (64 << 10) && inputSize <= (128 << 10)) {
|
||||
nThreadsPerBlock = 128;
|
||||
} else if (inputSize >= (128 << 10) && inputSize <= (256 << 10)) {
|
||||
nThreadsPerBlock = 256;
|
||||
{
|
||||
bool isFp8 = dtype == DataType::FLOAT8_E4M3B15;
|
||||
#if defined(__FP8_TYPES_EXIST__)
|
||||
isFp8 = isFp8 || dtype == DataType::FLOAT8_E4M3 || dtype == DataType::FLOAT8_E5M2;
|
||||
#endif
|
||||
if (isFp8) {
|
||||
if (inputSize < (64 << 10)) {
|
||||
nThreadsPerBlock = 64;
|
||||
} else if (inputSize >= (64 << 10) && inputSize <= (128 << 10)) {
|
||||
nThreadsPerBlock = 128;
|
||||
} else if (inputSize >= (128 << 10) && inputSize <= (256 << 10)) {
|
||||
nThreadsPerBlock = 256;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
return {nBlocks, nThreadsPerBlock};
|
||||
}
|
||||
@@ -213,7 +227,8 @@ void AllreducePacket::initialize(std::shared_ptr<Communicator> comm) {
|
||||
CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr<void> ctx_void, const void* input, void* output,
|
||||
size_t inputSize, [[maybe_unused]] DataType dtype, ReduceOp op,
|
||||
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>&) {
|
||||
const std::unordered_map<std::string, uintptr_t>&,
|
||||
DataType accumDtype) {
|
||||
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
|
||||
std::pair<int, int> blockAndThreadNum = {nBlocks, nThreadsPerBlock};
|
||||
if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) {
|
||||
@@ -225,9 +240,10 @@ CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr<void> ctx_
|
||||
MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)input));
|
||||
size_t channelInOffset = (char*)input - (char*)sendBasePtr;
|
||||
|
||||
AllreduceFunc allreduce = dispatch<PacketAdapter>(op, dtype);
|
||||
AllreduceFunc allreduce = dispatch<PacketAdapter>(op, dtype, accumDtype);
|
||||
if (!allreduce) {
|
||||
WARN("Unsupported operation or data type for allreduce: op=%d, dtype=%d", op, static_cast<int>(dtype));
|
||||
WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast<int>(op),
|
||||
", dtype=", static_cast<int>(dtype));
|
||||
return CommResult::CommInvalidArgument;
|
||||
}
|
||||
cudaError_t error =
|
||||
@@ -236,7 +252,7 @@ CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr<void> ctx_
|
||||
stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_, this->nSegmentsForScratchBuffer_,
|
||||
blockAndThreadNum.first, blockAndThreadNum.second);
|
||||
if (error != cudaSuccess) {
|
||||
WARN("AllreducePacket failed with error: %s", cudaGetErrorString(error));
|
||||
WARN(ALGO, "AllreducePacket failed with error: ", cudaGetErrorString(error));
|
||||
return CommResult::CommUnhandledCudaError;
|
||||
}
|
||||
return CommResult::CommSuccess;
|
||||
@@ -280,9 +296,9 @@ std::shared_ptr<Algorithm> AllreducePacket::build() {
|
||||
"default_allreduce_packet", "allreduce", [self](std::shared_ptr<Communicator> comm) { self->initialize(comm); },
|
||||
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) {
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype) {
|
||||
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
|
||||
extras);
|
||||
extras, accumDtype);
|
||||
},
|
||||
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize,
|
||||
|
||||
@@ -87,7 +87,7 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
int rankIdx = (rank + i + 1) % nRanksPerNode;
|
||||
int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1;
|
||||
int4 data = mscclpp::read<int4>(((void**)remoteMemories)[peerIdx], offset);
|
||||
tmp = cal_vector<T, OpType>(data, tmp);
|
||||
tmp = calVector<T, OpType>(data, tmp);
|
||||
}
|
||||
for (uint32_t i = 0; i < nPeers; i++) {
|
||||
int rankIdx = (rank + i + 1) % nRanksPerNode;
|
||||
@@ -123,7 +123,7 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
}
|
||||
}
|
||||
|
||||
template <ReduceOp OpType, typename T>
|
||||
template <ReduceOp OpType, typename T, typename AccumT = T>
|
||||
struct AllreduceRsAgAdapter {
|
||||
static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories,
|
||||
DeviceHandle<SwitchChannel>* switchChannel, DeviceHandle<SwitchChannel>*, size_t, size_t,
|
||||
@@ -166,9 +166,9 @@ void AllreduceRsAg::initialize(std::shared_ptr<Communicator> comm) {
|
||||
CommResult AllreduceRsAg::allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output,
|
||||
size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream,
|
||||
int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>&) {
|
||||
const std::unordered_map<std::string, uintptr_t>&, DataType accumDtype) {
|
||||
auto algoCtx = std::static_pointer_cast<AlgorithmCtx>(ctx);
|
||||
AllreduceFunc allreduce = dispatch<AllreduceRsAgAdapter>(op, dtype);
|
||||
AllreduceFunc allreduce = dispatch<AllreduceRsAgAdapter>(op, dtype, accumDtype);
|
||||
if (!allreduce) {
|
||||
WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast<int>(op),
|
||||
", dtype=", static_cast<int>(dtype));
|
||||
@@ -213,9 +213,10 @@ std::shared_ptr<Algorithm> AllreduceRsAg::build() {
|
||||
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
|
||||
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) -> CommResult {
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras,
|
||||
DataType accumDtype) -> CommResult {
|
||||
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
|
||||
extras);
|
||||
extras, accumDtype);
|
||||
},
|
||||
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize,
|
||||
|
||||
@@ -168,7 +168,7 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
uint32_t peerSlotOffset =
|
||||
baseOffset + remoteRankId * nInt4PerIter + threadIdInPut + putStep * blockDim.x * nblocksForPut;
|
||||
int4 data = scratch4[peerSlotOffset];
|
||||
tmp = cal_vector<T, OpType>(data, tmp);
|
||||
tmp = calVector<T, OpType>(data, tmp);
|
||||
}
|
||||
storeVec(resultBuff, myChunkOffset, tmp, nelems);
|
||||
// Broadcast reduced result to all peers' scratch at SCATTER_AG_OFFSET + rank * nInt4PerIter
|
||||
@@ -220,7 +220,7 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
}
|
||||
}
|
||||
|
||||
template <ReduceOp OpType, typename T>
|
||||
template <ReduceOp OpType, typename T, typename AccumT = T>
|
||||
struct AllreduceRsAgPipelineAdapter {
|
||||
static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories,
|
||||
DeviceHandle<SwitchChannel>* switchChannel, DeviceHandle<SwitchChannel>*, size_t, size_t,
|
||||
@@ -274,12 +274,12 @@ void AllreduceRsAgPipeline::initialize(std::shared_ptr<Communicator> comm) {
|
||||
cudaMemcpyHostToDevice);
|
||||
}
|
||||
|
||||
CommResult AllreduceRsAgPipeline::allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output,
|
||||
size_t inputSize, DataType dtype, ReduceOp op,
|
||||
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>&) {
|
||||
CommResult AllreduceRsAgPipeline::allreduceKernelFunc(
|
||||
const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op,
|
||||
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
[[maybe_unused]] const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype) {
|
||||
auto algoCtx = std::static_pointer_cast<AlgorithmCtx>(ctx);
|
||||
AllreduceFunc allreduce = dispatch<AllreduceRsAgPipelineAdapter>(op, dtype);
|
||||
AllreduceFunc allreduce = dispatch<AllreduceRsAgPipelineAdapter>(op, dtype, accumDtype);
|
||||
if (!allreduce) {
|
||||
WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast<int>(op),
|
||||
", dtype=", static_cast<int>(dtype));
|
||||
@@ -320,9 +320,10 @@ std::shared_ptr<Algorithm> AllreduceRsAgPipeline::build() {
|
||||
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
|
||||
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) -> CommResult {
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras,
|
||||
DataType accumDtype) -> CommResult {
|
||||
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
|
||||
extras);
|
||||
extras, accumDtype);
|
||||
},
|
||||
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "allreduce/allreduce_rsag_zero_copy.hpp"
|
||||
#include "allreduce/common.hpp"
|
||||
#include "collective_utils.hpp"
|
||||
@@ -36,7 +38,7 @@ __device__ mscclpp::DeviceSyncer globalSyncer;
|
||||
// the extra copy steps of the standard RSAG. The NRanksPerNode template
|
||||
// parameter enables compile-time unrolling of peer loops (supports 4 or 8).
|
||||
|
||||
template <int NRanksPerNode, ReduceOp OpType, typename T>
|
||||
template <int NRanksPerNode, ReduceOp OpType, typename T, typename AccumT = T>
|
||||
__global__ void __launch_bounds__(1024, 1)
|
||||
allreduceRsAgZeroCopy(T* buff, T* scratch, T* resultBuff, DeviceHandle<BaseMemoryChannel>* memoryChannels,
|
||||
DeviceHandle<SwitchChannel>* switchChannels, void* remoteMemories, int rank, int worldSize,
|
||||
@@ -73,19 +75,26 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
}
|
||||
__syncthreads();
|
||||
int4 data[NPeers];
|
||||
// AccumInt4: when AccumT != T, use a wider accumulator type.
|
||||
// For AccumT == T, this is just int4 (no-op conversion).
|
||||
constexpr int nElemsPerInt4 = sizeof(int4) / sizeof(T);
|
||||
// When T == AccumT, stay with raw int4 to avoid type mismatch in identity path.
|
||||
using AccumVec = std::conditional_t<std::is_same_v<T, AccumT>, int4, mscclpp::VectorType<AccumT, nElemsPerInt4>>;
|
||||
for (uint32_t idx = threadIdx.x; idx < nInt4PerBlock; idx += blockDim.x) {
|
||||
uint32_t offset = idx + offset4 + rank * nInt4PerRank;
|
||||
if (offset >= nInt4Total) continue;
|
||||
int4 tmp = buff4[offset];
|
||||
int4 tmp_raw = buff4[offset];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NPeers; i++) {
|
||||
int rankIdx = (rank + i + 1) % NRanksPerNode;
|
||||
int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1;
|
||||
data[i] = mscclpp::read<int4>(((void**)remoteMemories)[peerIdx], offset);
|
||||
}
|
||||
AccumVec acc = mscclpp::upcastVector<T, AccumT, AccumVec>(tmp_raw);
|
||||
for (int i = 0; i < NPeers; i++) {
|
||||
tmp = cal_vector<T, OpType>(data[i], tmp);
|
||||
acc = mscclpp::calVectorAccum<T, AccumT, OpType, AccumVec>(acc, data[i]);
|
||||
}
|
||||
int4 tmp = mscclpp::downcastVector<T, AccumT, int4>(acc);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NPeers; i++) {
|
||||
int rankIdx = (rank + i + 1) % NRanksPerNode;
|
||||
@@ -102,7 +111,7 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
}
|
||||
}
|
||||
|
||||
template <ReduceOp OpType, typename T>
|
||||
template <ReduceOp OpType, typename T, typename AccumT = T>
|
||||
struct AllreduceRsAgZeroCopyAdapter {
|
||||
static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories,
|
||||
DeviceHandle<SwitchChannel>* switchChannel, DeviceHandle<SwitchChannel>*, size_t, size_t,
|
||||
@@ -118,11 +127,11 @@ struct AllreduceRsAgZeroCopyAdapter {
|
||||
}
|
||||
}
|
||||
if (nRanksPerNode == 4) {
|
||||
allreduceRsAgZeroCopy<4, OpType, T>
|
||||
allreduceRsAgZeroCopy<4, OpType, T, AccumT>
|
||||
<<<nBlocks, nThreadsPerBlock, 0, stream>>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels,
|
||||
switchChannel, remoteMemories, rank, worldSize, nelems);
|
||||
} else if (nRanksPerNode == 8) {
|
||||
allreduceRsAgZeroCopy<8, OpType, T>
|
||||
allreduceRsAgZeroCopy<8, OpType, T, AccumT>
|
||||
<<<nBlocks, nThreadsPerBlock, 0, stream>>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels,
|
||||
switchChannel, remoteMemories, rank, worldSize, nelems);
|
||||
} else {
|
||||
@@ -145,9 +154,10 @@ void AllreduceRsAgZeroCopy::initialize(std::shared_ptr<Communicator> comm) {
|
||||
CommResult AllreduceRsAgZeroCopy::allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output,
|
||||
size_t inputSize, DataType dtype, ReduceOp op,
|
||||
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>&) {
|
||||
const std::unordered_map<std::string, uintptr_t>&,
|
||||
DataType accumDtype) {
|
||||
auto algoCtx = std::static_pointer_cast<AlgorithmCtx>(ctx);
|
||||
AllreduceFunc allreduce = dispatch<AllreduceRsAgZeroCopyAdapter>(op, dtype);
|
||||
AllreduceFunc allreduce = dispatch<AllreduceRsAgZeroCopyAdapter>(op, dtype, accumDtype);
|
||||
if (!allreduce) {
|
||||
WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast<int>(op),
|
||||
", dtype=", static_cast<int>(dtype));
|
||||
@@ -220,9 +230,10 @@ std::shared_ptr<Algorithm> AllreduceRsAgZeroCopy::build() {
|
||||
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
|
||||
[self](const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras) -> CommResult {
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras,
|
||||
DataType accumDtype) -> CommResult {
|
||||
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock,
|
||||
extras);
|
||||
extras, accumDtype);
|
||||
},
|
||||
[self](std::shared_ptr<Communicator> comm, const void* input, void* output, size_t inputSize,
|
||||
[[maybe_unused]] size_t outputSize,
|
||||
|
||||
@@ -20,7 +20,7 @@ class AllreduceAllpairPacket : public AlgorithmBuilder {
|
||||
void initialize(std::shared_ptr<Communicator> comm);
|
||||
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras);
|
||||
const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype);
|
||||
|
||||
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
|
||||
DataType);
|
||||
|
||||
@@ -16,7 +16,7 @@ class AllreduceFullmesh : public mscclpp::AlgorithmBuilder {
|
||||
void initialize(std::shared_ptr<Communicator> comm);
|
||||
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras);
|
||||
const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype);
|
||||
|
||||
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
|
||||
DataType);
|
||||
|
||||
@@ -19,7 +19,7 @@ class AllreduceNvlsBlockPipeline : public AlgorithmBuilder {
|
||||
void initialize(std::shared_ptr<Communicator> comm);
|
||||
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras);
|
||||
const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype);
|
||||
|
||||
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
|
||||
DataType);
|
||||
|
||||
@@ -21,7 +21,8 @@ class AllreduceNvlsPacket : public mscclpp::AlgorithmBuilder {
|
||||
void initialize(std::shared_ptr<mscclpp::Communicator> comm);
|
||||
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
mscclpp::DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks,
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras);
|
||||
int nThreadsPerBlock, const std::unordered_map<std::string, uintptr_t>& extras,
|
||||
mscclpp::DataType accumDtype);
|
||||
|
||||
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<mscclpp::Communicator> comm, const void*, void* output,
|
||||
size_t, mscclpp::DataType);
|
||||
@@ -34,6 +35,7 @@ class AllreduceNvlsPacket : public mscclpp::AlgorithmBuilder {
|
||||
uintptr_t flagBuffer_;
|
||||
size_t flagBufferSize_;
|
||||
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections_;
|
||||
std::vector<SwitchChannel> switchChannels_;
|
||||
};
|
||||
} // namespace collective
|
||||
} // namespace mscclpp
|
||||
|
||||
@@ -19,7 +19,7 @@ class AllreduceNvlsWarpPipeline : public AlgorithmBuilder {
|
||||
void initialize(std::shared_ptr<Communicator> comm);
|
||||
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras);
|
||||
const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype);
|
||||
|
||||
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
|
||||
DataType);
|
||||
|
||||
@@ -19,7 +19,7 @@ class AllreduceNvls : public AlgorithmBuilder {
|
||||
void initialize(std::shared_ptr<Communicator> comm);
|
||||
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras);
|
||||
const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype);
|
||||
|
||||
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
|
||||
DataType);
|
||||
|
||||
@@ -20,7 +20,7 @@ class AllreducePacket : public AlgorithmBuilder {
|
||||
void initialize(std::shared_ptr<Communicator> comm);
|
||||
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras);
|
||||
const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype);
|
||||
|
||||
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
|
||||
DataType);
|
||||
|
||||
@@ -19,7 +19,7 @@ class AllreduceRsAg : public mscclpp::AlgorithmBuilder {
|
||||
void initialize(std::shared_ptr<Communicator> comm);
|
||||
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras);
|
||||
const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype);
|
||||
|
||||
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
|
||||
DataType);
|
||||
|
||||
@@ -19,7 +19,7 @@ class AllreduceRsAgPipeline : public mscclpp::AlgorithmBuilder {
|
||||
void initialize(std::shared_ptr<Communicator> comm);
|
||||
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras);
|
||||
const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype);
|
||||
|
||||
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
|
||||
DataType);
|
||||
|
||||
@@ -18,7 +18,7 @@ class AllreduceRsAgZeroCopy : public mscclpp::AlgorithmBuilder {
|
||||
void initialize(std::shared_ptr<Communicator> comm);
|
||||
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras);
|
||||
const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype);
|
||||
|
||||
std::shared_ptr<void> initAllreduceContext(std::shared_ptr<Communicator> comm, const void*, void* output, size_t,
|
||||
DataType);
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef MSCCLPP_ALLREDUCE_COMMOM_HPP_
|
||||
#define MSCCLPP_ALLREDUCE_COMMOM_HPP_
|
||||
#ifndef MSCCLPP_ALLREDUCE_COMMON_HPP_
|
||||
#define MSCCLPP_ALLREDUCE_COMMON_HPP_
|
||||
|
||||
#include <cmath>
|
||||
#include <mscclpp/algorithm.hpp>
|
||||
@@ -77,55 +77,51 @@ using AllreduceFunc =
|
||||
mscclpp::DeviceHandle<mscclpp::SwitchChannel>*, size_t, size_t, size_t, int, int, int,
|
||||
size_t, cudaStream_t, void*, uint32_t, uint32_t, int, int)>;
|
||||
|
||||
template <template <ReduceOp, typename> class Adapter>
|
||||
AllreduceFunc dispatch(ReduceOp op, mscclpp::DataType dtype) {
|
||||
if (op == SUM) {
|
||||
if (dtype == mscclpp::DataType::FLOAT16) {
|
||||
return Adapter<SUM, half>::call;
|
||||
} else if (dtype == mscclpp::DataType::FLOAT32) {
|
||||
return Adapter<SUM, float>::call;
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
} else if (dtype == mscclpp::DataType::BFLOAT16) {
|
||||
return Adapter<SUM, __bfloat16>::call;
|
||||
#endif
|
||||
#if defined(__FP8_TYPES_EXIST__)
|
||||
} else if (dtype == mscclpp::DataType::FLOAT8_E4M3) {
|
||||
return Adapter<SUM, __fp8_e4m3>::call;
|
||||
} else if (dtype == mscclpp::DataType::FLOAT8_E5M2) {
|
||||
return Adapter<SUM, __fp8_e5m2>::call;
|
||||
#endif
|
||||
} else if (dtype == mscclpp::DataType::INT32 || dtype == mscclpp::DataType::UINT32) {
|
||||
return Adapter<SUM, int>::call;
|
||||
} else if (dtype == mscclpp::DataType::UINT8) {
|
||||
return Adapter<SUM, uint8_t>::call;
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
} else if (op == MIN) {
|
||||
if (dtype == mscclpp::DataType::FLOAT16) {
|
||||
return Adapter<MIN, half>::call;
|
||||
} else if (dtype == mscclpp::DataType::FLOAT32) {
|
||||
return Adapter<MIN, float>::call;
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
} else if (dtype == mscclpp::DataType::BFLOAT16) {
|
||||
return Adapter<MIN, __bfloat16>::call;
|
||||
#endif
|
||||
#if defined(__FP8_TYPES_EXIST__)
|
||||
} else if (dtype == mscclpp::DataType::FLOAT8_E4M3) {
|
||||
return Adapter<MIN, __fp8_e4m3>::call;
|
||||
} else if (dtype == mscclpp::DataType::FLOAT8_E5M2) {
|
||||
return Adapter<MIN, __fp8_e5m2>::call;
|
||||
#endif
|
||||
} else if (dtype == mscclpp::DataType::INT32 || dtype == mscclpp::DataType::UINT32) {
|
||||
return Adapter<MIN, int>::call;
|
||||
} else if (dtype == mscclpp::DataType::UINT8) {
|
||||
return Adapter<MIN, uint8_t>::call;
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
/// Dispatch helper for FP8 types with a configurable accumulation type.
|
||||
template <ReduceOp Op, typename FP8T, template <ReduceOp, typename, typename> class Adapter>
|
||||
AllreduceFunc dispatchFp8Accum(mscclpp::DataType accumDtype, mscclpp::DataType dtype) {
|
||||
if (accumDtype == mscclpp::DataType::FLOAT32) {
|
||||
return Adapter<Op, FP8T, float>::call;
|
||||
} else if (accumDtype == mscclpp::DataType::FLOAT16) {
|
||||
return Adapter<Op, FP8T, half>::call;
|
||||
} else if (accumDtype == dtype) {
|
||||
return Adapter<Op, FP8T, FP8T>::call;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <ReduceOp Op, template <ReduceOp, typename, typename> class Adapter>
|
||||
AllreduceFunc dispatchByDtype(mscclpp::DataType dtype, mscclpp::DataType accumDtype) {
|
||||
if (dtype == mscclpp::DataType::FLOAT16) {
|
||||
return Adapter<Op, half, half>::call;
|
||||
} else if (dtype == mscclpp::DataType::FLOAT32) {
|
||||
return Adapter<Op, float, float>::call;
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
} else if (dtype == mscclpp::DataType::BFLOAT16) {
|
||||
return Adapter<Op, __bfloat16, __bfloat16>::call;
|
||||
#endif
|
||||
#if defined(__FP8_TYPES_EXIST__)
|
||||
} else if (dtype == mscclpp::DataType::FLOAT8_E4M3) {
|
||||
return dispatchFp8Accum<Op, __fp8_e4m3, Adapter>(accumDtype, dtype);
|
||||
} else if (dtype == mscclpp::DataType::FLOAT8_E5M2) {
|
||||
return dispatchFp8Accum<Op, __fp8_e5m2, Adapter>(accumDtype, dtype);
|
||||
#endif
|
||||
} else if (dtype == mscclpp::DataType::FLOAT8_E4M3B15) {
|
||||
return dispatchFp8Accum<Op, __fp8_e4m3b15, Adapter>(accumDtype, dtype);
|
||||
} else if (dtype == mscclpp::DataType::INT32 || dtype == mscclpp::DataType::UINT32) {
|
||||
return Adapter<Op, int, int>::call;
|
||||
} else if (dtype == mscclpp::DataType::UINT8) {
|
||||
return Adapter<Op, uint8_t, uint8_t>::call;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <template <ReduceOp, typename, typename> class Adapter>
|
||||
AllreduceFunc dispatch(ReduceOp op, mscclpp::DataType dtype, mscclpp::DataType accumDtype) {
|
||||
if (op == SUM) return dispatchByDtype<SUM, Adapter>(dtype, accumDtype);
|
||||
if (op == MIN) return dispatchByDtype<MIN, Adapter>(dtype, accumDtype);
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace mscclpp
|
||||
|
||||
|
||||
@@ -15,7 +15,8 @@ static bool isNvlsSupportedForDataType(const AlgorithmSelectorConfig& config, Da
|
||||
bool nvlsSupported = config.nvlsSupported;
|
||||
|
||||
// NVLS does not support uint8_t (no hardware support for byte-level reduction)
|
||||
if (dtype == DataType::UINT8) {
|
||||
// NVLS also does not support float8_e4m3b15 (software-defined type with no hardware NVLS reduction support)
|
||||
if (dtype == DataType::UINT8 || dtype == DataType::FLOAT8_E4M3B15) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ inline size_t getDataTypeSize(mscclpp::DataType dtype) {
|
||||
case mscclpp::DataType::UINT8:
|
||||
case mscclpp::DataType::FLOAT8_E4M3:
|
||||
case mscclpp::DataType::FLOAT8_E5M2:
|
||||
case mscclpp::DataType::FLOAT8_E4M3B15:
|
||||
return 1;
|
||||
case mscclpp::DataType::FLOAT16:
|
||||
case mscclpp::DataType::BFLOAT16:
|
||||
@@ -76,6 +77,10 @@ static inline ncclDataType_t mscclppToNcclDataType(mscclpp::DataType dtype) {
|
||||
case mscclpp::DataType::FLOAT8_E5M2:
|
||||
return ncclFloat8e5m2;
|
||||
#endif
|
||||
case mscclpp::DataType::FLOAT8_E4M3B15:
|
||||
// float8_e4m3b15 has no NCCL equivalent; NCCL cannot reduce this type correctly.
|
||||
THROW(mscclpp::LogSubsys::NCCL, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage,
|
||||
"FLOAT8_E4M3B15 (float8_e4m3b15) has no NCCL equivalent and cannot be used with NCCL collectives");
|
||||
default:
|
||||
THROW(mscclpp::LogSubsys::NCCL, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage,
|
||||
"Unsupported mscclpp::DataType: " + std::to_string(static_cast<int>(dtype)));
|
||||
|
||||
@@ -83,17 +83,17 @@ static inline int mscclppNcclDlopenInit() {
|
||||
const char* ncclLibPath = mscclpp::env()->ncclSharedLibPath.c_str();
|
||||
if (ncclLibPath != nullptr && ncclLibPath[0] != '\0') {
|
||||
if (std::filesystem::is_directory(ncclLibPath)) {
|
||||
WARN(MSCCLPP_NCCL, "The value of the environment variable %s is a directory", ncclLibPath);
|
||||
WARN(MSCCLPP_NCCL, "MSCCLPP_NCCL_LIB_PATH points to a directory: ", ncclLibPath);
|
||||
return dlopenError;
|
||||
}
|
||||
|
||||
mscclppNcclDlHandle = dlopen(ncclLibPath, RTLD_LAZY | RTLD_NODELETE | RTLD_DEEPBIND);
|
||||
if (!mscclppNcclDlHandle) {
|
||||
WARN(MSCCLPP_NCCL, "Cannot open the shared library specified by MSCCLPP_NCCL_LIB_PATH: %s\n", dlerror());
|
||||
WARN(MSCCLPP_NCCL, "Cannot open the shared library specified by MSCCLPP_NCCL_LIB_PATH: ", dlerror());
|
||||
return dlopenError;
|
||||
}
|
||||
} else {
|
||||
WARN(MSCCLPP_NCCL, "The value of MSCCLPP_NCCL_LIB_PATH is empty!\n");
|
||||
WARN(MSCCLPP_NCCL, "The value of MSCCLPP_NCCL_LIB_PATH is empty!");
|
||||
return dlopenError;
|
||||
}
|
||||
|
||||
@@ -270,19 +270,18 @@ static std::shared_ptr<mscclpp::Algorithm> algoSelector(
|
||||
return mscclpp::nccl::selectSingleNodeAllreduce(algoMap, request, config);
|
||||
}
|
||||
|
||||
INFO(MSCCLPP_NCCL, "No suitable algorithm found for collective '%s', fallback to nccl/rccl",
|
||||
request.collective.c_str());
|
||||
INFO(MSCCLPP_NCCL, "No suitable algorithm found for collective '", request.collective, "', fallback to nccl/rccl");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank) {
|
||||
INFO(MSCCLPP_NCCL, "Initializing NCCL communicator for rank %d, world_size=%d", rank, nranks);
|
||||
INFO(MSCCLPP_NCCL, "Initializing NCCL communicator for rank ", rank, ", world_size=", nranks);
|
||||
if (comm == nullptr) {
|
||||
WARN(MSCCLPP_NCCL, "comm is nullptr");
|
||||
return ncclInvalidArgument;
|
||||
}
|
||||
if (nranks < 0 || rank < 0 || rank >= nranks) {
|
||||
WARN(MSCCLPP_NCCL, "nranks is %d, rank is %d", nranks, rank);
|
||||
WARN(MSCCLPP_NCCL, "nranks is ", nranks, ", rank is ", rank);
|
||||
return ncclInvalidArgument;
|
||||
}
|
||||
std::shared_ptr<mscclpp::TcpBootstrap> bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, nranks);
|
||||
@@ -560,8 +559,8 @@ NCCL_API ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t
|
||||
return ncclInvalidArgument;
|
||||
}
|
||||
|
||||
INFO(MSCCLPP_NCCL, "rank %d broadcast sendbuff %p recvbuff %p count %ld, dtype %d, comm: %p", rank, sendbuff,
|
||||
recvbuff, count, datatype, comm);
|
||||
INFO(MSCCLPP_NCCL, "rank ", rank, " broadcast sendbuff ", sendbuff, " recvbuff ", recvbuff, " count ", count,
|
||||
", dtype ", datatype, ", comm: ", (void*)comm);
|
||||
|
||||
const char* fallbackList = mscclpp::env()->forceNcclFallbackOperation.c_str();
|
||||
if (mscclppNcclDlopenSharedLib == true && mscclppNcclInFallbackList("broadcast", fallbackList)) {
|
||||
@@ -619,8 +618,8 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
|
||||
}
|
||||
// Declarating variables
|
||||
int rank = comm->comm->bootstrap()->getRank();
|
||||
INFO(MSCCLPP_NCCL, "rank %d allreduce sendbuff %p recvbuff %p count %ld, dtype %d comm is %p", rank, sendbuff,
|
||||
recvbuff, count, datatype, comm);
|
||||
INFO(MSCCLPP_NCCL, "rank ", rank, " allreduce sendbuff ", sendbuff, " recvbuff ", recvbuff, " count ", count,
|
||||
", dtype ", datatype, " comm is ", (void*)comm);
|
||||
|
||||
const char* fallbackList = mscclpp::env()->forceNcclFallbackOperation.c_str();
|
||||
if (mscclppNcclDlopenSharedLib && mscclppNcclInFallbackList("allreduce", fallbackList)) {
|
||||
@@ -673,8 +672,8 @@ NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, si
|
||||
return ncclInvalidArgument;
|
||||
}
|
||||
|
||||
INFO(MSCCLPP_NCCL, "ReduceScatter recvcount: %ld, datatype: %d, op: %d, messageSize: %ld", recvcount, datatype, op,
|
||||
bytes * comm->comm->bootstrap()->getNranks());
|
||||
INFO(MSCCLPP_NCCL, "ReduceScatter recvcount: ", recvcount, ", datatype: ", datatype, ", op: ", op,
|
||||
", messageSize: ", bytes * comm->comm->bootstrap()->getNranks());
|
||||
|
||||
const char* fallbackList = mscclpp::env()->forceNcclFallbackOperation.c_str();
|
||||
if (mscclppNcclDlopenSharedLib == true && mscclppNcclInFallbackList("reducescatter", fallbackList)) {
|
||||
@@ -730,8 +729,8 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
|
||||
|
||||
int rank = comm->comm->bootstrap()->getRank();
|
||||
int nRank = comm->comm->bootstrap()->getNranks();
|
||||
INFO(MSCCLPP_NCCL, "rank %d allgather sendbuff %p recvbuff %p count %ld, dtype %d, comm %p", rank, sendbuff, recvbuff,
|
||||
sendcount, datatype, comm);
|
||||
INFO(MSCCLPP_NCCL, "rank ", rank, " allgather sendbuff ", sendbuff, " recvbuff ", recvbuff, " count ", sendcount,
|
||||
", dtype ", datatype, ", comm ", (void*)comm);
|
||||
|
||||
const char* fallbackList = mscclpp::env()->forceNcclFallbackOperation.c_str();
|
||||
if (mscclppNcclDlopenSharedLib == true && mscclppNcclInFallbackList("allgather", fallbackList)) {
|
||||
@@ -866,20 +865,20 @@ ncclResult_t ncclMemAlloc(void** ptr, size_t size) {
|
||||
}
|
||||
} catch (const mscclpp::Error& e) {
|
||||
if (e.getErrorCode() == mscclpp::ErrorCode::InvalidUsage) {
|
||||
WARN(MSCCLPP_NCCL, "Invalid usage: %s", e.what());
|
||||
WARN(MSCCLPP_NCCL, "Invalid usage: ", e.what());
|
||||
return ncclInvalidUsage;
|
||||
} else {
|
||||
WARN(MSCCLPP_NCCL, "Internal error: %s", e.what());
|
||||
WARN(MSCCLPP_NCCL, "Internal error: ", e.what());
|
||||
return ncclInternalError;
|
||||
}
|
||||
} catch (const mscclpp::CudaError& e) {
|
||||
WARN(MSCCLPP_NCCL, "Cuda error: %s", e.what());
|
||||
WARN(MSCCLPP_NCCL, "Cuda error: ", e.what());
|
||||
return ncclUnhandledCudaError;
|
||||
} catch (const mscclpp::CuError& e) {
|
||||
WARN(MSCCLPP_NCCL, "Cu error: %s", e.what());
|
||||
WARN(MSCCLPP_NCCL, "Cu error: ", e.what());
|
||||
return ncclUnhandledCudaError;
|
||||
} catch (const mscclpp::BaseError& e) {
|
||||
WARN(MSCCLPP_NCCL, "Base error: %s", e.what());
|
||||
WARN(MSCCLPP_NCCL, "Base error: ", e.what());
|
||||
return ncclInternalError;
|
||||
}
|
||||
ptrMap[sharedPtr.get()] = sharedPtr;
|
||||
|
||||
@@ -30,6 +30,12 @@ fi
|
||||
if [ "${PLATFORM}" == "rocm" ]; then
|
||||
export CXX=/opt/rocm/bin/hipcc
|
||||
fi
|
||||
|
||||
PIP_CMAKE_ARGS_FILE="/root/mscclpp/pip_cmake_args.txt"
|
||||
if [ -f "${PIP_CMAKE_ARGS_FILE}" ]; then
|
||||
export CMAKE_ARGS="$(cat ${PIP_CMAKE_ARGS_FILE})"
|
||||
echo "Using CMAKE_ARGS: ${CMAKE_ARGS}"
|
||||
fi
|
||||
cd /root/mscclpp && pip3 install .
|
||||
pip3 install setuptools_scm
|
||||
python3 -m setuptools_scm --force-write-version-files
|
||||
|
||||
@@ -8,18 +8,32 @@
|
||||
#include "mp_unit_tests.hpp"
|
||||
#include "utils_internal.hpp"
|
||||
|
||||
// Skip the current test if HostNoAtomic mode is not supported.
|
||||
// On CUDA, HostNoAtomic requires GDRCopy for BAR1 signal forwarding.
|
||||
// On ROCm, HostNoAtomic uses direct volatile writes and does not need GDRCopy.
|
||||
// Skip the current test if the given IB mode will require GDRCopy on CUDA but it is unavailable.
|
||||
// On CUDA, HostNoAtomic requires GDRCopy for BAR1 signal forwarding. When IbMode::Host or
|
||||
// IbMode::Default is used and the IB device does not support RDMA atomics, the endpoint falls
|
||||
// back to no-atomic mode, which also requires GDRCopy.
|
||||
// On ROCm, no-atomic mode uses direct volatile writes and does not need GDRCopy.
|
||||
#if defined(MSCCLPP_USE_CUDA)
|
||||
#define REQUIRE_HOST_NO_ATOMIC \
|
||||
do { \
|
||||
if (!mscclpp::gdrEnabled()) { \
|
||||
SKIP_TEST() << "HostNoAtomic requires GDRCopy: " << mscclpp::gdrStatusMessage(); \
|
||||
} \
|
||||
} while (0)
|
||||
inline void requireGdrForIbMode(IbMode mode, mscclpp::Transport ibTransport) {
|
||||
if (mscclpp::gdrEnabled()) return; // GDRCopy available — nothing to skip.
|
||||
if (mode == IbMode::HostNoAtomic) {
|
||||
SKIP_TEST() << "HostNoAtomic requires GDRCopy on CUDA: " << mscclpp::gdrStatusMessage();
|
||||
}
|
||||
// For Host/Default modes: check whether the IB device lacks RDMA atomics,
|
||||
// which would cause an automatic fallback to no-atomic mode.
|
||||
if (mode == IbMode::Host || mode == IbMode::Default) {
|
||||
std::string devName = mscclpp::getIBDeviceName(ibTransport);
|
||||
mscclpp::IbCtx ibCtx(devName);
|
||||
if (!ibCtx.supportsRdmaAtomics()) {
|
||||
SKIP_TEST() << "IB device " << devName
|
||||
<< " lacks RDMA atomics; Host mode falls back to HostNoAtomic which requires GDRCopy: "
|
||||
<< mscclpp::gdrStatusMessage();
|
||||
}
|
||||
}
|
||||
}
|
||||
#define REQUIRE_GDR_FOR_IB_MODE(mode) requireGdrForIbMode((mode), ibTransport)
|
||||
#else
|
||||
#define REQUIRE_HOST_NO_ATOMIC // No extra requirements on non-CUDA platforms.
|
||||
#define REQUIRE_GDR_FOR_IB_MODE(mode) // No extra requirements on non-CUDA platforms.
|
||||
#endif
|
||||
|
||||
void PortChannelOneToOneTest::SetUp() {
|
||||
@@ -254,6 +268,7 @@ TEST(PortChannelOneToOneTest, PingPong) {
|
||||
|
||||
TEST(PortChannelOneToOneTest, PingPongIbHostMode) {
|
||||
REQUIRE_IBVERBS;
|
||||
REQUIRE_GDR_FOR_IB_MODE(IbMode::Host);
|
||||
testPingPong(PingPongTestParams{
|
||||
.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = false, .ibMode = IbMode::Host});
|
||||
}
|
||||
@@ -270,6 +285,7 @@ TEST(PortChannelOneToOneTest, PingPongWithPoll) {
|
||||
|
||||
TEST(PortChannelOneToOneTest, PingPongIbHostModeWithPoll) {
|
||||
REQUIRE_IBVERBS;
|
||||
REQUIRE_GDR_FOR_IB_MODE(IbMode::Host);
|
||||
testPingPong(PingPongTestParams{
|
||||
.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = true, .ibMode = IbMode::Host});
|
||||
}
|
||||
@@ -281,13 +297,14 @@ PERF_TEST(PortChannelOneToOneTest, PingPongPerf) {
|
||||
|
||||
PERF_TEST(PortChannelOneToOneTest, PingPongPerfIbHostMode) {
|
||||
REQUIRE_IBVERBS;
|
||||
REQUIRE_GDR_FOR_IB_MODE(IbMode::Host);
|
||||
testPingPongPerf(PingPongTestParams{
|
||||
.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = false, .ibMode = IbMode::Host});
|
||||
}
|
||||
|
||||
PERF_TEST(PortChannelOneToOneTest, PingPongPerfIbHostNoAtomicMode) {
|
||||
REQUIRE_IBVERBS;
|
||||
REQUIRE_HOST_NO_ATOMIC;
|
||||
REQUIRE_GDR_FOR_IB_MODE(IbMode::HostNoAtomic);
|
||||
testPingPongPerf(PingPongTestParams{
|
||||
.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = false, .ibMode = IbMode::HostNoAtomic});
|
||||
}
|
||||
@@ -469,6 +486,7 @@ TEST(PortChannelOneToOneTest, PacketPingPong) { testPacketPingPong(false, IbMode
|
||||
|
||||
TEST(PortChannelOneToOneTest, PacketPingPongIbHostMode) {
|
||||
REQUIRE_IBVERBS;
|
||||
REQUIRE_GDR_FOR_IB_MODE(IbMode::Host);
|
||||
testPacketPingPong(true, IbMode::Host);
|
||||
}
|
||||
|
||||
@@ -476,25 +494,26 @@ PERF_TEST(PortChannelOneToOneTest, PacketPingPongPerf) { testPacketPingPongPerf(
|
||||
|
||||
PERF_TEST(PortChannelOneToOneTest, PacketPingPongPerfIbHostMode) {
|
||||
REQUIRE_IBVERBS;
|
||||
REQUIRE_GDR_FOR_IB_MODE(IbMode::Host);
|
||||
testPacketPingPongPerf(true, IbMode::Host);
|
||||
}
|
||||
|
||||
PERF_TEST(PortChannelOneToOneTest, PacketPingPongPerfIbHostNoAtomicMode) {
|
||||
REQUIRE_IBVERBS;
|
||||
REQUIRE_HOST_NO_ATOMIC;
|
||||
REQUIRE_GDR_FOR_IB_MODE(IbMode::HostNoAtomic);
|
||||
testPacketPingPongPerf(true, IbMode::HostNoAtomic);
|
||||
}
|
||||
|
||||
TEST(PortChannelOneToOneTest, PingPongIbHostNoAtomicMode) {
|
||||
REQUIRE_IBVERBS;
|
||||
REQUIRE_HOST_NO_ATOMIC;
|
||||
REQUIRE_GDR_FOR_IB_MODE(IbMode::HostNoAtomic);
|
||||
testPingPong(PingPongTestParams{
|
||||
.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = false, .ibMode = IbMode::HostNoAtomic});
|
||||
}
|
||||
|
||||
TEST(PortChannelOneToOneTest, PacketPingPongIbHostNoAtomicMode) {
|
||||
REQUIRE_IBVERBS;
|
||||
REQUIRE_HOST_NO_ATOMIC;
|
||||
REQUIRE_GDR_FOR_IB_MODE(IbMode::HostNoAtomic);
|
||||
testPacketPingPong(true, IbMode::HostNoAtomic);
|
||||
}
|
||||
|
||||
@@ -570,13 +589,14 @@ PERF_TEST(PortChannelOneToOneTest, Bandwidth) {
|
||||
|
||||
PERF_TEST(PortChannelOneToOneTest, BandwidthIbHostMode) {
|
||||
REQUIRE_IBVERBS;
|
||||
REQUIRE_GDR_FOR_IB_MODE(IbMode::Host);
|
||||
testBandwidth(PingPongTestParams{
|
||||
.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = false, .ibMode = IbMode::Host});
|
||||
}
|
||||
|
||||
PERF_TEST(PortChannelOneToOneTest, BandwidthIbHostNoAtomicMode) {
|
||||
REQUIRE_IBVERBS;
|
||||
REQUIRE_HOST_NO_ATOMIC;
|
||||
REQUIRE_GDR_FOR_IB_MODE(IbMode::HostNoAtomic);
|
||||
testBandwidth(PingPongTestParams{
|
||||
.useIPC = false, .useIB = true, .useEthernet = false, .waitWithPoll = false, .ibMode = IbMode::HostNoAtomic});
|
||||
}
|
||||
|
||||
@@ -14,25 +14,25 @@ def parse_npkit_event_header(npkit_event_header_path):
|
||||
"NOP",
|
||||
"BARRIER",
|
||||
"PUT",
|
||||
"PUT_PACKET",
|
||||
"READ_PUT_PACKET",
|
||||
"PUT_PACKETS",
|
||||
"READ_PUT_PACKETS",
|
||||
"PUT_WITH_SIGNAL",
|
||||
"PUT_WITH_SIGNAL_AND_FLUSH",
|
||||
"GET",
|
||||
"COPY",
|
||||
"COPY_PACKET",
|
||||
"TRANSFORM_TO_PACKET",
|
||||
"COPY_PACKETS",
|
||||
"UNPACK_PACKETS",
|
||||
"SIGNAL",
|
||||
"WAIT",
|
||||
"FLUSH",
|
||||
"REDUCE",
|
||||
"REDUCE_PACKET",
|
||||
"REDUCE_PACKETS",
|
||||
"REDUCE_COPY_PACKETS",
|
||||
"REDUCE_SEND",
|
||||
"REDUCE_SEND_PACKET",
|
||||
"REDUCE_SEND_PACKETS",
|
||||
"REDUCE_COPY_SEND_PACKETS",
|
||||
"READ_REDUCE_COPY",
|
||||
"READ_REDUCE_COPY_SEND",
|
||||
"READ_REDUCE",
|
||||
"READ_REDUCE_SEND",
|
||||
"MULTI_LOAD_REDUCE_STORE",
|
||||
"RELAXED_SIGNAL",
|
||||
"RELAXED_WAIT",
|
||||
|
||||
Reference in New Issue
Block a user