From fd76507e9a6b57a6de7ad832deb4a15dc4d60195 Mon Sep 17 00:00:00 2001 From: Ekow Wellington <34079588+ekwhoa@users.noreply.github.com> Date: Tue, 31 Mar 2026 14:27:33 -0500 Subject: [PATCH 01/21] Install default plans under MSCCLPP_CACHE_DIR/default (#769) ### Summary Update the installer to place bundled default execution plans under `/default`, which is where the runtime already looks for bundled plans. ### Background The C++ runtime treats `MSCCLPP_CACHE_DIR` as the cache *root* and loads bundled default plans from `/default`. When `MSCCLPP_CACHE_DIR` was set, the installer instead wrote bundled plans directly into the cache root, causing the runtime to miss them. This surfaced while running benchmarking tests with a non-default `MSCCLPP_CACHE_DIR`, where the bundled plans were not being discovered. ### Change This PR updates the installer to always install bundled default plans into `/default`, preserving the existing runtime contract. ### Scope - Installer-only change - No runtime behavior changes ### Validation Manual inspection of the updated install path. Successful build --------- Co-authored-by: Ekow Wellington --- docs/dsl/quick_start.md | 4 ++++ docs/dsl/results.md | 3 +++ python/mscclpp/__main__.py | 2 +- 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/dsl/quick_start.md b/docs/dsl/quick_start.md index 6c32ec32..afccd48e 100644 --- a/docs/dsl/quick_start.md +++ b/docs/dsl/quick_start.md @@ -12,6 +12,10 @@ After finishing the installation in the quick start section, you can add the fol python3 -m mscclpp --install ``` +This installs bundled default execution plans into `~/.cache/mscclpp/default` by default. +If `MSCCLPP_CACHE_DIR` is set, bundled default plans are installed into `MSCCLPP_CACHE_DIR/default`. +`MSCCLPP_CACHE_DIR` specifies the cache root directory, so it should be set without `default` in the path. + ## Your First Algorithm: AllGather Let's walk through a simple AllGather algorithm to understand the DSL basics. This example demonstrates the key concepts without diving into all the advanced features. diff --git a/docs/dsl/results.md b/docs/dsl/results.md index 99f19476..a1adad2a 100644 --- a/docs/dsl/results.md +++ b/docs/dsl/results.md @@ -59,6 +59,9 @@ After installation, the generated JSON execution plan can be found at: ~/.cache/mscclpp/default/ ``` +If `MSCCLPP_CACHE_DIR` is set, bundled default plans are installed under `MSCCLPP_CACHE_DIR/default/`. +`MSCCLPP_CACHE_DIR` specifies the cache root directory, so it should be set without `default` in the path. + **Performance Results:** The figure below shows the performance characteristics for small message sizes in a two-node configuration: diff --git a/python/mscclpp/__main__.py b/python/mscclpp/__main__.py index d57cb362..6a6f5f28 100644 --- a/python/mscclpp/__main__.py +++ b/python/mscclpp/__main__.py @@ -57,7 +57,7 @@ default_algo_configs = [ def create_default_plans(): - plan_dir = os.environ.get("MSCCLPP_CACHE_DIR", Path.home() / ".cache/mscclpp/default") + plan_dir = os.path.join(os.environ.get("MSCCLPP_CACHE_DIR", Path.home() / ".cache/mscclpp"), "default") plan_path = Path(plan_dir) if plan_path.exists(): shutil.rmtree(plan_path) From 4f3638b60db4640eb5f0cd4c1c92e05a72227474 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Tue, 31 Mar 2026 15:34:43 -0700 Subject: [PATCH 02/21] Use PTX red for D2D semaphore signal (#768) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Replace the two-step `signal()` implementation (`incOutbound()` + `atomicStore()`) with a single fire-and-forget PTX `red.release.sys.global.add.u64` instruction - This eliminates one local atomic fetch-add and replaces a remote store with a remote atomic add that has no return value — more efficient on both NVIDIA (PTX `red`) and AMD (compiler optimizes `(void)fetch_add` to fire-and-forget `flat_atomic_add_x2`) - Add a C++ perf test (`PERF_TEST`) in `mp_unit` for signal+wait ping-pong latency ### Performance (H100, 2 ranks, signal+wait round-trip) ``` SemaphorePerfTest.SignalPingPong: Store-based (old): 2.595 us/iter Red-based (new): 2.345 us/iter Speedup: 1.11x ``` ## Test plan - [x] Builds successfully (`make mp_unit_tests`) - [x] `mpirun -np 2 ./build/bin/mp_unit_tests --filter "SemaphorePerfTest"` — 1.11x speedup 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 --- include/mscclpp/semaphore.hpp | 1 - include/mscclpp/semaphore_device.hpp | 34 ++++--------- python/csrc/semaphore_py.cpp | 1 - src/core/semaphore.cc | 5 +- test/mp_unit/CMakeLists.txt | 1 + test/mp_unit/mp_unit_tests.hpp | 6 +++ test/mp_unit/semaphore_perf_tests.cu | 73 ++++++++++++++++++++++++++++ 7 files changed, 91 insertions(+), 30 deletions(-) create mode 100644 test/mp_unit/semaphore_perf_tests.cu diff --git a/include/mscclpp/semaphore.hpp b/include/mscclpp/semaphore.hpp index 27f9aefa..85787c95 100644 --- a/include/mscclpp/semaphore.hpp +++ b/include/mscclpp/semaphore.hpp @@ -82,7 +82,6 @@ class MemoryDevice2DeviceSemaphore { private: Semaphore semaphore_; detail::UniqueGpuPtr expectedInboundToken_; - detail::UniqueGpuPtr outboundToken_; public: /// Constructor. diff --git a/include/mscclpp/semaphore_device.hpp b/include/mscclpp/semaphore_device.hpp index f1b01e89..a790a6e1 100644 --- a/include/mscclpp/semaphore_device.hpp +++ b/include/mscclpp/semaphore_device.hpp @@ -82,19 +82,20 @@ struct MemoryDevice2DeviceSemaphoreDeviceHandle { /// Signal remote device, ensures prior memory ops complete. MSCCLPP_DEVICE_INLINE void signal() { - auto outbound = incOutbound(); -#if defined(MSCCLPP_DEVICE_CUDA) && (__CUDA_ARCH__ == 800) - // Using memoryOrderSeqCst is faster for A100. - atomicStore(remoteInboundToken, outbound, memoryOrderSeqCst); -#else - atomicStore(remoteInboundToken, outbound, memoryOrderRelease); +#if defined(MSCCLPP_DEVICE_CUDA) + asm volatile("red.release.sys.global.add.u64 [%0], %1;" ::"l"(remoteInboundToken), "l"((uint64_t)1) : "memory"); +#elif defined(MSCCLPP_DEVICE_HIP) + (void)atomicFetchAdd(remoteInboundToken, (uint64_t)1, memoryOrderRelease); #endif } /// Relaxed signal; no memory completion guarantee. Use it only for synchronizing execution, not data. MSCCLPP_DEVICE_INLINE void relaxedSignal() { - auto outbound = incOutbound(); - atomicStore(remoteInboundToken, outbound, memoryOrderRelaxed); +#if defined(MSCCLPP_DEVICE_CUDA) + asm volatile("red.relaxed.sys.global.add.u64 [%0], %1;" ::"l"(remoteInboundToken), "l"((uint64_t)1) : "memory"); +#elif defined(MSCCLPP_DEVICE_HIP) + (void)atomicFetchAdd(remoteInboundToken, (uint64_t)1, memoryOrderRelaxed); +#endif } /// Thread-safe read of expected inbound value. @@ -121,27 +122,12 @@ struct MemoryDevice2DeviceSemaphoreDeviceHandle { return atomicLoad(inboundToken, memoryOrderRelaxed); } - /// Thread-safe read of outbound value. - /// @return The outbound value. - MSCCLPP_DEVICE_INLINE uint64_t loadOutbound() { - return atomicLoad(outboundToken, memoryOrderRelaxed); - } - - /// Thread-safe increment of outbound value. - /// @return The incremented outbound value. - MSCCLPP_DEVICE_INLINE uint64_t incOutbound() { - return atomicFetchAdd(outboundToken, 1, memoryOrderRelaxed) + 1; - } #endif // defined(MSCCLPP_DEVICE_COMPILE) /// A local memory space where the remote device will write its semaphore value and the local device will read it. uint64_t* inboundToken; - /// A local memory space where the local device stores the semaphore value to be written to the remote device. - uint64_t* outboundToken; - - /// A remote memory space where the local device writes its outboundToken on. This is inboundToken of the - /// remote device. + /// A remote memory space where the local device atomically increments. This is inboundToken of the remote device. uint64_t* remoteInboundToken; /// A local memory space where the local device stores the expected value of the inboundToken to wait for. diff --git a/python/csrc/semaphore_py.cpp b/python/csrc/semaphore_py.cpp index 36d559f2..17c06a7d 100644 --- a/python/csrc/semaphore_py.cpp +++ b/python/csrc/semaphore_py.cpp @@ -43,7 +43,6 @@ void register_semaphore(nb::module_& m) { nb::class_(memoryDevice2DeviceSemaphore, "DeviceHandle") .def(nb::init<>()) .def_rw("inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::inboundToken) - .def_rw("outbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::outboundToken) .def_rw("remote_inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::remoteInboundToken) .def_rw("expected_inbound_token", &MemoryDevice2DeviceSemaphore::DeviceHandle::expectedInboundToken) .def_prop_ro("raw", [](const MemoryDevice2DeviceSemaphore::DeviceHandle& self) -> nb::bytes { diff --git a/src/core/semaphore.cc b/src/core/semaphore.cc index c6eb1e23..bea43327 100644 --- a/src/core/semaphore.cc +++ b/src/core/semaphore.cc @@ -183,9 +183,7 @@ MSCCLPP_API_CPP void Host2HostSemaphore::wait(int64_t maxSpinCount) { } MSCCLPP_API_CPP MemoryDevice2DeviceSemaphore::MemoryDevice2DeviceSemaphore(const Semaphore& semaphore) - : semaphore_(semaphore), - expectedInboundToken_(detail::gpuCallocUnique()), - outboundToken_(detail::gpuCallocUnique()) { + : semaphore_(semaphore), expectedInboundToken_(detail::gpuCallocUnique()) { if (connection().localDevice().type != DeviceType::GPU) { throw Error("Local endpoint device type of MemoryDevice2DeviceSemaphore should be GPU", ErrorCode::InvalidUsage); } @@ -202,7 +200,6 @@ MSCCLPP_API_CPP MemoryDevice2DeviceSemaphore::DeviceHandle MemoryDevice2DeviceSe device.remoteInboundToken = reinterpret_cast(semaphore_.remoteMemory().data()); device.inboundToken = reinterpret_cast(semaphore_.localMemory().data()); device.expectedInboundToken = expectedInboundToken_.get(); - device.outboundToken = outboundToken_.get(); return device; }; diff --git a/test/mp_unit/CMakeLists.txt b/test/mp_unit/CMakeLists.txt index b99bb09d..d4004e8e 100644 --- a/test/mp_unit/CMakeLists.txt +++ b/test/mp_unit/CMakeLists.txt @@ -8,6 +8,7 @@ target_sources(mp_unit_tests PRIVATE communicator_tests.cu port_channel_tests.cu memory_channel_tests.cu + semaphore_perf_tests.cu switch_channel_tests.cu executor_tests.cc ) diff --git a/test/mp_unit/mp_unit_tests.hpp b/test/mp_unit/mp_unit_tests.hpp index 03e4cbde..5f95d660 100644 --- a/test/mp_unit/mp_unit_tests.hpp +++ b/test/mp_unit/mp_unit_tests.hpp @@ -176,6 +176,12 @@ class MemoryChannelOneToOneTest : public CommunicatorTestBase { std::unordered_map> memorySemaphores; }; +class SemaphorePerfTest : public CommunicatorTestBase { + protected: + void SetUp() override; + void TearDown() override; +}; + class SwitchChannelTest : public CommunicatorTestBase { protected: void SetUp() override; diff --git a/test/mp_unit/semaphore_perf_tests.cu b/test/mp_unit/semaphore_perf_tests.cu new file mode 100644 index 00000000..92560539 --- /dev/null +++ b/test/mp_unit/semaphore_perf_tests.cu @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include +#include + +#include "mp_unit_tests.hpp" + +void SemaphorePerfTest::SetUp() { + // Need at least two ranks within a node + if (gEnv->nRanksPerNode < 2) { + SKIP_TEST(); + } + setNumRanksToUse(2); + CommunicatorTestBase::SetUp(); +} + +void SemaphorePerfTest::TearDown() { CommunicatorTestBase::TearDown(); } + +__constant__ mscclpp::MemoryDevice2DeviceSemaphoreDeviceHandle gSemaphorePerfTestHandle; + +__global__ void kernelSemaphorePingPong(int rank, int nIters) { + mscclpp::MemoryDevice2DeviceSemaphoreDeviceHandle& sem = gSemaphorePerfTestHandle; + + // Warmup + for (int i = 0; i < 10; i++) { + if ((rank ^ (i & 1)) == 0) { + sem.signal(); + } else { + sem.wait(); + } + } + + // Timed iterations — alternating signal/wait like the memory channel ping-pong + for (int i = 0; i < nIters; i++) { + if ((rank ^ (i & 1)) == 0) { + sem.signal(); + } else { + sem.wait(); + } + } +} + +PERF_TEST(SemaphorePerfTest, SignalPingPong) { + if (gEnv->rank >= numRanksToUse) return; + + connectMesh(/*useIpc=*/true, /*useIb=*/false, /*useEthernet=*/false); + + int peerRank = (gEnv->rank == 0) ? 1 : 0; + auto d2dSemaphore = std::make_shared(*communicator, connections[peerRank]); + + auto devHandle = d2dSemaphore->deviceHandle(); + MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gSemaphorePerfTestHandle, &devHandle, sizeof(devHandle))); + + const int nIters = 1000; + const std::string testName = ::mscclpp::test::currentTestName(); + + // Warmup run + kernelSemaphorePingPong<<<1, 1>>>(gEnv->rank, nIters); + MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); + + communicator->bootstrap()->barrier(); + + // Timed run + mscclpp::Timer timer; + kernelSemaphorePingPong<<<1, 1>>>(gEnv->rank, nIters); + MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); + communicator->bootstrap()->barrier(); + + if (gEnv->rank == 0) { + std::cout << testName << ": " << std::setprecision(4) << (float)timer.elapsed() / (float)nIters << " us/iter\n"; + } +} From d2f7056cf4d1956cb452ee475b331f8e19e1d886 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 31 Mar 2026 22:30:35 -0700 Subject: [PATCH 03/21] Add unit testing framework readme (#766) --- test/README.md | 130 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 test/README.md diff --git a/test/README.md b/test/README.md new file mode 100644 index 00000000..a69b66ad --- /dev/null +++ b/test/README.md @@ -0,0 +1,130 @@ +# MSCCL++ C++ Test Framework + +A lightweight, GTest-like test framework with MPI support for testing MSCCL++ C++ APIs. Defined in `framework.hpp` / `framework.cc`. + +## Adding a New Test (Step-by-Step) + +### Single-process test (unit/) + +1. **Create the test file** `test/unit/my_feature_tests.cc` (or `.cu` for CUDA): + + ```cpp + #include "../framework.hpp" + #include + + TEST(MyFeatureTest, BasicUsage) { + EXPECT_EQ(myFunction(), 42); + } + ``` + +2. **Register it in CMake** — add the filename to `test/unit/CMakeLists.txt`: + + ```cmake + target_sources(unit_tests PRIVATE + ... + my_feature_tests.cc # <-- add here + ) + ``` + +3. **Build and run**: + + ```bash + cmake --build build -j + ./build/test/unit_tests --filter=MyFeatureTest + ``` + +### Multi-process test (mp_unit/) + +1. **Create the test file** `test/mp_unit/my_feature_tests.cc` (or `.cu`): + + ```cpp + #include "mp_unit_tests.hpp" + + TEST(MyFeatureTest, MultiRank) { + int rank = gEnv->rank; + EXPECT_GE(rank, 0); + } + ``` + + Use fixtures from `mp_unit_tests.hpp` (e.g., `CommunicatorTest`) if you need pre-established connections. + +2. **Register it in CMake** — add the filename to `test/mp_unit/CMakeLists.txt`: + + ```cmake + target_sources(mp_unit_tests PRIVATE + ... + my_feature_tests.cc # <-- add here + ) + ``` + +3. **Build and run**: + + ```bash + cmake --build build -j + mpirun -np 2 ./build/test/mp_unit_tests --filter=MyFeatureTest + ``` + +### Notes + +- No separate test registration step is needed — `TEST()` auto-registers via static initialization. +- The `test_framework` static library is built from `framework.cc` in the top-level `test/CMakeLists.txt` and linked into both `unit_tests` and `mp_unit_tests`. You do not need to modify it. +- Use `.cu` extension for files that contain CUDA kernel code; use `.cc` for host-only tests. +- Each test binary needs a `main()` that calls `RUN_ALL_TESTS()`. See `unit/unit_tests_main.cc` (single-process) and `mp_unit/mp_unit_tests.cc` (multi-process with `Environment` setup). +- Additional run options: `--filter=-Pattern` (exclude), `--exclude-perf-tests` (skip `PERF_TEST`s). + +## Macros + +| Macro | Behavior | +|---|---| +| `TEST(Suite, Name)` | Register a test. If `Suite` is a defined class, it's used as a fixture. | +| `PERF_TEST(Suite, Name)` | Same as `TEST` but marked as perf (skippable via `--exclude-perf-tests`). | +| `EXPECT_*` | Non-fatal assertions: `EXPECT_TRUE`, `EXPECT_FALSE`, `EXPECT_EQ`, `EXPECT_NE`, `EXPECT_LT`, `EXPECT_LE`, `EXPECT_GT`, `EXPECT_GE` | +| `ASSERT_*` | Fatal assertions (abort test on failure): same variants as `EXPECT_*`, plus `ASSERT_NO_THROW` | +| `FAIL()` | Fail immediately. Supports streaming: `FAIL() << "reason";` | +| `SKIP_TEST()` | Skip the current test. Supports streaming: `SKIP_TEST() << "reason";` | +| `CUDA_CHECK(call)` | Check a CUDA API return code, throw on error. | + +## Fixtures + +Define a class inheriting from `mscclpp::test::TestCase` with `SetUp()` / `TearDown()`, then use the class name as the suite name: + +```cpp +class MyFixture : public mscclpp::test::TestCase { + public: + void SetUp() override { /* per-test setup */ } + void TearDown() override { /* per-test cleanup */ } + protected: + int sharedState_ = 0; +}; + +TEST(MyFixture, SomeTest) { + sharedState_ = 42; + EXPECT_EQ(sharedState_, 42); +} +``` + +See `mp_unit/mp_unit_tests.hpp` (`BootstrapTest`, `CommunicatorTest`, etc.) for real fixture examples. + +## Global Environments + +Register an `Environment` subclass for one-time global setup/teardown (e.g., MPI bootstrap): + +```cpp +class MyEnv : public mscclpp::test::Environment { + public: + void SetUp() override { /* global init */ } + void TearDown() override { /* global cleanup */ } +}; + +// In main(), before RUN_ALL_TESTS(): +mscclpp::test::TestRegistry::instance().addEnvironment(new MyEnv()); +``` + +See `mp_unit/mp_unit_tests.cc` for the `MultiProcessTestEnv` example. + +## Utilities + +- `mscclpp::test::utils::isMainRank()` — true on MPI rank 0 +- `mscclpp::test::utils::getMPIRank()` / `getMPISize()` +- `mscclpp::test::utils::Timer` — high-resolution timer with `start()`, `stop()`, `elapsedMilliseconds()` +- `mscclpp::test::currentTestName()` — returns `"Suite.Name"` for the running test \ No newline at end of file From be9126ca1b36c4817de622a0aebd87e5382b9a6b Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 1 Apr 2026 16:25:19 -0700 Subject: [PATCH 04/21] Fix run-remote.sh to support multi-command scripts (#770) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Fix `run-remote.sh` to correctly execute multi-command scripts (e.g., multiple `mpirun` calls) - The old approach piped decoded script through `base64 -d | bash`, which feeds the script via bash's **stdin**. When `mpirun` (or its child processes) runs, it can consume the remaining stdin, causing bash to never see subsequent commands — only the first command would execute. - The fix decodes the script to a **temp file** and runs `bash -euxo pipefail "$TMP"` instead, so bash reads commands from the file and `mpirun` consuming stdin has no effect. - Applied to both the docker path (pssh + docker exec) and the non-docker path (pssh only). 🤖 Generated with [Claude Code](https://claude.com/claude-code) --- test/deploy/run-remote.sh | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/deploy/run-remote.sh b/test/deploy/run-remote.sh index b646ea92..2468243e 100755 --- a/test/deploy/run-remote.sh +++ b/test/deploy/run-remote.sh @@ -97,11 +97,14 @@ if $USE_DOCKER; then INNER+=" cd /root/mscclpp;" INNER+=" export LD_LIBRARY_PATH=/root/mscclpp/build/lib:\\\$LD_LIBRARY_PATH;" INNER+=" CMD_B64='${CMD_B64}';" - INNER+=" printf '%s' \\\"\\\$CMD_B64\\\" | base64 -d | bash -euxo pipefail" + INNER+=" TMP=\\\$(mktemp);" + INNER+=" printf '%s' \\\"\\\$CMD_B64\\\" | base64 -d > \\\"\\\$TMP\\\";" + INNER+=" bash -euxo pipefail \\\"\\\$TMP\\\";" + INNER+=" rm -f \\\"\\\$TMP\\\"" parallel-ssh -i "${PSSH_COMMON[@]}" \ "sudo docker exec mscclpp-test bash -c \"${INNER}\"" else parallel-ssh -i "${PSSH_COMMON[@]}" \ - "set -euxo pipefail; CMD_B64='${CMD_B64}'; printf '%s' \"\$CMD_B64\" | base64 -d | bash -euxo pipefail" + "set -euxo pipefail; CMD_B64='${CMD_B64}'; TMP=\$(mktemp); printf '%s' \"\$CMD_B64\" | base64 -d > \"\$TMP\"; bash -euxo pipefail \"\$TMP\"; rm -f \"\$TMP\"" fi From fa95e82e18c5f963b059aefe20939d5ca8a63df2 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Tue, 7 Apr 2026 08:41:51 -0700 Subject: [PATCH 05/21] Fix CI/CD pipeline issues (#773) This pull request updates the deployment pipeline to allow custom CMake arguments to be passed to the pip install process on remote VMs. --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .azure-pipelines/templates/deploy.yml | 24 ++++++++++++++++++++++-- .azure-pipelines/templates/ut-npkit.yml | 10 +++++----- test/deploy/setup.sh | 6 ++++++ tools/npkit/npkit_trace_generator.py | 16 ++++++++-------- 4 files changed, 41 insertions(+), 15 deletions(-) diff --git a/.azure-pipelines/templates/deploy.yml b/.azure-pipelines/templates/deploy.yml index fc116acf..2f642f1d 100644 --- a/.azure-pipelines/templates/deploy.yml +++ b/.azure-pipelines/templates/deploy.yml @@ -94,7 +94,27 @@ steps: du -sh build/bin/* 2>/dev/null || true workingDirectory: '$(System.DefaultWorkingDirectory)' -# 2. Download SSH key + install packages + start VMSS +# 2. Write CMake args for pip install on remote VMs +- task: Bash@3 + name: WritePipCmakeArgs + displayName: Write pip CMake args + inputs: + targetType: 'inline' + script: | + set -e + PIP_CMAKE_ARGS="" + if [ -n "${{ parameters.gpuArch }}" ]; then + PIP_CMAKE_ARGS="-DMSCCLPP_GPU_ARCHS=${{ parameters.gpuArch }}" + fi + CMAKE_EXTRA_ARGS='${{ parameters.cmakeArgs }}' + if [ -n "${CMAKE_EXTRA_ARGS}" ]; then + PIP_CMAKE_ARGS="${PIP_CMAKE_ARGS} ${CMAKE_EXTRA_ARGS}" + fi + echo "${PIP_CMAKE_ARGS}" > pip_cmake_args.txt + echo "pip CMake args: $(cat pip_cmake_args.txt)" + workingDirectory: '$(System.DefaultWorkingDirectory)' + +# 3. Download SSH key + install packages + start VMSS - task: DownloadSecureFile@1 name: SshKeyFile displayName: Download key file @@ -120,7 +140,7 @@ steps: inlineScript: | az vmss start --name ${{ parameters.vmssName }} --resource-group ${{ parameters.resourceGroup }} -# 3. Deploy test environment +# 4. Deploy test environment - task: Bash@3 name: DeployTestEnv displayName: Deploy Test Env diff --git a/.azure-pipelines/templates/ut-npkit.yml b/.azure-pipelines/templates/ut-npkit.yml index e53b5cf5..1bd89caf 100644 --- a/.azure-pipelines/templates/ut-npkit.yml +++ b/.azure-pipelines/templates/ut-npkit.yml @@ -28,7 +28,7 @@ steps: grep -q NPKIT_EVENT_EXECUTOR_INIT_ENTRY ./npkit_output/npkit_event_trace.json grep -q NPKIT_EVENT_EXECUTOR_SIGNAL_ENTRY ./npkit_output/npkit_event_trace.json grep -q NPKIT_EVENT_EXECUTOR_WAIT_ENTRY ./npkit_output/npkit_event_trace.json - grep -q NPKIT_EVENT_EXECUTOR_READ_REDUCE_COPY_SEND_ENTRY ./npkit_output/npkit_event_trace.json + grep -q NPKIT_EVENT_EXECUTOR_READ_REDUCE_SEND_ENTRY ./npkit_output/npkit_event_trace.json - template: run-remote-task.yml parameters: @@ -42,14 +42,14 @@ steps: grep -q NPKIT_EVENT_EXECUTOR_INIT_ENTRY ./npkit_output/npkit_event_trace.json grep -q NPKIT_EVENT_EXECUTOR_SIGNAL_ENTRY ./npkit_output/npkit_event_trace.json grep -q NPKIT_EVENT_EXECUTOR_WAIT_ENTRY ./npkit_output/npkit_event_trace.json - grep -q NPKIT_EVENT_EXECUTOR_READ_REDUCE_COPY_SEND_ENTRY ./npkit_output/npkit_event_trace.json + grep -q NPKIT_EVENT_EXECUTOR_READ_REDUCE_SEND_ENTRY ./npkit_output/npkit_event_trace.json rm -rf ./npkit_dump && mkdir ./npkit_dump && rm -rf ./npkit_output && mkdir ./npkit_output mpirun --allow-run-as-root -tag-output -x MSCCLPP_HOME=/root/mscclpp -np 8 python3 -m pytest ./python/test/test_mscclpp.py -x -k 'test_executor[allreduce_packet.json' python3 ./tools/npkit/npkit_trace_generator.py --npkit_dump_dir=./npkit_dump --npkit_event_header_path=./include/mscclpp/npkit/npkit_event.hpp --output_dir=./npkit_output grep -q NPKIT_EVENT_EXECUTOR_INIT_ENTRY ./npkit_output/npkit_event_trace.json - grep -q NPKIT_EVENT_EXECUTOR_COPY_PACKET_ENTRY ./npkit_output/npkit_event_trace.json - grep -q NPKIT_EVENT_EXECUTOR_PUT_PACKET_ENTRY ./npkit_output/npkit_event_trace.json - grep -q NPKIT_EVENT_EXECUTOR_REDUCE_SEND_PACKET_ENTRY ./npkit_output/npkit_event_trace.json + grep -q NPKIT_EVENT_EXECUTOR_PUT_PACKETS_ENTRY ./npkit_output/npkit_event_trace.json + grep -q NPKIT_EVENT_EXECUTOR_REDUCE_SEND_PACKETS_ENTRY ./npkit_output/npkit_event_trace.json + grep -q NPKIT_EVENT_EXECUTOR_UNPACK_PACKETS_ENTRY ./npkit_output/npkit_event_trace.json - template: stop.yml parameters: diff --git a/test/deploy/setup.sh b/test/deploy/setup.sh index 80cd10b1..d4996cc2 100644 --- a/test/deploy/setup.sh +++ b/test/deploy/setup.sh @@ -30,6 +30,12 @@ fi if [ "${PLATFORM}" == "rocm" ]; then export CXX=/opt/rocm/bin/hipcc fi + +PIP_CMAKE_ARGS_FILE="/root/mscclpp/pip_cmake_args.txt" +if [ -f "${PIP_CMAKE_ARGS_FILE}" ]; then + export CMAKE_ARGS="$(cat ${PIP_CMAKE_ARGS_FILE})" + echo "Using CMAKE_ARGS: ${CMAKE_ARGS}" +fi cd /root/mscclpp && pip3 install . pip3 install setuptools_scm python3 -m setuptools_scm --force-write-version-files diff --git a/tools/npkit/npkit_trace_generator.py b/tools/npkit/npkit_trace_generator.py index c5ed6191..294516e6 100644 --- a/tools/npkit/npkit_trace_generator.py +++ b/tools/npkit/npkit_trace_generator.py @@ -14,25 +14,25 @@ def parse_npkit_event_header(npkit_event_header_path): "NOP", "BARRIER", "PUT", - "PUT_PACKET", - "READ_PUT_PACKET", + "PUT_PACKETS", + "READ_PUT_PACKETS", "PUT_WITH_SIGNAL", "PUT_WITH_SIGNAL_AND_FLUSH", "GET", "COPY", - "COPY_PACKET", - "TRANSFORM_TO_PACKET", + "COPY_PACKETS", + "UNPACK_PACKETS", "SIGNAL", "WAIT", "FLUSH", "REDUCE", - "REDUCE_PACKET", + "REDUCE_PACKETS", "REDUCE_COPY_PACKETS", "REDUCE_SEND", - "REDUCE_SEND_PACKET", + "REDUCE_SEND_PACKETS", "REDUCE_COPY_SEND_PACKETS", - "READ_REDUCE_COPY", - "READ_REDUCE_COPY_SEND", + "READ_REDUCE", + "READ_REDUCE_SEND", "MULTI_LOAD_REDUCE_STORE", "RELAXED_SIGNAL", "RELAXED_WAIT", From 96a72bbd3e71df14f8afca6b4daaf907bbad8e8e Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Tue, 7 Apr 2026 13:37:02 -0700 Subject: [PATCH 06/21] Support E4M3B15 datatype (#765) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - **Add `fp8_e4m3b15` datatype**: A software-defined FP8 type with 4 exponent bits, 3 mantissa bits, and bias=15 (max finite value: 0.9375). Implemented entirely in software with no HW dependency, using Triton-style bit manipulation through fp16 as intermediate for efficient conversion. - **Add mixed-precision accumulation for allreduce**: All allreduce algorithm variants (packet, NVLS packet, fullmesh, RSAG zero-copy, and others) now support a configurable `accumDtype` parameter, enabling FP8 inputs to be reduced in float16 or float32 for higher accuracy. - **Propagate `accumDtype` through the full API**: The new parameter is threaded from `Algorithm::execute()` → `NativeAlgorithm` → `KernelFunc` → dispatch → CUDA kernels, with `DataType::AUTO` as the default (resolves to input dtype at runtime). - **Add FP8 accumulation correctness tests**: New `test_fp8_accum.py` validates that higher-precision accumulation produces results at least as accurate as native FP8 accumulation across multiple algorithms and sizes. Skipped on CUDA SM < 89 (pre-Hopper); runs on HIP/ROCm. - **Add `test_fp8_accum.py` to CI**: Azure Pipeline `ut.yml` now runs FP8 accumulation tests alongside existing pytests. - **NCCL shim logging cleanup**: Migrated `printf`-style `WARN`/`INFO` calls to streaming-style logging. ## Key files | Area | Files | |------|-------| | New datatype + vector ops | `include/mscclpp/gpu_data_types.hpp` | | Accumulation reduce helpers | `src/core/include/reduce_kernel.hpp` | | Algorithm API (`accumDtype`) | `include/mscclpp/algorithm.hpp`, `src/core/algorithm.cc` | | Allreduce kernels | `src/ext/collectives/allreduce/*.cu` | | Dispatch + common | `src/ext/collectives/include/allreduce/common.hpp` | | Python bindings | `python/csrc/algorithm.cpp`, `python/mscclpp/_core/algorithm.py` | | Tests | `python/test/test_fp8_accum.py` | | CI | `.azure-pipelines/templates/ut.yml` | ## Test plan - [x] CI passes on H100 (CUDA SM 90) — full FP8 E4M3 + E4M3B15 accumulation tests - [x] CI passes on A100 (CUDA SM 80) — FP8 tests correctly skipped - [x] CI passes on MI300X (ROCm) — FP8 tests run via HIP - [x] Existing `test_mscclpp.py` tests continue to pass - [x] NCCL shim builds and runs correctly with new `accumDtype` defaults 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .azure-pipelines/templates/ut.yml | 1 + docs/guide/mscclpp-torch-integration.md | 3 +- .../customized_allgather.cu | 3 +- .../torch-integration/customized_allgather.cu | 3 +- include/mscclpp/algorithm.hpp | 15 +- include/mscclpp/gpu_data_types.hpp | 771 +++++++++++++++++- python/csrc/algorithm.cpp | 8 +- python/csrc/core_py.cpp | 3 +- python/csrc/gpu_utils_py.cpp | 13 + python/mscclpp/_core/algorithm.py | 8 +- python/test/test_fp8_accum.py | 391 +++++++++ src/core/algorithm.cc | 17 +- src/core/executor/execution_kernel.cu | 6 + src/core/include/execution_kernel.hpp | 27 +- src/core/include/reduce_kernel.hpp | 174 +++- .../allgather/allgather_fullmesh.cu | 3 +- .../allgather/allgather_fullmesh_2.cu | 3 +- .../allreduce/allreduce_allpair_packet.cu | 13 +- .../allreduce/allreduce_fullmesh.cu | 37 +- .../allreduce_nvls_block_pipeline.cu | 14 +- .../allreduce/allreduce_nvls_packet.cu | 45 +- .../allreduce/allreduce_nvls_warp_pipeline.cu | 19 +- .../allreduce/allreduce_nvls_zero_copy.cu | 15 +- .../collectives/allreduce/allreduce_packet.cu | 68 +- .../collectives/allreduce/allreduce_rsag.cu | 13 +- .../allreduce/allreduce_rsag_pipeline.cu | 19 +- .../allreduce/allreduce_rsag_zero_copy.cu | 31 +- .../allreduce/allreduce_allpair_packet.hpp | 2 +- .../include/allreduce/allreduce_fullmesh.hpp | 2 +- .../allreduce_nvls_block_pipeline.hpp | 2 +- .../allreduce/allreduce_nvls_packet.hpp | 4 +- .../allreduce_nvls_warp_pipeline.hpp | 2 +- .../allreduce/allreduce_nvls_zero_copy.hpp | 2 +- .../include/allreduce/allreduce_packet.hpp | 2 +- .../include/allreduce/allreduce_rsag.hpp | 2 +- .../allreduce/allreduce_rsag_pipeline.hpp | 2 +- .../allreduce/allreduce_rsag_zero_copy.hpp | 2 +- .../collectives/include/allreduce/common.hpp | 92 +-- src/ext/nccl/algorithm_selector.cc | 3 +- src/ext/nccl/datatype_conversion.hpp | 5 + src/ext/nccl/nccl.cc | 39 +- 41 files changed, 1623 insertions(+), 261 deletions(-) create mode 100644 python/test/test_fp8_accum.py diff --git a/.azure-pipelines/templates/ut.yml b/.azure-pipelines/templates/ut.yml index 9d17e923..743c66e6 100644 --- a/.azure-pipelines/templates/ut.yml +++ b/.azure-pipelines/templates/ut.yml @@ -41,6 +41,7 @@ steps: displayName: Run pytests remoteScript: | mpirun --allow-run-as-root -tag-output -x MSCCLPP_HOME=/root/mscclpp -x GPU_MAX_HW_QUEUES=8 -np 8 python3 -m pytest ./python/test/test_mscclpp.py -x + mpirun --allow-run-as-root -tag-output -x MSCCLPP_HOME=/root/mscclpp -x GPU_MAX_HW_QUEUES=8 -np 8 python3 -m pytest ./python/test/test_fp8_accum.py -x - template: stop.yml parameters: diff --git a/docs/guide/mscclpp-torch-integration.md b/docs/guide/mscclpp-torch-integration.md index 1c966155..b4e4fcdf 100644 --- a/docs/guide/mscclpp-torch-integration.md +++ b/docs/guide/mscclpp-torch-integration.md @@ -332,7 +332,8 @@ public: size_t inputSize, size_t outputSize, mscclpp::DataType dtype, mscclpp::ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras) { + const std::unordered_map& extras, + [[maybe_unused]] mscclpp::DataType accumDtype) { return self->kernelFunc(ctx, input, output, inputSize, dtype, stream); }, // Context initialization function diff --git a/examples/customized-collective-algorithm/customized_allgather.cu b/examples/customized-collective-algorithm/customized_allgather.cu index e78c4777..02df3685 100644 --- a/examples/customized-collective-algorithm/customized_allgather.cu +++ b/examples/customized-collective-algorithm/customized_allgather.cu @@ -101,7 +101,8 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder { "allgather", "allgather", [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, size_t outputSize, mscclpp::DataType dtype, [[maybe_unused]] mscclpp::ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) { + int nThreadsPerBlock, const std::unordered_map& extras, + [[maybe_unused]] mscclpp::DataType accumDtype) { return self->allgatherKernelFunc(ctx, input, output, inputSize, stream); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, diff --git a/examples/torch-integration/customized_allgather.cu b/examples/torch-integration/customized_allgather.cu index d48c4410..907b3ada 100644 --- a/examples/torch-integration/customized_allgather.cu +++ b/examples/torch-integration/customized_allgather.cu @@ -69,7 +69,8 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder { "allgather", "allgather", [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, size_t outputSize, mscclpp::DataType dtype, [[maybe_unused]] mscclpp::ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) { + int nThreadsPerBlock, const std::unordered_map& extras, + [[maybe_unused]] mscclpp::DataType accumDtype) { return self->allgatherKernelFunc(ctx, input, output, inputSize, dtype, stream); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, diff --git a/include/mscclpp/algorithm.hpp b/include/mscclpp/algorithm.hpp index 65b1ab3c..531cb857 100644 --- a/include/mscclpp/algorithm.hpp +++ b/include/mscclpp/algorithm.hpp @@ -103,12 +103,14 @@ class Algorithm { /// @param nThreadsPerBlock Number of threads per block (0 for auto-selection). /// @param symmetricMemory Whether to use symmetric memory optimization. /// @param extras Additional parameters for algorithm-specific customization. + /// @param accumDtype Data type for accumulation during reduction. DataType::AUTO resolves to dtype. /// @return The result of the operation. virtual CommResult execute(std::shared_ptr comm, const void* input, void* output, size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, std::shared_ptr executor, int nBlocks = 0, int nThreadsPerBlock = 0, bool symmetricMemory = false, - const std::unordered_map& extras = {}) = 0; + const std::unordered_map& extras = {}, + DataType accumDtype = DataType::AUTO) = 0; /// Reset the algorithm state, clearing any cached contexts. virtual void reset() = 0; @@ -186,10 +188,11 @@ class NativeAlgorithm : public Algorithm { /// @param nBlocks Number of CUDA blocks. /// @param nThreadsPerBlock Number of threads per block. /// @param extras Additional algorithm-specific parameters. + /// @param accumDtype Data type for accumulation (resolved from input dtype if sentinel). /// @return The result of the operation. using KernelFunc = std::function, const void*, void*, size_t, size_t, DataType, ReduceOp, - cudaStream_t, int, int, const std::unordered_map&)>; + cudaStream_t, int, int, const std::unordered_map&, DataType)>; /// Function type for creating algorithm contexts. /// @param comm The communicator. @@ -233,8 +236,8 @@ class NativeAlgorithm : public Algorithm { CommResult execute(std::shared_ptr comm, const void* input, void* output, size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, std::shared_ptr executor, int nBlocks = 0, int nThreadsPerBlock = 0, - bool symmetricMemory = false, - const std::unordered_map& extras = {}) override; + bool symmetricMemory = false, const std::unordered_map& extras = {}, + DataType accumDtype = DataType::AUTO) override; const std::string& name() const override; const std::string& collective() const override; const std::pair& messageRange() const override; @@ -285,8 +288,8 @@ class DslAlgorithm : public Algorithm, public AlgorithmBuilder, public std::enab CommResult execute(std::shared_ptr comm, const void* input, void* output, size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, std::shared_ptr executor, int nBlocks = 0, int nThreadsPerBlock = 0, - bool symmetricMemory = false, - const std::unordered_map& extras = {}) override; + bool symmetricMemory = false, const std::unordered_map& extras = {}, + DataType accumDtype = DataType::AUTO) override; AlgorithmType type() const override { return AlgorithmType::DSL; } Constraint constraint() const override; void reset() override; diff --git a/include/mscclpp/gpu_data_types.hpp b/include/mscclpp/gpu_data_types.hpp index 1cecbea6..fa31a28f 100644 --- a/include/mscclpp/gpu_data_types.hpp +++ b/include/mscclpp/gpu_data_types.hpp @@ -64,18 +64,151 @@ using __bfloat162 = __nv_bfloat162; #endif +/// Software float8 with 4 exponent bits, 3 mantissa bits, exponent bias = 15. +/// Format (MSB first): [sign:1][exponent:4][mantissa:3] +/// No infinities; exp=15 is NaN. Negative zero is NaN (fnuz convention). +/// Max finite value: 0.9375, min normal: ~6.1e-5, min subnormal: ~7.6e-6. +struct alignas(1) __fp8_e4m3b15 { + uint8_t __x; + + __fp8_e4m3b15() = default; + + /// Construct from raw bits (use __fp8_e4m3b15::fromRaw() for clarity). + MSCCLPP_HOST_DEVICE_INLINE explicit __fp8_e4m3b15(uint8_t raw) : __x(raw) {} + + /// Construct from float32 (explicit to avoid ambiguous conversion chains). + MSCCLPP_HOST_DEVICE_INLINE explicit __fp8_e4m3b15(float val) : __x(fromFloat(val)) {} + + /// Convert to float32. + MSCCLPP_HOST_DEVICE_INLINE operator float() const { return toFloat(__x); } + + /// Construct from a raw bit pattern without conversion. + static MSCCLPP_HOST_DEVICE_INLINE __fp8_e4m3b15 fromRaw(uint8_t bits) { + __fp8_e4m3b15 r; + r.__x = bits; + return r; + } + + private: + /// Decode fp8_e4m3b15 bits → float32. + /// + /// Uses bit manipulation through fp16 as intermediate, adapted from the Triton compiler. + /// fp8_e4m3b15 is identical to fp8_e4m3fn (NVIDIA) except exponent bias is 15 vs 7. + /// Algorithm: reinterpret fp8 bits into an fp16 bit pattern with exponent shifted by -8, + /// then convert fp16 → float32. + static MSCCLPP_HOST_DEVICE_INLINE float toFloat(uint8_t bits) { + // Handle special values: negative zero (0x80) → NaN, exponent=15 → NaN. + uint32_t exp = (bits >> 3) & 0xFu; + if (bits == 0x80 || exp == 15) { + union { + uint32_t u; + float f; + } nan_val = {0x7FC00000u}; + return nan_val.f; + } + if (bits == 0) return 0.0f; + + // Triton-style bit manipulation: fp8 → fp16 → fp32. + // fp8 layout: [S:1][E:4][M:3] (bias=15) + // fp16 layout: [S:1][E:5][M:10] (bias=15) + // + // Place fp8 in upper byte of fp16, then right-shift exponent+mantissa by 1 + // to convert E4 → E5 (both share bias=15). Sign bit stays at bit 15. + // Refer: + // https://github.com/triton-lang/triton/blob/cf34004b8a67d290a962da166f5aa2fc66751326/python/triton/language/extra/cuda/utils.py#L34 + uint16_t h = (uint16_t)bits << 8; // place fp8 in upper byte of fp16 + uint16_t sign16 = h & 0x8000u; // extract sign at fp16 position + uint16_t nosign = h & 0x7F00u; // exponent + mantissa (no sign) + uint16_t fp16_bits = sign16 | (nosign >> 1); // shift exponent right by 1 + + // For subnormals: when fp8 exponent=0, the above gives fp16 exponent=0 + // and fp16 mantissa = (fp8_mantissa << 7), which correctly represents + // the subnormal fp16 value since both share bias=15. + + // Convert fp16 bits to float via __half (works on host and device, CUDA and HIP). + union { + uint16_t u; + __half h; + } cvt = {fp16_bits}; + return __half2float(cvt.h); + } + + /// Encode float32 → fp8_e4m3b15 bits. + /// + /// Algorithm adapted from Triton: float32 → fp16 → bit-manipulate → fp8. + /// The key insight is to convert to fp16 first (which shares bias=15 with e4m3b15), + /// then pack the fp16 bits back into 8 bits by shifting the exponent left by 1. + static MSCCLPP_HOST_DEVICE_INLINE uint8_t fromFloat(float val) { + union { + float f; + uint32_t u; + } in = {val}; + + // NaN → 0x80 (negative-zero bit pattern = NaN in fnuz). + if ((in.u & 0x7F800000u) == 0x7F800000u && (in.u & 0x007FFFFFu) != 0) return 0x80u; + + // Convert float32 → fp16 bits via __half (works on host and device, CUDA and HIP). + __half h_val = __float2half_rn(val); + union { + __half h; + uint16_t u; + } cvt = {h_val}; + uint16_t fp16_bits = cvt.u; + + // Clamp absolute value to max finite e4m3b15: 0.9375 → fp16 = 0x3B80. + uint16_t abs_fp16 = fp16_bits & 0x7FFFu; + if (abs_fp16 > 0x3B80u) abs_fp16 = 0x3B80u; + + // Reconstruct with sign. + uint16_t sign16 = fp16_bits & 0x8000u; + + // Triton-style: fp16 → fp8. + // fp16 layout: [S:1][E:5][M:10] (bias=15) + // fp8 layout: [S:1][E:4][M:3] (bias=15) + // + // mad.lo.u32 a0, a0, 2, 0x00800080 → (abs_fp16 * 2 + 0x0080) + // This shifts left by 1 (undoing the right-shift in decode) and adds rounding bias. + // Then: lop3.b32 b0, $1, 0x80008000, a0, 0xea → (sign & 0x8000) | a0 + // Finally: prmt for byte extraction. + // + // Simplified for scalar: shift abs_fp16 left by 1, add rounding bias, take upper byte. + uint16_t adjusted = (uint16_t)(abs_fp16 * 2u + 0x0080u); + // The upper byte now contains [E:4][M:3][round_bit]. + // Combine with sign and extract. + uint16_t with_sign = sign16 | adjusted; + uint8_t result = (uint8_t)(with_sign >> 8); + + // Zero → 0x00 (ensure positive zero, not negative zero which is NaN). + if ((result & 0x7Fu) == 0) result = 0x00u; + + return result; + } +}; + +/// Packed 2x fp8_e4m3b15 storage. +struct alignas(2) __fp8x2_e4m3b15 { + uint16_t __x; +}; + +/// Packed 4x fp8_e4m3b15 storage. +struct alignas(4) __fp8x4_e4m3b15 { + uint32_t __x; +}; + namespace mscclpp { /// Data types supported by mscclpp operations. enum class DataType { - INT32, // 32-bit signed integer. - UINT32, // 32-bit unsigned integer. - FLOAT16, // IEEE 754 half precision. - FLOAT32, // IEEE 754 single precision. - BFLOAT16, // bfloat16 precision. - FLOAT8_E4M3, // float8 with E4M3 layout. - FLOAT8_E5M2, // float8 with E5M2 layout. - UINT8, // 8-bit unsigned integer. + INT32, // 32-bit signed integer. + UINT32, // 32-bit unsigned integer. + FLOAT16, // IEEE 754 half precision. + FLOAT32, // IEEE 754 single precision. + BFLOAT16, // bfloat16 precision. + FLOAT8_E4M3, // float8 with E4M3 layout. + FLOAT8_E5M2, // float8 with E5M2 layout. + UINT8, // 8-bit unsigned integer. + FLOAT8_E4M3B15, // float8 with E4M3 layout, bias=15 (software, no HW accel). + AUTO = 255, // Sentinel: resolve to the input dtype at runtime. }; /// Word array. @@ -97,6 +230,7 @@ struct alignas(Bytes) Words {}; template union alignas(sizeof(T) * N) VectorTypeImpl { static_assert(N > 0, "N must be greater than 0"); + static_assert(sizeof(StorageT) >= sizeof(T) * N, "StorageT must cover the full vector size"); T data[N]; Words words; @@ -127,13 +261,14 @@ union alignas(sizeof(T) * N) VectorTypeImpl { MSCCLPP_HOST_DEVICE_INLINE const T& operator[](int i) const { return data[i]; } }; -// Helper template to get the appropriate vector type for a given element type and count +// Helper template to get the appropriate vector type for a given element type and count. template struct VectorTypeHelper { - using type = - VectorTypeImpl>>; + static constexpr int Bytes = N * sizeof(T); + using type = VectorTypeImpl< + T, N, + std::conditional_t>>>>; }; /// Vector type - clean user interface (automatically selects appropriate storage type) @@ -170,6 +305,11 @@ DEFINE_VEC(bf16x4, __bfloat16, 4, uint2); DEFINE_VEC(f16x8, __half, 8, uint4); DEFINE_VEC(bf16x8, __bfloat16, 8, uint4); +// Aliases for large vector types (>16 bytes) where no native CUDA storage type exists. +using f32x8 = VectorType; +using f32x16 = VectorType; +using f16x16 = VectorType<__half, 16>; + #if defined(__FP8_TYPES_EXIST__) DEFINE_VEC(f8_e4m3x2, __fp8_e4m3, 2, __fp8x2_e4m3); DEFINE_VEC(f8_e4m3x4, __fp8_e4m3, 4, __fp8x4_e4m3); @@ -181,6 +321,12 @@ DEFINE_VEC(f8_e5m2x4, __fp8_e5m2, 4, __fp8x4_e5m2); DEFINE_VEC(f8_e5m2x8, __fp8_e5m2, 8, uint2); DEFINE_VEC(f8_e5m2x16, __fp8_e5m2, 16, uint4); #endif + +// fp8_e4m3b15 vectors (always available — software type, no HW dependency) +DEFINE_VEC(f8_e4m3b15x2, __fp8_e4m3b15, 2, __fp8x2_e4m3b15); +DEFINE_VEC(f8_e4m3b15x4, __fp8_e4m3b15, 4, __fp8x4_e4m3b15); +DEFINE_VEC(f8_e4m3b15x8, __fp8_e4m3b15, 8, uint2); +DEFINE_VEC(f8_e4m3b15x16, __fp8_e4m3b15, 16, uint4); #undef DEFINE_VEC #if defined(MSCCLPP_DEVICE_COMPILE) @@ -254,6 +400,21 @@ MSCCLPP_DEVICE_INLINE __fp8_e5m2 clip(__fp8_e5m2 val) { } #endif +// --- f32x2 arithmetic --- + +template +MSCCLPP_DEVICE_INLINE f32x2 operator+(const f32x2& a, const f32x2& b) { +#if defined(MSCCLPP_DEVICE_CUDA) && (__CUDA_ARCH__ >= 1000) + // Blackwell (SM 10.0+): packed float2 add in a single instruction. + return __fadd2_rn(a.storage, b.storage); +#else + f32x2 result; + result.data[0] = a.data[0] + b.data[0]; + result.data[1] = a.data[1] + b.data[1]; + return result; +#endif +} + template MSCCLPP_DEVICE_INLINE f16x2 operator+(const f16x2& a, const f16x2& b) { __half2 result; @@ -265,6 +426,18 @@ MSCCLPP_DEVICE_INLINE f16x2 operator+(const f16x2& a, const f16x2& b) { return result; } +template +MSCCLPP_DEVICE_INLINE f16x4 operator+(const f16x4& a, const f16x4& b) { + // Decompose into 2× packed __hadd2 (2 instructions instead of 4 scalar __hadd). + const f16x2* a2 = reinterpret_cast(&a); + const f16x2* b2 = reinterpret_cast(&b); + f16x4 result; + f16x2* r2 = reinterpret_cast(&result); + r2[0] = a2[0] + b2[0]; + r2[1] = a2[1] + b2[1]; + return result; +} + template MSCCLPP_DEVICE_INLINE bf16x2 operator+(const bf16x2& a, const bf16x2& b) { __bfloat162 result; @@ -449,6 +622,14 @@ MSCCLPP_DEVICE_INLINE T min(const T& a, const T& b) { return (a < b ? a : b); } +template <> +MSCCLPP_DEVICE_INLINE f32x2 min(const f32x2& a, const f32x2& b) { + f32x2 result; + result.data[0] = fminf(a.data[0], b.data[0]); + result.data[1] = fminf(a.data[1], b.data[1]); + return result; +} + template <> MSCCLPP_DEVICE_INLINE f16x2 min(const f16x2& a, const f16x2& b) { #if defined(MSCCLPP_DEVICE_HIP) @@ -489,6 +670,51 @@ MSCCLPP_DEVICE_INLINE u8x4 min(const u8x4& a, const u8x4& b) { #endif } +/// Convert a vector type From to vector type To. +/// Primary template with auto-decomposition: vectors with N > 4 elements decompose into x4 chunks, +/// vectors with N == 4 decompose into x2 chunks, enabling optimized x2/x4 specializations to be reached. +/// Specialized below for optimized FP8 conversion paths at x2/x4 level. +template +MSCCLPP_DEVICE_INLINE To to(const From& v) { + static_assert(To::Size == From::Size, "to: vector sizes must match"); + constexpr int N = From::Size; + + // Auto-decompose: N > 4 → split into x4 chunks + if constexpr (N > 4 && N % 4 == 0) { + constexpr int nChunks = N / 4; + using FromChunk = VectorType; + using ToChunk = VectorType; + const FromChunk* in = reinterpret_cast(&v); + To result; + ToChunk* out = reinterpret_cast(&result); +#pragma unroll + for (int c = 0; c < nChunks; ++c) { + out[c] = to(in[c]); + } + return result; + } + // Auto-decompose: N == 4 → split into 2x x2 chunks + else if constexpr (N == 4) { + using FromChunk = VectorType; + using ToChunk = VectorType; + const FromChunk* in = reinterpret_cast(&v); + To result; + ToChunk* out = reinterpret_cast(&result); + out[0] = to(in[0]); + out[1] = to(in[1]); + return result; + } + // Base case: element-wise conversion + else { + To result; +#pragma unroll + for (int i = 0; i < N; ++i) { + result.data[i] = static_cast(v.data[i]); + } + return result; + } +} + #if defined(__FP8_TYPES_EXIST__) template <> MSCCLPP_DEVICE_INLINE __fp8_e4m3 min(const __fp8_e4m3& a, const __fp8_e4m3& b) { @@ -551,7 +777,526 @@ MSCCLPP_DEVICE_INLINE f8_e5m2x4 min(const f8_e5m2x4& a, const f8_e5m2x4& b) { return result; } + +// --- f8_e4m3 -> f32 specializations --- + +/// f8_e4m3x2 -> f32x2. +/// NVIDIA: fp8 -> half (via __nv_cvt_fp8x2_to_halfraw2) -> float. +/// HIP gfx942: fp8 -> float (via __builtin_amdgcn_cvt_pk_f32_fp8). +template <> +MSCCLPP_DEVICE_INLINE f32x2 to(const f8_e4m3x2& v) { +#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + auto f = __builtin_amdgcn_cvt_pk_f32_fp8(v.storage.__x, 0); + f32x2 result; + result.data[0] = f[0]; + result.data[1] = f[1]; + return result; +#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900 + __half2_raw h2 = __nv_cvt_fp8x2_to_halfraw2(bit_cast<__nv_fp8x2_storage_t>(v.storage), __NV_E4M3); + f32x2 result; + result.data[0] = __half2float(bit_cast<__half>(h2.x)); + result.data[1] = __half2float(bit_cast<__half>(h2.y)); + return result; +#else + f32x2 result; + result.data[0] = float(v.data[0]); + result.data[1] = float(v.data[1]); + return result; +#endif +} + +/// f8_e4m3x4 -> f32x4. +template <> +MSCCLPP_DEVICE_INLINE f32x4 to(const f8_e4m3x4& v) { +#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + auto lo = __builtin_amdgcn_cvt_pk_f32_fp8(v.storage.__x, false); + auto hi = __builtin_amdgcn_cvt_pk_f32_fp8(v.storage.__x, true); + f32x4 result; + result.data[0] = lo[0]; + result.data[1] = lo[1]; + result.data[2] = hi[0]; + result.data[3] = hi[1]; + return result; +#else + const f8_e4m3x2* pair = reinterpret_cast(&v); + f32x2 lo = to(pair[0]); + f32x2 hi = to(pair[1]); + f32x4 result; + result.data[0] = lo.data[0]; + result.data[1] = lo.data[1]; + result.data[2] = hi.data[0]; + result.data[3] = hi.data[1]; + return result; +#endif +} + +// --- f8_e5m2 -> f32 specializations --- + +/// f8_e5m2x2 -> f32x2. +/// NVIDIA: fp8 -> half (via __nv_cvt_fp8x2_to_halfraw2) -> float. +/// HIP gfx942: bf8 -> float (via __builtin_amdgcn_cvt_pk_f32_bf8). +template <> +MSCCLPP_DEVICE_INLINE f32x2 to(const f8_e5m2x2& v) { +#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + auto f = __builtin_amdgcn_cvt_pk_f32_bf8(v.storage.__x, 0); + f32x2 result; + result.data[0] = f[0]; + result.data[1] = f[1]; + return result; +#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900 + __half2_raw h2 = __nv_cvt_fp8x2_to_halfraw2(bit_cast<__nv_fp8x2_storage_t>(v.storage), __NV_E5M2); + f32x2 result; + result.data[0] = __half2float(bit_cast<__half>(h2.x)); + result.data[1] = __half2float(bit_cast<__half>(h2.y)); + return result; +#else + f32x2 result; + result.data[0] = float(v.data[0]); + result.data[1] = float(v.data[1]); + return result; +#endif +} + +/// f8_e5m2x4 -> f32x4. +template <> +MSCCLPP_DEVICE_INLINE f32x4 to(const f8_e5m2x4& v) { +#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + auto lo = __builtin_amdgcn_cvt_pk_f32_bf8(v.storage.__x, false); + auto hi = __builtin_amdgcn_cvt_pk_f32_bf8(v.storage.__x, true); + f32x4 result; + result.data[0] = lo[0]; + result.data[1] = lo[1]; + result.data[2] = hi[0]; + result.data[3] = hi[1]; + return result; +#else + const f8_e5m2x2* pair = reinterpret_cast(&v); + f32x2 lo = to(pair[0]); + f32x2 hi = to(pair[1]); + f32x4 result; + result.data[0] = lo.data[0]; + result.data[1] = lo.data[1]; + result.data[2] = hi.data[0]; + result.data[3] = hi.data[1]; + return result; +#endif +} + +// --- f32 -> f8_e4m3 specializations (downcast) --- + +/// f32x2 -> f8_e4m3x2. +/// HIP gfx942: float -> fp8 (via __builtin_amdgcn_cvt_pk_fp8_f32). +/// NVIDIA SM90+: float -> half -> fp8 (via __nv_cvt_halfraw2_to_fp8x2). +/// NVIDIA pre-SM90: float -> half -> fp8 (via __nv_cvt_halfraw_to_fp8, element-wise). +template <> +MSCCLPP_DEVICE_INLINE f8_e4m3x2 to(const f32x2& v) { +#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + uint32_t packed = __builtin_amdgcn_cvt_pk_fp8_f32(v.data[0], v.data[1], 0, false); + return bit_cast(static_cast<__hip_fp8x2_storage_t>(packed)); +#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900 + __half2_raw h2; + h2.x = bit_cast(__float2half_rn(v.data[0])); + h2.y = bit_cast(__float2half_rn(v.data[1])); + __nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2(h2, __NV_SATFINITE, __NV_E4M3); + return bit_cast(fp8x2); +#elif defined(MSCCLPP_DEVICE_CUDA) + __half_raw h0, h1; + h0.x = bit_cast(__float2half_rn(v.data[0])); + h1.x = bit_cast(__float2half_rn(v.data[1])); + f8_e4m3x2 result; + result.data[0] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h0, __NV_SATFINITE, __NV_E4M3)); + result.data[1] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h1, __NV_SATFINITE, __NV_E4M3)); + return result; +#else + f8_e4m3x2 result; + result.data[0] = static_cast<__fp8_e4m3>(v.data[0]); + result.data[1] = static_cast<__fp8_e4m3>(v.data[1]); + return result; +#endif +} + +/// f32x4 -> f8_e4m3x4. +template <> +MSCCLPP_DEVICE_INLINE f8_e4m3x4 to(const f32x4& v) { +#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + uint32_t packed = __builtin_amdgcn_cvt_pk_fp8_f32(v.data[0], v.data[1], 0, false); + packed = __builtin_amdgcn_cvt_pk_fp8_f32(v.data[2], v.data[3], packed, true); + return bit_cast(packed); +#else + f32x2 lo, hi; + lo.data[0] = v.data[0]; + lo.data[1] = v.data[1]; + hi.data[0] = v.data[2]; + hi.data[1] = v.data[3]; + f8_e4m3x2 lo_fp8 = to(lo); + f8_e4m3x2 hi_fp8 = to(hi); + f8_e4m3x4 result; + result.data[0] = lo_fp8.data[0]; + result.data[1] = lo_fp8.data[1]; + result.data[2] = hi_fp8.data[0]; + result.data[3] = hi_fp8.data[1]; + return result; +#endif +} + +// --- f32 -> f8_e5m2 specializations (downcast) --- + +/// f32x2 -> f8_e5m2x2. +/// HIP gfx942: float -> bf8 (via __builtin_amdgcn_cvt_pk_bf8_f32). +/// NVIDIA SM90+: float -> half -> fp8 (via __nv_cvt_halfraw2_to_fp8x2 with __NV_E5M2). +/// NVIDIA pre-SM90: float -> half -> fp8 (via __nv_cvt_halfraw_to_fp8, element-wise). +template <> +MSCCLPP_DEVICE_INLINE f8_e5m2x2 to(const f32x2& v) { +#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + uint32_t packed = __builtin_amdgcn_cvt_pk_bf8_f32(v.data[0], v.data[1], 0, false); + return bit_cast(static_cast<__hip_fp8x2_storage_t>(packed)); +#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900 + __half2_raw h2; + h2.x = bit_cast(__float2half_rn(v.data[0])); + h2.y = bit_cast(__float2half_rn(v.data[1])); + __nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2(h2, __NV_SATFINITE, __NV_E5M2); + return bit_cast(fp8x2); +#elif defined(MSCCLPP_DEVICE_CUDA) + __half_raw h0, h1; + h0.x = bit_cast(__float2half_rn(v.data[0])); + h1.x = bit_cast(__float2half_rn(v.data[1])); + f8_e5m2x2 result; + result.data[0] = bit_cast<__fp8_e5m2>(__nv_cvt_halfraw_to_fp8(h0, __NV_SATFINITE, __NV_E5M2)); + result.data[1] = bit_cast<__fp8_e5m2>(__nv_cvt_halfraw_to_fp8(h1, __NV_SATFINITE, __NV_E5M2)); + return result; +#else + f8_e5m2x2 result; + result.data[0] = static_cast<__fp8_e5m2>(v.data[0]); + result.data[1] = static_cast<__fp8_e5m2>(v.data[1]); + return result; +#endif +} + +/// f32x4 -> f8_e5m2x4. +template <> +MSCCLPP_DEVICE_INLINE f8_e5m2x4 to(const f32x4& v) { +#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + uint32_t packed = __builtin_amdgcn_cvt_pk_bf8_f32(v.data[0], v.data[1], 0, false); + packed = __builtin_amdgcn_cvt_pk_bf8_f32(v.data[2], v.data[3], packed, true); + return bit_cast(packed); +#else + f32x2 lo, hi; + lo.data[0] = v.data[0]; + lo.data[1] = v.data[1]; + hi.data[0] = v.data[2]; + hi.data[1] = v.data[3]; + f8_e5m2x2 lo_fp8 = to(lo); + f8_e5m2x2 hi_fp8 = to(hi); + f8_e5m2x4 result; + result.data[0] = lo_fp8.data[0]; + result.data[1] = lo_fp8.data[1]; + result.data[2] = hi_fp8.data[0]; + result.data[3] = hi_fp8.data[1]; + return result; +#endif +} + +// --- f8_e4m3 <-> f16 conversion specializations --- + +/// f8_e4m3x2 -> f16x2. +/// NVIDIA SM90+: packed intrinsic (1 instruction). +/// HIP gfx942: fp8 -> float -> half (via AMD builtin). +/// Pre-SM90 / fallback: element-wise scalar conversion. +template <> +MSCCLPP_DEVICE_INLINE f16x2 to(const f8_e4m3x2& v) { +#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + auto f = __builtin_amdgcn_cvt_pk_f32_fp8(v.storage.__x, 0); + f16x2 result; + result.data[0] = __float2half(f[0]); + result.data[1] = __float2half(f[1]); + return result; +#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900 + __half2_raw h2 = __nv_cvt_fp8x2_to_halfraw2(bit_cast<__nv_fp8x2_storage_t>(v.storage), __NV_E4M3); + return bit_cast(h2); +#else + f16x2 result; + result.data[0] = static_cast<__half>(v.data[0]); + result.data[1] = static_cast<__half>(v.data[1]); + return result; +#endif +} + +/// f16x2 -> f8_e4m3x2. +/// NVIDIA SM90+: packed intrinsic (1 instruction). +/// HIP gfx942: half -> float -> fp8 (via AMD builtin). +/// Pre-SM90: element-wise scalar conversion. +template <> +MSCCLPP_DEVICE_INLINE f8_e4m3x2 to(const f16x2& v) { +#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) + float f0 = __half2float(v.data[0]); + float f1 = __half2float(v.data[1]); + uint32_t packed = __builtin_amdgcn_cvt_pk_fp8_f32(f0, f1, 0, false); + return bit_cast(static_cast<__hip_fp8x2_storage_t>(packed)); +#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900 + __half2_raw h2 = bit_cast<__half2_raw>(v); + __nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2(h2, __NV_SATFINITE, __NV_E4M3); + return bit_cast(fp8x2); +#elif defined(MSCCLPP_DEVICE_CUDA) + __half_raw h0, h1; + h0.x = bit_cast(v.data[0]); + h1.x = bit_cast(v.data[1]); + f8_e4m3x2 result; + result.data[0] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h0, __NV_SATFINITE, __NV_E4M3)); + result.data[1] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h1, __NV_SATFINITE, __NV_E4M3)); + return result; +#else + f8_e4m3x2 result; + result.data[0] = static_cast<__fp8_e4m3>(v.data[0]); + result.data[1] = static_cast<__fp8_e4m3>(v.data[1]); + return result; +#endif +} + #endif // defined(__FP8_TYPES_EXIST__) + +// --- fp8_e4m3b15 <-> fp16 direct conversion specializations --- +// These are the PRIMARY conversions: fp8_b15 <-> fp16 is just a 1-bit exponent shift +// (E4 bias=15 <-> E5 bias=15), no precision loss since fp16 has 10 mantissa bits +// vs fp8's 3. fp32 conversions are derived by routing through fp16. + +/// f8_e4m3b15x2 -> f16x2. +/// Direct fp8 -> fp16 via branch-free bit manipulation. +template <> +MSCCLPP_DEVICE_INLINE f16x2 to(const f8_e4m3b15x2& v) { +#if defined(MSCCLPP_DEVICE_CUDA) + uint16_t in = v.storage.__x; + // Spread 2 fp8 bytes into packed fp16 pair, adjust exponent E4->E5. + uint32_t a0 = ((uint32_t)(in & 0xFFu) << 8) | ((uint32_t)(in >> 8) << 24); + uint32_t b0 = (a0 & 0x7f007f00u) >> 1; + uint32_t out0 = b0 | (a0 & 0x80008000u); + __half2 h; + asm("mov.b32 %0, %1;" : "=r"(*reinterpret_cast(&h)) : "r"(out0)); + return h; +#else + f16x2 result; + result.data[0] = __float2half(float(v.data[0])); + result.data[1] = __float2half(float(v.data[1])); + return result; +#endif +} + +/// f8_e4m3b15x4 -> f16x4. +/// Uses __byte_perm + lop3 for branch-free vectorized conversion. +template <> +MSCCLPP_DEVICE_INLINE f16x4 to(const f8_e4m3b15x4& v) { +#if defined(MSCCLPP_DEVICE_CUDA) + uint32_t in = v.storage.__x; + uint32_t a0 = __byte_perm(0u, in, 0x5746u); + uint32_t a0_shr = a0 >> 1; + uint32_t a0_sign = a0 & 0x80008000u; + uint32_t out0; + asm("lop3.b32 %0, %1, %2, %3, 0xEA;" : "=r"(out0) : "r"(a0_shr), "r"(0x3f803f80u), "r"(a0_sign)); + uint32_t a1 = __byte_perm(a0, 0u, 0x2301u); + uint32_t a1_shr = a1 >> 1; + uint32_t a1_sign = a1 & 0x80008000u; + uint32_t out1; + asm("lop3.b32 %0, %1, %2, %3, 0xEA;" : "=r"(out1) : "r"(a1_shr), "r"(0x3f803f80u), "r"(a1_sign)); + f16x4 result; + asm("mov.b32 %0, %1;" : "=r"(result.words[0]) : "r"(out0)); + asm("mov.b32 %0, %1;" : "=r"(result.words[1]) : "r"(out1)); + return result; +#else + f16x4 result; +#pragma unroll + for (int i = 0; i < 4; ++i) { + result.data[i] = __float2half(float(v.data[i])); + } + return result; +#endif +} + +/// f16x2 -> f8_e4m3b15x2. +/// Direct fp16 -> fp8 via clamp + exponent shift E5->E4 + pack. +template <> +MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to(const f16x2& v) { +#if defined(MSCCLPP_DEVICE_CUDA) + uint32_t in0; + asm("mov.b32 %0, %1;" : "=r"(in0) : "r"(*reinterpret_cast(&v))); + // Clamp abs to max finite e4m3b15 (0x3B80 = 0.9375 in fp16). + uint32_t lo = in0 & 0xFFFFu, hi = in0 >> 16; + uint32_t alo = lo & 0x7FFFu, ahi = hi & 0x7FFFu; + alo = alo < 0x3B80u ? alo : 0x3B80u; + ahi = ahi < 0x3B80u ? ahi : 0x3B80u; + uint32_t a0 = alo | (ahi << 16); + a0 = a0 * 2u + 0x00800080u; + uint32_t b0 = a0 | (in0 & 0x80008000u); + uint16_t packed = (uint16_t)(((b0 >> 8) & 0xFFu) | ((b0 >> 16) & 0xFF00u)); + return bit_cast(packed); +#else + f8_e4m3b15x2 result; + result.data[0] = __fp8_e4m3b15(__half2float(v.data[0])); + result.data[1] = __fp8_e4m3b15(__half2float(v.data[1])); + return result; +#endif +} + +/// f16x4 -> f8_e4m3b15x4. +/// Uses __vminu2 + lop3 + __byte_perm for branch-free vectorized conversion. +template <> +MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to(const f16x4& v) { +#if defined(MSCCLPP_DEVICE_CUDA) + uint32_t in0, in1; + asm("mov.b32 %0, %1;" : "=r"(in0) : "r"(v.words[0])); + asm("mov.b32 %0, %1;" : "=r"(in1) : "r"(v.words[1])); + uint32_t abs0 = in0 & 0x7fff7fffu; + uint32_t abs1 = in1 & 0x7fff7fffu; + uint32_t a0 = __vminu2(abs0, 0x3B803B80u); + uint32_t a1 = __vminu2(abs1, 0x3B803B80u); + a0 = a0 * 2u + 0x00800080u; + a1 = a1 * 2u + 0x00800080u; + uint32_t b0, b1; + asm("lop3.b32 %0, %1, %2, %3, 0xf8;" : "=r"(b0) : "r"(a0), "r"(in0), "r"(0x80008000u)); + asm("lop3.b32 %0, %1, %2, %3, 0xf8;" : "=r"(b1) : "r"(a1), "r"(in1), "r"(0x80008000u)); + uint32_t packed = __byte_perm(b0, b1, 0x7531u); + return bit_cast(packed); +#else + f8_e4m3b15x4 result; +#pragma unroll + for (int i = 0; i < 4; ++i) { + result.data[i] = __fp8_e4m3b15(__half2float(v.data[i])); + } + return result; +#endif +} + +// --- fp8_e4m3b15 <-> f32 conversion specializations --- +// Derived from fp16 conversions: fp8→f32 = fp8→fp16→f32, f32→fp8 = f32→fp16→fp8. + +/// f8_e4m3b15x2 -> f32x2. +/// Routes through fp16: fp8→fp16 (bit manip) then fp16→f32. +template <> +MSCCLPP_DEVICE_INLINE f32x2 to(const f8_e4m3b15x2& v) { +#if defined(MSCCLPP_DEVICE_CUDA) + f16x2 h = to(v); + float2 f2 = __half22float2(h); + return bit_cast(f2); +#else + f32x2 result; + result.data[0] = float(v.data[0]); + result.data[1] = float(v.data[1]); + return result; +#endif +} + +/// f8_e4m3b15x4 -> f32x4. +/// Routes through fp16: fp8→fp16 (bit manip) then fp16→f32. +template <> +MSCCLPP_DEVICE_INLINE f32x4 to(const f8_e4m3b15x4& v) { +#if defined(MSCCLPP_DEVICE_CUDA) + f16x4 h = to(v); + __half2 h0, h1; + asm("mov.b32 %0, %1;" : "=r"(*reinterpret_cast(&h0)) : "r"(h.words[0])); + asm("mov.b32 %0, %1;" : "=r"(*reinterpret_cast(&h1)) : "r"(h.words[1])); + float2 f0 = __half22float2(h0); + float2 f1 = __half22float2(h1); + f32x4 result; + result.data[0] = f0.x; + result.data[1] = f0.y; + result.data[2] = f1.x; + result.data[3] = f1.y; + return result; +#else + f32x4 result; +#pragma unroll + for (int i = 0; i < 4; ++i) { + result.data[i] = float(v.data[i]); + } + return result; +#endif +} + +/// f32x2 -> f8_e4m3b15x2. +/// Routes through fp16: f32→fp16 then fp16→fp8 (clamp + exponent shift + pack). +template <> +MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to(const f32x2& v) { +#if defined(MSCCLPP_DEVICE_CUDA) + float2 f2 = {v.data[0], v.data[1]}; + __half2 h = __float22half2_rn(f2); + return to(h); +#else + f8_e4m3b15x2 result; + result.data[0] = __fp8_e4m3b15(v.data[0]); + result.data[1] = __fp8_e4m3b15(v.data[1]); + return result; +#endif +} + +/// f32x4 -> f8_e4m3b15x4. +/// Routes through fp16: f32→fp16 then fp16→fp8 (clamp + exponent shift + pack). +template <> +MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to(const f32x4& v) { +#if defined(MSCCLPP_DEVICE_CUDA) + float2 f01 = {v.data[0], v.data[1]}; + float2 f23 = {v.data[2], v.data[3]}; + __half2 h01 = __float22half2_rn(f01); + __half2 h23 = __float22half2_rn(f23); + f16x4 h; + asm("mov.b32 %0, %1;" : "=r"(h.words[0]) : "r"(*reinterpret_cast(&h01))); + asm("mov.b32 %0, %1;" : "=r"(h.words[1]) : "r"(*reinterpret_cast(&h23))); + return to(h); +#else + f8_e4m3b15x4 result; +#pragma unroll + for (int i = 0; i < 4; ++i) { + result.data[i] = __fp8_e4m3b15(v.data[i]); + } + return result; +#endif +} + +// --- fp8_e4m3b15 arithmetic (software, always available) --- + +template +MSCCLPP_DEVICE_INLINE __fp8_e4m3b15 operator+(const __fp8_e4m3b15& a, const __fp8_e4m3b15& b) { + return __fp8_e4m3b15(float(a) + float(b)); +} + +template +MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 operator+(const f8_e4m3b15x2& a, const f8_e4m3b15x2& b) { + f8_e4m3b15x2 result; + result.data[0] = __fp8_e4m3b15(float(a.data[0]) + float(b.data[0])); + result.data[1] = __fp8_e4m3b15(float(a.data[1]) + float(b.data[1])); + return result; +} + +template +MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 operator+(const f8_e4m3b15x4& a, const f8_e4m3b15x4& b) { + f8_e4m3b15x4 result; +#pragma unroll + for (int i = 0; i < 4; ++i) { + result.data[i] = __fp8_e4m3b15(float(a.data[i]) + float(b.data[i])); + } + return result; +} + +// --- fp8_e4m3b15 min (software) --- + +template <> +MSCCLPP_DEVICE_INLINE __fp8_e4m3b15 min(const __fp8_e4m3b15& a, const __fp8_e4m3b15& b) { + return __fp8_e4m3b15(fminf(float(a), float(b))); +} + +MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 min(const f8_e4m3b15x2& a, const f8_e4m3b15x2& b) { + f8_e4m3b15x2 result; + result.data[0] = mscclpp::min(a.data[0], b.data[0]); + result.data[1] = mscclpp::min(a.data[1], b.data[1]); + return result; +} + +MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 min(const f8_e4m3b15x4& a, const f8_e4m3b15x4& b) { + f8_e4m3b15x4 result; +#pragma unroll + for (int i = 0; i < 4; ++i) { + result.data[i] = mscclpp::min(a.data[i], b.data[i]); + } + return result; +} + #endif // MSCCLPP_DEVICE_COMPILE } // namespace mscclpp diff --git a/python/csrc/algorithm.cpp b/python/csrc/algorithm.cpp index 1a93cbc0..1cb3f253 100644 --- a/python/csrc/algorithm.cpp +++ b/python/csrc/algorithm.cpp @@ -75,15 +75,17 @@ void register_algorithm(nb::module_& m) { [](Algorithm& self, std::shared_ptr comm, uintptr_t input, uintptr_t output, size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op, uintptr_t stream, std::shared_ptr executor, int nBlocks, int nThreadsPerBlock, bool symmetricMemory, - std::unordered_map extras) { + std::unordered_map extras, int32_t accumDtype) { return self.execute(comm, reinterpret_cast(input), reinterpret_cast(output), inputSize, outputSize, dtype, op, reinterpret_cast(stream), executor, - nBlocks, nThreadsPerBlock, symmetricMemory, extras); + nBlocks, nThreadsPerBlock, symmetricMemory, extras, + static_cast(accumDtype)); }, nb::arg("comm"), nb::arg("input"), nb::arg("output"), nb::arg("input_size"), nb::arg("output_size"), nb::arg("dtype"), nb::arg("op") = ReduceOp::NOP, nb::arg("stream") = 0, nb::arg("executor") = nullptr, nb::arg("n_blocks") = 0, nb::arg("n_threads_per_block") = 0, nb::arg("symmetric_memory") = false, - nb::arg("extras") = std::unordered_map()) + nb::arg("extras") = std::unordered_map(), + nb::arg("accum_dtype") = static_cast(DataType::AUTO)) .def("reset", &Algorithm::reset); nb::class_(algorithmClass, "Constraint") diff --git a/python/csrc/core_py.cpp b/python/csrc/core_py.cpp index 47d76ac4..b8649564 100644 --- a/python/csrc/core_py.cpp +++ b/python/csrc/core_py.cpp @@ -47,7 +47,8 @@ void register_core(nb::module_& m) { .value("bfloat16", DataType::BFLOAT16) .value("float8_e4m3", DataType::FLOAT8_E4M3) .value("float8_e5m2", DataType::FLOAT8_E5M2) - .value("uint8", DataType::UINT8); + .value("uint8", DataType::UINT8) + .value("float8_e4m3b15", DataType::FLOAT8_E4M3B15); nb::class_(m, "CppBootstrap") .def("get_rank", &Bootstrap::getRank) diff --git a/python/csrc/gpu_utils_py.cpp b/python/csrc/gpu_utils_py.cpp index 6995756b..60880456 100644 --- a/python/csrc/gpu_utils_py.cpp +++ b/python/csrc/gpu_utils_py.cpp @@ -34,6 +34,19 @@ static DLDataType getDlType(std::string type) { return DLDataType{kDLBfloat, 16, 1}; } else if (type == "torch.float16") { return DLDataType{kDLFloat, 16, 1}; + } else if (type == "torch.float8_e4m3fn") { + return DLDataType{kDLFloat8_e4m3fn, 8, 1}; + } else if (type == "torch.float8_e4m3fnuz") { + return DLDataType{kDLFloat8_e4m3fnuz, 8, 1}; + } else if (type == "torch.float8_e5m2") { + return DLDataType{kDLFloat8_e5m2, 8, 1}; + } else if (type == "torch.float8_e5m2fnuz") { + return DLDataType{kDLFloat8_e5m2fnuz, 8, 1}; + } else if (type == "torch.uint8") { + return DLDataType{kDLUInt, 8, 1}; + } else if (type == "fp8_e4m3b15") { + // No standard DLPack code for fp8_e4m3b15; store as raw uint8 bytes. + return DLDataType{kDLUInt, 8, 1}; } else { throw Error("Unsupported type: " + type, ErrorCode::InvalidUsage); } diff --git a/python/mscclpp/_core/algorithm.py b/python/mscclpp/_core/algorithm.py index 744cf39e..f12a3027 100644 --- a/python/mscclpp/_core/algorithm.py +++ b/python/mscclpp/_core/algorithm.py @@ -177,6 +177,7 @@ class Algorithm: nthreads_per_block=0, symmetric_memory: bool = False, extras: Optional[Dict[str, int]] = None, + accum_dtype: Optional[CppDataType] = None, ) -> int: """Execute the collective algorithm. @@ -194,10 +195,14 @@ class Algorithm: nthreads_per_block: Number of threads per block (0 for auto-selection). symmetric_memory: Whether to use symmetric memory optimization (default: False). extras: Additional algorithm-specific parameters. + accum_dtype: Data type for accumulation during reduction. If None, defaults to + the same as dtype. Use DataType.float32 for high-precision FP8 accumulation. Returns: The result code (0 for success). """ + merged_extras = dict(extras) if extras is not None else {} + accum_dtype = accum_dtype if accum_dtype is not None else dtype return self._algorithm.execute( comm, int(input_buffer), @@ -211,7 +216,8 @@ class Algorithm: nblocks, nthreads_per_block, symmetric_memory, - extras if extras is not None else {}, + merged_extras, + int(accum_dtype), ) def reset(self): diff --git a/python/test/test_fp8_accum.py b/python/test/test_fp8_accum.py new file mode 100644 index 00000000..3a6c67f1 --- /dev/null +++ b/python/test/test_fp8_accum.py @@ -0,0 +1,391 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Correctness test for FP8 allreduce with different accumulation types. +# +# Verifies that FP8 allreduce with higher-precision accumulation produces +# results at least as accurate as native FP8 accumulation, by comparing +# against a float32 reference. +# +# Usage: +# mpirun -np 8 pytest python/test/test_fp8_accum.py -v + +import cupy as cp +import numpy as np +import pytest + +from mscclpp import CommGroup, GpuBuffer, DataType, ReduceOp, is_nvls_supported +from mscclpp.ext import AlgorithmCollectionBuilder +from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group + +# FP8 E4M3 (hardware) requires SM >= 89 (Ada / Hopper) on NVIDIA GPUs. +# On AMD/ROCm (e.g. MI300X), FP8 is supported natively — no skip needed. +_is_hip = hasattr(cp.cuda.runtime, "is_hip") and cp.cuda.runtime.is_hip +# TODO(binyli): Skip hip for now, will fix it in the next PR +_skip_fp8 = _is_hip or int(cp.cuda.Device().compute_capability) < 89 +pytestmark = pytest.mark.skipif(_skip_fp8, reason="FP8 accum tests require SM >= 89 on CUDA (HIP not yet supported)") + +# --------------------------------------------------------------------------- +# FP8 E4M3FN helpers (bias=7, no infinity, NaN = exp=15 & mant=7) +# --------------------------------------------------------------------------- + + +def e4m3fn_to_float(uint8_array): + """Decode a cupy uint8 array of E4M3FN bit patterns to float32.""" + bits = uint8_array.astype(cp.int32) + sign = (bits >> 7) & 1 + exp = (bits >> 3) & 0xF + mant = bits & 0x7 + + # Normal: (-1)^s * 2^(exp-7) * (1 + mant/8) + normal_val = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), (exp - 7).astype(cp.int32)) + # Subnormal (exp==0): (-1)^s * 2^(-6) * (mant/8) + subnormal_val = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-6)) + + result = cp.where(exp == 0, subnormal_val, normal_val) + result = cp.where(sign == 1, -result, result) + # Zero + result = cp.where((exp == 0) & (mant == 0), cp.float32(0.0), result) + # NaN: exp==15 & mant==7 + nan_mask = (exp == 15) & (mant == 7) + result = cp.where(nan_mask, cp.float32(float("nan")), result) + return result + + +def float_to_e4m3fn(f32_array, chunk_size=65536): + """Encode a cupy float32 array to uint8 E4M3FN bit patterns. + + Uses a lookup-table approach: precompute all 128 positive E4M3FN values, + then find nearest match per element via chunked broadcast comparison. + """ + # Build lookup table of all 128 positive E4M3FN values (0x00..0x7F) + all_bytes = cp.arange(128, dtype=cp.uint8) + all_floats = e4m3fn_to_float(all_bytes) # (128,) float32 + # Mark NaN entries as inf so they're never selected as nearest + all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats) + + # Clamp input and extract sign + clamped = f32_array.astype(cp.float32) + clamped = cp.clip(clamped, -448.0, 448.0) + signs = (clamped < 0).astype(cp.uint8) + absval = cp.abs(clamped) + + result = cp.zeros(absval.shape, dtype=cp.uint8) + n = absval.size + absval_flat = absval.ravel() + result_flat = result.ravel() + + for start in range(0, n, chunk_size): + end = min(start + chunk_size, n) + chunk = absval_flat[start:end] + # (chunk_size, 128) difference matrix + diffs = cp.abs(chunk[:, None] - all_floats[None, :]) + result_flat[start:end] = cp.argmin(diffs, axis=1).astype(cp.uint8) + + # Combine with sign bit + result = result_flat.reshape(absval.shape) + result = result | (signs << 7) + # Handle exact zero + result = cp.where(absval == 0, cp.uint8(0), result) + return result + + +# --------------------------------------------------------------------------- +# FP8 E4M3B15 helpers (bias=15, max=0.9375, NaN = exp==15 or bits==0x80) +# --------------------------------------------------------------------------- + + +def e4m3b15_to_float(uint8_array): + """Decode a cupy uint8 array of E4M3B15 bit patterns to float32.""" + bits = uint8_array.astype(cp.int32) + sign = (bits >> 7) & 1 + exp = (bits >> 3) & 0xF + mant = bits & 0x7 + + # Normal: (-1)^s * 2^(exp-15) * (1 + mant/8) + normal_val = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), (exp - 15).astype(cp.int32)) + # Subnormal (exp==0): (-1)^s * 2^(-14) * (mant/8) + subnormal_val = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-14)) + + result = cp.where(exp == 0, subnormal_val, normal_val) + result = cp.where(sign == 1, -result, result) + # Zero + result = cp.where((exp == 0) & (mant == 0), cp.float32(0.0), result) + # NaN: exp==15 or negative zero (0x80) + nan_mask = (exp == 15) | (uint8_array.astype(cp.int32) == 0x80) + result = cp.where(nan_mask, cp.float32(float("nan")), result) + return result + + +def float_to_e4m3b15(f32_array, chunk_size=65536): + """Encode a cupy float32 array to uint8 E4M3B15 bit patterns. + + Same lookup-table approach as float_to_e4m3fn. + """ + # Build lookup table of all 128 positive E4M3B15 values (0x00..0x7F) + all_bytes = cp.arange(128, dtype=cp.uint8) + all_floats = e4m3b15_to_float(all_bytes) # (128,) float32 + # Mark NaN entries as inf so they're never selected as nearest + all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats) + + # Clamp input and extract sign + clamped = f32_array.astype(cp.float32) + clamped = cp.clip(clamped, -0.9375, 0.9375) + signs = (clamped < 0).astype(cp.uint8) + absval = cp.abs(clamped) + + result = cp.zeros(absval.shape, dtype=cp.uint8) + n = absval.size + absval_flat = absval.ravel() + result_flat = result.ravel() + + for start in range(0, n, chunk_size): + end = min(start + chunk_size, n) + chunk = absval_flat[start:end] + # (chunk_size, 128) difference matrix + diffs = cp.abs(chunk[:, None] - all_floats[None, :]) + result_flat[start:end] = cp.argmin(diffs, axis=1).astype(cp.uint8) + + # Combine with sign bit + result = result_flat.reshape(absval.shape) + result = result | (signs << 7) + # Handle exact zero + result = cp.where(absval == 0, cp.uint8(0), result) + return result + + +# --------------------------------------------------------------------------- +# Shared test helpers +# --------------------------------------------------------------------------- + + +def setup_algorithms(mpi_group): + """Build default algorithms and return (comm_group, algo_map, scratch_buf).""" + comm_group = CommGroup(mpi_group.comm) + scratch = GpuBuffer(1 << 27, dtype=cp.uint8) # 128 MB + AlgorithmCollectionBuilder.reset() + builder = AlgorithmCollectionBuilder() + algorithms = builder.build_default_algorithms( + scratch_buffer=scratch.data.ptr, + scratch_buffer_size=scratch.nbytes, + rank=comm_group.my_rank, + ) + algo_map = {a.name: a for a in algorithms} + return comm_group, algo_map, scratch + + +def run_allreduce(algo, comm_group, buffer, dtype, accum_dtype=None, nblocks=0, nthreads_per_block=0): + """Run allreduce in-place on buffer and return a copy of the result.""" + ret = algo.execute( + comm=comm_group.communicator, + input_buffer=buffer.data.ptr, + output_buffer=buffer.data.ptr, + input_size=buffer.nbytes, + output_size=buffer.nbytes, + dtype=dtype, + op=ReduceOp.SUM, + stream=cp.cuda.get_current_stream().ptr, + nblocks=nblocks, + nthreads_per_block=nthreads_per_block, + symmetric_memory=True, + accum_dtype=accum_dtype, + ) + cp.cuda.Device().synchronize() + assert ret == 0, f"Allreduce failed with error code {ret}" + return buffer.copy() + + +# --------------------------------------------------------------------------- +# Test: FP8 E4M3 accumulation correctness +# --------------------------------------------------------------------------- + + +@parametrize_mpi_groups(8) +@pytest.mark.parametrize( + "algo_name", + [ + "default_allreduce_packet", + "default_allreduce_nvls_packet", + "default_allreduce_fullmesh", + "default_allreduce_rsag_zero_copy", + ], +) +@pytest.mark.parametrize("size", [1024, 4096, 16384, 65536, 262144, 1048576]) +def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int): + """Verify that FP8 E4M3 allreduce with higher-precision accumulation is at + least as accurate as native FP8 accumulation, across all algorithm variants.""" + rank = mpi_group.comm.rank + world_size = mpi_group.comm.size + + comm_group, algo_map, scratch = setup_algorithms(mpi_group) + if algo_name not in algo_map: + pytest.skip(f"{algo_name} not available") + algo = algo_map[algo_name] + + buf = GpuBuffer(size, dtype=cp.uint8) + + accum_configs = [ + ("fp8_native", DataType.float8_e4m3), + ("float16", DataType.float16), + ("float32", DataType.float32), + ] + + # rsag_zero_copy and fullmesh need explicit block/thread counts + if "rsag" in algo_name: + nb = max(1, min(32, size // (world_size * 32))) + nt = 1024 + elif "fullmesh" in algo_name: + nb = 35 + nt = 512 + else: + nb = 0 + nt = 0 + + errors = {} + for accum_label, accum_dtype in accum_configs: + # Generate deterministic per-rank data + cp.random.seed(42 + rank) + src_f32 = cp.random.randn(size).astype(cp.float32) + src_f32 = cp.clip(src_f32, -240.0, 240.0) + src_fp8 = float_to_e4m3fn(src_f32) + + # Copy into symmetric buffer + buf[:] = src_fp8 + cp.cuda.Device().synchronize() + + # Run allreduce + result = run_allreduce( + algo, + comm_group, + buf, + dtype=DataType.float8_e4m3, + accum_dtype=accum_dtype, + nblocks=nb, + nthreads_per_block=nt, + ) + result_f32 = e4m3fn_to_float(result) + + # Compute float32 reference: sum all ranks' quantized FP8 inputs in float32 + ref_f32 = cp.zeros(size, dtype=cp.float32) + for r in range(world_size): + cp.random.seed(42 + r) + rank_data = cp.random.randn(size).astype(cp.float32) + rank_data = cp.clip(rank_data, -240.0, 240.0) + rank_data_fp8 = float_to_e4m3fn(rank_data) + ref_f32 += e4m3fn_to_float(rank_data_fp8) + + # Compute errors + abs_err = cp.abs(result_f32 - ref_f32) + mean_abs_err = float(cp.mean(abs_err)) + errors[accum_label] = mean_abs_err + + # Reset between runs + algo.reset() + + # Higher-precision accumulation should be at least as accurate as native fp8 + assert ( + errors["float16"] <= errors["fp8_native"] + 1e-6 + ), f"float16 accum ({errors['float16']:.6f}) worse than native ({errors['fp8_native']:.6f})" + assert ( + errors["float32"] <= errors["fp8_native"] + 1e-6 + ), f"float32 accum ({errors['float32']:.6f}) worse than native ({errors['fp8_native']:.6f})" + + +# --------------------------------------------------------------------------- +# Test: FP8 E4M3B15 accumulation correctness +# --------------------------------------------------------------------------- + + +@parametrize_mpi_groups(8) +@pytest.mark.parametrize( + "algo_name", + [ + "default_allreduce_packet", + "default_allreduce_nvls_packet", + "default_allreduce_rsag_zero_copy", + ], +) +@pytest.mark.parametrize("size", [1024, 4096, 65536]) +def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int): + """Verify that FP8 E4M3B15 allreduce with higher-precision accumulation is at + least as accurate as native E4M3B15 accumulation.""" + rank = mpi_group.comm.rank + world_size = mpi_group.comm.size + + comm_group, algo_map, scratch = setup_algorithms(mpi_group) + if algo_name not in algo_map: + pytest.skip(f"{algo_name} not available") + + algo = algo_map[algo_name] + buf = GpuBuffer(size, dtype=cp.uint8) + + accum_configs = [ + ("e4m3b15_native", DataType.float8_e4m3b15), + ("float16", DataType.float16), + ("float32", DataType.float32), + ] + + # rsag_zero_copy needs explicit block/thread counts, scaled to data size + if "rsag" in algo_name: + nb = max(1, min(32, size // (world_size * 32))) + nt = 1024 + else: + nb = 0 + nt = 0 + + errors = {} + for accum_label, accum_dtype in accum_configs: + # Generate deterministic per-rank random uint8 values in valid e4m3b15 range + cp.random.seed(42 + rank) + raw = cp.random.randint(0, 0x78, (size,), dtype=cp.uint8) + signs = cp.random.randint(0, 2, (size,), dtype=cp.uint8).astype(cp.uint8) << 7 + src_uint8 = raw | signs + # Fix negative zero -> positive zero + src_uint8 = cp.where(src_uint8 == 0x80, cp.uint8(0), src_uint8) + + # Copy into symmetric buffer + buf[:] = src_uint8 + cp.cuda.Device().synchronize() + + # Run allreduce + result = run_allreduce( + algo, + comm_group, + buf, + dtype=DataType.float8_e4m3b15, + accum_dtype=accum_dtype, + nblocks=nb, + nthreads_per_block=nt, + ) + + # Decode result + result_f32 = e4m3b15_to_float(result) + + # Compute float32 reference + ref_f32 = cp.zeros(size, dtype=cp.float32) + for r in range(world_size): + cp.random.seed(42 + r) + raw_r = cp.random.randint(0, 0x78, (size,), dtype=cp.uint8) + signs_r = cp.random.randint(0, 2, (size,), dtype=cp.uint8).astype(cp.uint8) << 7 + bits_r = raw_r | signs_r + bits_r = cp.where(bits_r == 0x80, cp.uint8(0), bits_r) + ref_f32 += e4m3b15_to_float(bits_r) + + # Clamp reference to e4m3b15 representable range + ref_f32 = cp.clip(ref_f32, -0.9375, 0.9375) + + # Compute errors (only on valid entries) + valid = ~cp.isnan(result_f32) & ~cp.isnan(ref_f32) + abs_err = cp.abs(result_f32[valid] - ref_f32[valid]) + mean_abs_err = float(cp.mean(abs_err)) if abs_err.size > 0 else 0.0 + errors[accum_label] = mean_abs_err + + algo.reset() + + # Higher-precision accumulation should be at least as accurate as native + assert ( + errors["float16"] <= errors["e4m3b15_native"] + 1e-8 + ), f"float16 accum ({errors['float16']:.8f}) worse than native ({errors['e4m3b15_native']:.8f})" + assert ( + errors["float32"] <= errors["e4m3b15_native"] + 1e-8 + ), f"float32 accum ({errors['float32']:.8f}) worse than native ({errors['e4m3b15_native']:.8f})" diff --git a/src/core/algorithm.cc b/src/core/algorithm.cc index 99e7b031..ffa53aa8 100644 --- a/src/core/algorithm.cc +++ b/src/core/algorithm.cc @@ -41,7 +41,9 @@ NativeAlgorithm::NativeAlgorithm(std::string name, std::string collective, InitF CommResult NativeAlgorithm::execute(std::shared_ptr comm, const void* input, void* output, size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, std::shared_ptr, int nBlocks, int nThreadsPerBlock, - bool symmetricMemory, const std::unordered_map& extras) { + bool symmetricMemory, const std::unordered_map& extras, + DataType accumDtype) { + if (accumDtype == DataType::AUTO) accumDtype = dtype; if (!initialized_) { initFunc_(comm); initialized_ = true; @@ -53,7 +55,7 @@ CommResult NativeAlgorithm::execute(std::shared_ptr comm, const vo contexts_[ctxKey] = ctx; } return kernelLaunchFunc_(contexts_[ctxKey], input, output, inputSize, outputSize, dtype, op, stream, nBlocks, - nThreadsPerBlock, extras); + nThreadsPerBlock, extras, accumDtype); } const std::string& NativeAlgorithm::name() const { return name_; } @@ -77,10 +79,7 @@ const CollectiveBufferMode& NativeAlgorithm::bufferMode() const { return bufferM Algorithm::Constraint NativeAlgorithm::constraint() const { return constraint_; } -void NativeAlgorithm::reset() { - contexts_.clear(); - initialized_ = false; -} +void NativeAlgorithm::reset() { contexts_.clear(); } void AlgorithmCollection::registerAlgorithm(const std::string collective, const std::string algoName, std::shared_ptr algorithm) { @@ -166,7 +165,7 @@ Algorithm::Constraint DslAlgorithm::constraint() const { return constraint_; } CommResult DslAlgorithm::execute(std::shared_ptr comm, const void* input, void* output, size_t inputSize, size_t outputSize, DataType dtype, ReduceOp, cudaStream_t stream, std::shared_ptr executor, int, int, bool, - const std::unordered_map&) { + const std::unordered_map&, DataType) { if (!executor) { THROW(EXEC, Error, ErrorCode::InvalidUsage, "Executor is null in DslAlgorithm::execute"); } @@ -192,6 +191,10 @@ CommResult DslAlgorithm::execute(std::shared_ptr comm, const void* plan_, stream); break; #endif + case DataType::FLOAT8_E4M3B15: + executor->execute(rank, (__fp8_e4m3b15*)input, (__fp8_e4m3b15*)output, inputSize, outputSize, + DataType::FLOAT8_E4M3B15, plan_, stream); + break; case DataType::INT32: case DataType::UINT32: executor->execute(rank, (int*)input, (int*)output, inputSize, outputSize, DataType::UINT32, plan_, stream); diff --git a/src/core/executor/execution_kernel.cu b/src/core/executor/execution_kernel.cu index 2d36bcf5..28ced77f 100644 --- a/src/core/executor/execution_kernel.cu +++ b/src/core/executor/execution_kernel.cu @@ -82,6 +82,12 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo case DataType::FLOAT8_E5M2: // FP8 is not supported in CUDA execution kernel. break; + case DataType::FLOAT8_E4M3B15: + // fp8_e4m3b15 is a software type not supported in the CUDA execution kernel. + break; + case DataType::AUTO: + // AUTO is a sentinel resolved before reaching this point; nothing to do. + break; } } diff --git a/src/core/include/execution_kernel.hpp b/src/core/include/execution_kernel.hpp index 20147c30..87b88888 100644 --- a/src/core/include/execution_kernel.hpp +++ b/src/core/include/execution_kernel.hpp @@ -210,7 +210,7 @@ MSCCLPP_DEVICE_INLINE void handleReadReduceSend(const Operation& op, void* input sizeof(int4); void* remoteMemory = static_cast(memoryChannelBufferPtrs_[op.inputBufferRefs[index + 1].id]); val = mscclpp::read(remoteMemory, srcOffset + idx); - tmp = cal_vector(tmp, val); + tmp = calVector(tmp, val); } output4[outputOffset4 + idx] = tmp; if constexpr (SendToRemote) { @@ -353,9 +353,9 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPackets(const Operation& op, void* in for (uint32_t index = 0; index < nSrcs; ++index) { PacketType* pkt = (PacketType*)((char*)scratch + scratchOffset_ + 2 * inputOffsets[index]); PacketPayload val = pkt[idx].read(flag_); - data = cal_vector(data, val); + data = calVector(data, val); } - data = cal_vector(data, srcPacketPayload[idx]); + data = calVector(data, srcPacketPayload[idx]); dstPacketPayload[idx] = data; if constexpr (SendToRemote) { @@ -394,9 +394,9 @@ MSCCLPP_DEVICE_INLINE void handleReduceCopySendPackets(const Operation& op, void for (uint32_t index = 0; index < nSrcs; ++index) { PacketType* pkt = (PacketType*)((char*)scratch + scratchOffset_ + 2 * inputOffsets[index]); PacketPayload val = pkt[idx].read(flag_); - data = cal_vector(data, val); + data = calVector(data, val); } - data = cal_vector(data, srcPacketPayload[idx]); + data = calVector(data, srcPacketPayload[idx]); dstPacketPayload[idx] = data; PacketType* dst_val = &dstPkt[idx]; dst_val->write(data, flag_); @@ -464,7 +464,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(const Operation& op, void* input, vo size_t buffOffset = (inputOffsets[index] + getOffset(outputBufferRefs[index].type, offset)) / sizeof(int4); int4 val = buff4[buffOffset + idx]; - tmp = cal_vector(tmp, val); + tmp = calVector(tmp, val); } dst4[dstOffset4 + idx] = tmp; if constexpr (SendToRemote) { @@ -899,6 +899,17 @@ class ExecutionKernel { #endif break; #endif // __FP8_TYPES_EXIST__ + case DataType::FLOAT8_E4M3B15: + executionKernel<__fp8_e4m3b15, PacketType, ReuseScratch><<>>( + rank, (__fp8_e4m3b15*)src, (__fp8_e4m3b15*)dst, (__fp8_e4m3b15*)scratch, scratchOffset, scratchChunkSize, + plan, semaphores, localMemoryIdBegin, flag +#if defined(ENABLE_NPKIT) + , + NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); +#else + ); +#endif + break; case DataType::UINT8: executionKernel<<>>( rank, (uint8_t*)src, (uint8_t*)dst, (uint8_t*)scratch, scratchOffset, scratchChunkSize, plan, semaphores, @@ -910,6 +921,10 @@ class ExecutionKernel { ); #endif break; + case DataType::AUTO: + // AUTO is a sentinel that must be resolved before reaching this point. + assert(false && "DataType::AUTO must be resolved before kernel launch"); + break; } } #else // !defined(MSCCLPP_DEVICE_HIP) diff --git a/src/core/include/reduce_kernel.hpp b/src/core/include/reduce_kernel.hpp index fd9bd1e9..463f827d 100644 --- a/src/core/include/reduce_kernel.hpp +++ b/src/core/include/reduce_kernel.hpp @@ -14,7 +14,7 @@ namespace mscclpp { // Generic element-wise calculation helper template -MSCCLPP_DEVICE_INLINE T cal_elements(const T& a, const T& b) { +MSCCLPP_DEVICE_INLINE T calElements(const T& a, const T& b) { if constexpr (OpType == SUM) { return a + b; } else if constexpr (OpType == MIN) { @@ -24,56 +24,168 @@ MSCCLPP_DEVICE_INLINE T cal_elements(const T& a, const T& b) { } // Generic vector reduction helpers -template -MSCCLPP_DEVICE_INLINE int4 cal_vector_helper(const int4& a, const int4& b) { - int4 ret; - ret.w = bit_cast(cal_elements(bit_cast(a.w), bit_cast(b.w))); - ret.x = bit_cast(cal_elements(bit_cast(a.x), bit_cast(b.x))); - ret.y = bit_cast(cal_elements(bit_cast(a.y), bit_cast(b.y))); - ret.z = bit_cast(cal_elements(bit_cast(a.z), bit_cast(b.z))); - return ret; -} template -MSCCLPP_DEVICE_INLINE uint2 cal_vector_helper(const uint2& a, const uint2& b) { +MSCCLPP_DEVICE_INLINE uint2 calVectorHelper(const uint2& a, const uint2& b) { uint2 ret; - ret.x = bit_cast(cal_elements(bit_cast(a.x), bit_cast(b.x))); - ret.y = bit_cast(cal_elements(bit_cast(a.y), bit_cast(b.y))); + ret.x = bit_cast(calElements(bit_cast(a.x), bit_cast(b.x))); + ret.y = bit_cast(calElements(bit_cast(a.y), bit_cast(b.y))); return ret; } -template -MSCCLPP_DEVICE_INLINE int cal_vector_helper(const int& a, const int& b) { - return bit_cast(cal_elements(bit_cast(a), bit_cast(b))); +/// f32x2 specialization for uint2: uses packed f32x2 operator+ (Blackwell __fadd2_rn when available). +template <> +MSCCLPP_DEVICE_INLINE uint2 calVectorHelper(const uint2& a, const uint2& b) { + f32x2 fa = bit_cast(a); + f32x2 fb = bit_cast(b); + f32x2 fr = fa + fb; + return bit_cast(fr); +} + +template <> +MSCCLPP_DEVICE_INLINE uint2 calVectorHelper(const uint2& a, const uint2& b) { + f32x2 fa = bit_cast(a); + f32x2 fb = bit_cast(b); + f32x2 fr = mscclpp::min(fa, fb); + return bit_cast(fr); } template -MSCCLPP_DEVICE_INLINE uint32_t cal_vector_helper(const uint32_t& a, const uint32_t& b) { - return bit_cast(cal_elements(bit_cast(a), bit_cast(b))); +MSCCLPP_DEVICE_INLINE int4 calVectorHelper(const int4& a, const int4& b) { + int4 ret; + ret.w = bit_cast(calElements(bit_cast(a.w), bit_cast(b.w))); + ret.x = bit_cast(calElements(bit_cast(a.x), bit_cast(b.x))); + ret.y = bit_cast(calElements(bit_cast(a.y), bit_cast(b.y))); + ret.z = bit_cast(calElements(bit_cast(a.z), bit_cast(b.z))); + return ret; } -// cal_vector wrapper - converts scalar types to vector types and calls cal_vector_helper +/// f32x2 specialization for int4: process as two uint2 pairs using packed f32x2 arithmetic. +template <> +MSCCLPP_DEVICE_INLINE int4 calVectorHelper(const int4& a, const int4& b) { + uint2 lo_a = {(uint32_t)a.x, (uint32_t)a.y}; + uint2 hi_a = {(uint32_t)a.z, (uint32_t)a.w}; + uint2 lo_b = {(uint32_t)b.x, (uint32_t)b.y}; + uint2 hi_b = {(uint32_t)b.z, (uint32_t)b.w}; + uint2 lo_r = calVectorHelper(lo_a, lo_b); + uint2 hi_r = calVectorHelper(hi_a, hi_b); + return {(int)lo_r.x, (int)lo_r.y, (int)hi_r.x, (int)hi_r.y}; +} + +template <> +MSCCLPP_DEVICE_INLINE int4 calVectorHelper(const int4& a, const int4& b) { + uint2 lo_a = {(uint32_t)a.x, (uint32_t)a.y}; + uint2 hi_a = {(uint32_t)a.z, (uint32_t)a.w}; + uint2 lo_b = {(uint32_t)b.x, (uint32_t)b.y}; + uint2 hi_b = {(uint32_t)b.z, (uint32_t)b.w}; + uint2 lo_r = calVectorHelper(lo_a, lo_b); + uint2 hi_r = calVectorHelper(hi_a, hi_b); + return {(int)lo_r.x, (int)lo_r.y, (int)hi_r.x, (int)hi_r.y}; +} + +template +MSCCLPP_DEVICE_INLINE int calVectorHelper(const int& a, const int& b) { + return bit_cast(calElements(bit_cast(a), bit_cast(b))); +} + +template +MSCCLPP_DEVICE_INLINE uint32_t calVectorHelper(const uint32_t& a, const uint32_t& b) { + return bit_cast(calElements(bit_cast(a), bit_cast(b))); +} + +/// f32x2 specialization for uint32_t: a single float packed in 32 bits (scalar fallback). +template <> +MSCCLPP_DEVICE_INLINE uint32_t calVectorHelper(const uint32_t& a, const uint32_t& b) { + float fa = bit_cast(a); + float fb = bit_cast(b); + return bit_cast(fa + fb); +} + +template <> +MSCCLPP_DEVICE_INLINE uint32_t calVectorHelper(const uint32_t& a, const uint32_t& b) { + float fa = bit_cast(a); + float fb = bit_cast(b); + return bit_cast(fminf(fa, fb)); +} + +// calVector wrapper – converts scalar types to vector types and calls calVectorHelper template -MSCCLPP_DEVICE_INLINE DataType cal_vector(const DataType& a, const DataType& b) { +MSCCLPP_DEVICE_INLINE DataType calVector(const DataType& a, const DataType& b) { // Define the vectorized computation type based on the element type static_assert(sizeof(DataType) % sizeof(T) == 0, "DataType size must be multiple of T size"); static_assert(sizeof(DataType) >= 4, "DataType size must be at least 4 bytes"); using CompType = typename std::conditional_t< - std::is_same_v, f16x2, + std::is_same_v, f32x2, std::conditional_t< - std::is_same_v, bf16x2, - std::conditional_t, u8x4, + std::is_same_v, f16x2, + std::conditional_t< + std::is_same_v, bf16x2, + std::conditional_t< + std::is_same_v, u8x4, + std::conditional_t, f8_e4m3b15x4, #if defined(__FP8_TYPES_EXIST__) - std::conditional_t, f8_e4m3x4, - std::conditional_t, f8_e5m2x4, -#endif - T -#if defined(__FP8_TYPES_EXIST__) - >>>>>; + std::conditional_t, f8_e4m3x4, + std::conditional_t, f8_e5m2x4, T>> #else - >>>; + T #endif - return cal_vector_helper(a, b); + >>>>>; + return calVectorHelper(a, b); +} + +/// Upcast a packed DataType (containing T elements) to a packed AccDataType (containing AccumT elements). +/// Uses the optimized to<>() specializations when available (e.g. FP8 -> float hardware intrinsics). +/// When AccumT == T, this is a no-op identity. +template +MSCCLPP_DEVICE_INLINE AccDataType upcastVector(const DataType& val) { + if constexpr (std::is_same_v) { + return val; + } else { + constexpr int nElems = sizeof(DataType) / sizeof(T); + using FromVec = VectorType; + using ToVec = VectorType; + ToVec result = mscclpp::to(reinterpret_cast(val)); + return reinterpret_cast(result); + } +} + +/// Downcast a packed AccDataType (containing AccumT elements) back to DataType (containing T elements). +/// Uses the optimized to<>() specializations when available. +/// When AccumT == T, this is a no-op identity. +template +MSCCLPP_DEVICE_INLINE DataType downcastVector(const AccDataType& val) { + if constexpr (std::is_same_v) { + return val; + } else { + constexpr int nElems = sizeof(DataType) / sizeof(T); + using FromVec = VectorType; + using ToVec = VectorType; + FromVec result = mscclpp::to(reinterpret_cast(val)); + return reinterpret_cast(result); + } +} + +/// Accumulate `val` (packed T elements in DataType) into `acc` (packed AccumT elements in AccDataType). +/// When AccumT == T, falls back to the standard calVector. +/// Otherwise, upcasts val to AccumT, reduces element-wise, and returns the AccumT accumulator. +template +MSCCLPP_DEVICE_INLINE AccDataType calVectorAccum(const AccDataType& acc, const DataType& val) { + if constexpr (std::is_same_v) { + return calVector(acc, val); + } else { + constexpr int nElems = sizeof(DataType) / sizeof(T); + using FromVec = VectorType; + using ToVec = VectorType; + + ToVec fv = mscclpp::to(reinterpret_cast(val)); + const ToVec& fa = reinterpret_cast(acc); + ToVec fr; +#pragma unroll + for (int i = 0; i < nElems; ++i) { + fr.data[i] = calElements(fa.data[i], fv.data[i]); + } + return reinterpret_cast(fr); + } } #endif // defined(MSCCLPP_DEVICE_COMPILE) diff --git a/src/ext/collectives/allgather/allgather_fullmesh.cu b/src/ext/collectives/allgather/allgather_fullmesh.cu index 0b288b38..fb51a342 100644 --- a/src/ext/collectives/allgather/allgather_fullmesh.cu +++ b/src/ext/collectives/allgather/allgather_fullmesh.cu @@ -183,7 +183,8 @@ std::shared_ptr AllgatherFullmesh::build() { [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, [[maybe_unused]] DataType dtype, [[maybe_unused]] ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras) -> CommResult { + const std::unordered_map& extras, + [[maybe_unused]] DataType accumDtype) -> CommResult { return self->allgatherKernelFunc(ctx, input, output, inputSize, stream, nBlocks, nThreadsPerBlock, extras); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, diff --git a/src/ext/collectives/allgather/allgather_fullmesh_2.cu b/src/ext/collectives/allgather/allgather_fullmesh_2.cu index cf6027d9..9d169d68 100644 --- a/src/ext/collectives/allgather/allgather_fullmesh_2.cu +++ b/src/ext/collectives/allgather/allgather_fullmesh_2.cu @@ -212,7 +212,8 @@ std::shared_ptr AllgatherFullmesh2::build() { [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, [[maybe_unused]] mscclpp::DataType dtype, [[maybe_unused]] ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras) -> mscclpp::CommResult { + const std::unordered_map& extras, + [[maybe_unused]] mscclpp::DataType accumDtype) -> mscclpp::CommResult { return self->allgatherKernelFunc(ctx, input, output, inputSize, stream, nBlocks, nThreadsPerBlock, extras); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, diff --git a/src/ext/collectives/allreduce/allreduce_allpair_packet.cu b/src/ext/collectives/allreduce/allreduce_allpair_packet.cu index 83950d7c..6cbc8977 100644 --- a/src/ext/collectives/allreduce/allreduce_allpair_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_allpair_packet.cu @@ -47,7 +47,7 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand const int remoteRank = index < rank ? index : index + 1; LL8Packet* dstPkt = (LL8Packet*)scratchBuff + remoteRank * nelems; uint32_t val = dstPkt[idx].read(flag, -1); - data = cal_vector(val, data); + data = calVector(val, data); } dst[idx] = data; } @@ -67,7 +67,7 @@ inline std::pair getDefaultBlockNumAndThreadNum(size_t inputSize, int return {(worldSize - 1) * 4, 512}; } -template +template struct AllpairAdapter { static cudaError_t call(const void* buff, void* scratch, void* resultBuff, void* memoryChannels, void*, DeviceHandle*, DeviceHandle*, size_t channelInOffset, size_t, @@ -94,7 +94,8 @@ void AllreduceAllpairPacket::initialize(std::shared_ptr comm) { CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map&) { + const std::unordered_map&, + DataType accumDtype) { auto algoCtx = std::static_pointer_cast(ctx); std::pair blockAndThreadNum{nBlocks, nThreadsPerBlock}; if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) { @@ -105,7 +106,7 @@ CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptr(op, dtype); + AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { WARN("Unsupported operation or data type for allreduce: op=%d, dtype=%d", op, static_cast(dtype)); return CommResult::CommInvalidArgument; @@ -161,9 +162,9 @@ std::shared_ptr AllreduceAllpairPacket::build() { [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) { + int nThreadsPerBlock, const std::unordered_map& extras, DataType accumDtype) { return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, - extras); + extras, accumDtype); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, diff --git a/src/ext/collectives/allreduce/allreduce_fullmesh.cu b/src/ext/collectives/allreduce/allreduce_fullmesh.cu index 13c63ba1..ee46fd77 100644 --- a/src/ext/collectives/allreduce/allreduce_fullmesh.cu +++ b/src/ext/collectives/allreduce/allreduce_fullmesh.cu @@ -9,7 +9,7 @@ namespace mscclpp { namespace collective { -template +template __global__ void __launch_bounds__(512, 1) allreduceFullmesh(T* buff, T* scratch, T* resultBuff, DeviceHandle* memoryChannels, DeviceHandle* memoryOutChannels, size_t channelOutDataOffset, int rank, @@ -26,6 +26,10 @@ __global__ void __launch_bounds__(512, 1) int4* scratch4 = reinterpret_cast((char*)scratch); int4* resultBuff4 = reinterpret_cast(resultBuff); + // AccumVec: wider vector for mixed-precision accumulation. When AccumT==T, this is just int4 (no-op). + constexpr int nElemsPerInt4 = sizeof(int4) / sizeof(T); + using AccumVec = std::conditional_t, int4, mscclpp::VectorType>; + // Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4` constexpr size_t unitNInt4 = 512; const size_t maxNInt4PerBlock = @@ -81,12 +85,14 @@ __global__ void __launch_bounds__(512, 1) __syncthreads(); for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { - int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; + int4 rawData = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; + AccumVec acc = mscclpp::upcastVector(rawData); for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx]; - data = cal_vector(val, data); + acc = mscclpp::calVectorAccum(acc, val); } + int4 data = mscclpp::downcastVector(acc); resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), @@ -121,12 +127,14 @@ __global__ void __launch_bounds__(512, 1) __syncthreads(); for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { - int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; + int4 rawData = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; + AccumVec acc = mscclpp::upcastVector(rawData); for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx]; - data = cal_vector(val, data); + acc = mscclpp::calVectorAccum(acc, val); } + int4 data = mscclpp::downcastVector(acc); resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), @@ -144,7 +152,7 @@ __global__ void __launch_bounds__(512, 1) } } -template +template struct AllreduceAllconnectAdapter { static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* memoryOutChannels, DeviceHandle*, DeviceHandle*, size_t, @@ -155,7 +163,7 @@ struct AllreduceAllconnectAdapter { size_t nelems = inputSize / sizeof(T); if (nBlocks == 0) nBlocks = 35; if (nThreadsPerBlock == 0) nThreadsPerBlock = 512; - allreduceFullmesh<<>>( + allreduceFullmesh<<>>( (T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, (ChannelType*)memoryOutChannels, channelOutDataOffset, rank, nRanksPerNode, worldSize, nelems); return cudaGetLastError(); @@ -174,10 +182,10 @@ void AllreduceFullmesh::initialize(std::shared_ptr comm) { localScratchMemory_ = std::move(localMemory); } -CommResult AllreduceFullmesh::allreduceKernelFunc(const std::shared_ptr ctx_void, const void* input, void* output, - size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, - int nBlocks, int nThreadsPerBlock, - const std::unordered_map&) { +CommResult AllreduceFullmesh::allreduceKernelFunc( + const std::shared_ptr ctx_void, const void* input, void* output, size_t inputSize, DataType dtype, + ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, + [[maybe_unused]] const std::unordered_map& extras, DataType accumDtype) { auto ctx = std::static_pointer_cast(ctx_void); size_t recvBytes; CUdeviceptr recvBasePtr; @@ -198,7 +206,7 @@ CommResult AllreduceFullmesh::allreduceKernelFunc(const std::shared_ptr ct } inputChannelHandles = this->memoryChannelsMap_[input].second; - AllreduceFunc allreduce = dispatch(op, dtype); + AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { WARN("Unsupported operation or data type for allreduce: op=%d, dtype=%d", static_cast(op), static_cast(dtype)); @@ -261,9 +269,10 @@ std::shared_ptr AllreduceFullmesh::build() { [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) -> CommResult { + int nThreadsPerBlock, const std::unordered_map& extras, + DataType accumDtype) -> CommResult { return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, - extras); + extras, accumDtype); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, diff --git a/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu b/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu index b542a6a6..2d71cd63 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu @@ -146,7 +146,7 @@ __global__ void __launch_bounds__(1024, 1) #endif } -template +template struct NvlsBlockPipelineAdapter { static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void*, DeviceHandle* nvlsChannels, DeviceHandle*, size_t, size_t, @@ -155,6 +155,9 @@ struct NvlsBlockPipelineAdapter { // uint8_t is not supported for NVLS (no hardware support for byte-level reduction) if constexpr (std::is_same_v) { return cudaErrorNotSupported; + } else if constexpr (std::is_same_v) { + // fp8_e4m3b15 is a software-only type with no hardware NVLS support. + return cudaErrorNotSupported; } else #if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS if constexpr (std::is_same_v || std::is_same_v) { @@ -187,9 +190,10 @@ void AllreduceNvlsBlockPipeline::initialize(std::shared_ptr comm) CommResult AllreduceNvlsBlockPipeline::allreduceKernelFunc(const std::shared_ptr ctx_void, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map&) { + const std::unordered_map& extras, + DataType accumDtype) { auto ctx = std::static_pointer_cast(ctx_void); - AllreduceFunc allreduce = dispatch(op, dtype); + AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast(dtype)); return CommResult::CommInvalidArgument; @@ -235,9 +239,9 @@ std::shared_ptr AllreduceNvlsBlockPipeline::build() { [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) { + int nThreadsPerBlock, const std::unordered_map& extras, DataType accumDtype) { return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, - extras); + extras, accumDtype); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, diff --git a/src/ext/collectives/allreduce/allreduce_nvls_packet.cu b/src/ext/collectives/allreduce/allreduce_nvls_packet.cu index 9824fbcd..a616485e 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_packet.cu @@ -1,15 +1,17 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +#include + #include "allreduce/allreduce_nvls_packet.hpp" #include "allreduce/common.hpp" #include "collective_utils.hpp" -#include "debug.h" +#include "logger.hpp" namespace mscclpp { namespace collective { -template +template __global__ void __launch_bounds__(1024, 1) allreduceNvlsPacket([[maybe_unused]] const T* input, [[maybe_unused]] T* scratch, [[maybe_unused]] T* output, [[maybe_unused]] mscclpp::DeviceHandle* multicast, @@ -31,15 +33,16 @@ __global__ void __launch_bounds__(1024, 1) mscclpp::SwitchChannelDeviceHandle::multimemStore(*(mscclpp::f32x2*)(&pkt), multiPkt + i); } for (uint32_t i = tid; i < nPktPerRank * worldSize; i += blockDim.x * gridDim.x) { - uint data = src[i]; + // When T == AccumT, stay with raw uint to avoid type mismatch in identity path. + using AccRaw = + std::conditional_t, uint, mscclpp::VectorType>; + AccRaw acc = mscclpp::upcastVector(src[i]); for (int peer = 0; peer < worldSize; peer++) { - if (peer == rank) { - continue; - } + if (peer == rank) continue; uint val = scratchPkt[peer * worldSize * nPktPerRank + i].read(flag); - data = cal_vector(data, val); + acc = mscclpp::calVectorAccum(acc, val); } - dst[i] = data; + dst[i] = mscclpp::downcastVector(acc); } __syncthreads(); if (threadIdx.x == 0) { @@ -62,13 +65,13 @@ inline std::pair getDefaultBlockNumAndThreadNum(size_t inputSize) { return {blockNum, threadNum}; } -template +template struct AllreduceNvlsPacketAdapter { static cudaError_t call(const void* input, void* scratch, void* output, void*, void*, DeviceHandle* nvlsChannels, DeviceHandle*, size_t, size_t, size_t scratchBufferSize, int rank, int, int worldSize, size_t inputSize, cudaStream_t stream, void* flags, uint32_t flagBufferSize, uint32_t, int nBlocks, int nThreadsPerBlock) { - allreduceNvlsPacket<<>>( + allreduceNvlsPacket<<>>( (const T*)input, (T*)scratch, (T*)output, nvlsChannels, inputSize / sizeof(T), scratchBufferSize, rank, worldSize, flags, flagBufferSize); return cudaGetLastError(); @@ -78,6 +81,8 @@ struct AllreduceNvlsPacketAdapter { void AllreduceNvlsPacket::initialize(std::shared_ptr comm) { int nSwitchChannels = 1; this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels); + this->switchChannels_ = + setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels); } AlgorithmCtxKey AllreduceNvlsPacket::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) { @@ -92,9 +97,7 @@ std::shared_ptr AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr< ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); // setup channels - int nSwitchChannels = 1; - ctx->switchChannels = - setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels); + ctx->switchChannels = this->switchChannels_; ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels); return ctx; } @@ -102,19 +105,20 @@ std::shared_ptr AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr< CommResult AllreduceNvlsPacket::allreduceKernelFunc(const std::shared_ptr ctx_void, const void* input, void* output, size_t inputSize, mscclpp::DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map&) { + const std::unordered_map&, + mscclpp::DataType accumDtype) { auto ctx = std::static_pointer_cast(ctx_void); std::pair blockAndThreadNum = {nBlocks, nThreadsPerBlock}; if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) { blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize); } if (blockAndThreadNum.first > maxBlockNum_) { - WARN("Block number %d exceeds the maximum limit %d", blockAndThreadNum.first, maxBlockNum_); + WARN(ALGO, "Block number ", blockAndThreadNum.first, " exceeds the maximum limit ", maxBlockNum_); return CommResult::CommInvalidArgument; } - AllreduceFunc allreduce = dispatch(op, dtype); + AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { - WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast(dtype)); + WARN(ALGO, "Unsupported operation or data type for allreduce, dtype=", static_cast(dtype)); return CommResult::CommInvalidArgument; } cudaError_t error = @@ -122,7 +126,7 @@ CommResult AllreduceNvlsPacket::allreduceKernelFunc(const std::shared_ptr 0, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerNode, ctx->workSize, inputSize, stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_, 0, blockAndThreadNum.first, blockAndThreadNum.second); if (error != cudaSuccess) { - WARN("AllreduceNvlsPacket failed with error: %s", cudaGetErrorString(error)); + WARN(ALGO, "AllreduceNvlsPacket failed with error: ", cudaGetErrorString(error)); return CommResult::CommUnhandledCudaError; } return CommResult::CommSuccess; @@ -136,9 +140,10 @@ std::shared_ptr AllreduceNvlsPacket::build() { [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, mscclpp::DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) { + int nThreadsPerBlock, const std::unordered_map& extras, + mscclpp::DataType accumDtype) { return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, - extras); + extras, accumDtype); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, diff --git a/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu b/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu index bc03ab26..3bb054da 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu @@ -109,7 +109,7 @@ __global__ void __launch_bounds__(1024, 1) #endif } -template +template struct NvlsWarpPipelineAdapter { static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void*, DeviceHandle* nvlsChannels, DeviceHandle*, size_t, size_t, @@ -118,6 +118,9 @@ struct NvlsWarpPipelineAdapter { // uint8_t is not supported for NVLS (no hardware support for byte-level reduction) if constexpr (std::is_same_v) { return cudaErrorNotSupported; + } else if constexpr (std::is_same_v) { + // fp8_e4m3b15 is a software-only type with no hardware NVLS support. + return cudaErrorNotSupported; } else #if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS if constexpr (std::is_same_v || std::is_same_v) { @@ -147,12 +150,12 @@ void AllreduceNvlsWarpPipeline::initialize(std::shared_ptr comm) { this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels_); } -CommResult AllreduceNvlsWarpPipeline::allreduceKernelFunc(const std::shared_ptr ctx_void, const void* input, - void* output, size_t inputSize, DataType dtype, ReduceOp op, - cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map&) { +CommResult AllreduceNvlsWarpPipeline::allreduceKernelFunc( + const std::shared_ptr ctx_void, const void* input, void* output, size_t inputSize, DataType dtype, + ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, + [[maybe_unused]] const std::unordered_map& extras, DataType accumDtype) { auto ctx = std::static_pointer_cast(ctx_void); - AllreduceFunc allreduce = dispatch(op, dtype); + AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast(dtype)); return CommResult::CommInvalidArgument; @@ -198,9 +201,9 @@ std::shared_ptr AllreduceNvlsWarpPipeline::build() { [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) { + int nThreadsPerBlock, const std::unordered_map& extras, DataType accumDtype) { return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, - extras); + extras, accumDtype); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, diff --git a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu index f251bcda..e7f2028f 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu @@ -67,7 +67,7 @@ __global__ void __launch_bounds__(1024, 1) #endif } -template +template struct NvlsAdapter { static cudaError_t call(const void*, void*, void*, void* memoryChannels, void*, mscclpp::DeviceHandle* nvlsChannels, @@ -77,6 +77,9 @@ struct NvlsAdapter { // uint8_t is not supported for NVLS (no hardware support for byte-level reduction) if constexpr (std::is_same_v) { return cudaErrorNotSupported; + } else if constexpr (std::is_same_v) { + // fp8_e4m3b15 is a software-only type with no hardware NVLS support. + return cudaErrorNotSupported; } else #if (!defined(__CUDA_ARCH_SPECIFIC__) && !defined(__CUDA_ARCH_FAMILY_SPECIFIC__)) || (__CUDA_ARCH__ < 1000) if constexpr (std::is_same_v || std::is_same_v) { @@ -114,13 +117,14 @@ void AllreduceNvls::initialize(std::shared_ptr comm) { CommResult AllreduceNvls::allreduceKernelFunc(const std::shared_ptr ctx_void, const void* input, void* output, size_t inputSize, mscclpp::DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map&) { + [[maybe_unused]] const std::unordered_map& extras, + mscclpp::DataType accumDtype) { if (!symmetricMemory_) { WARN("AllreduceNvls requires symmetric memory for now."); return CommResult::CommInvalidArgument; } auto ctx = std::static_pointer_cast(ctx_void); - AllreduceFunc allreduce = dispatch(op, dtype); + AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast(dtype)); return CommResult::CommInvalidArgument; @@ -203,9 +207,10 @@ std::shared_ptr AllreduceNvls::build() { [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, mscclpp::DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) { + int nThreadsPerBlock, const std::unordered_map& extras, + mscclpp::DataType accumDtype) { return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, - extras); + extras, accumDtype); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, diff --git a/src/ext/collectives/allreduce/allreduce_packet.cu b/src/ext/collectives/allreduce/allreduce_packet.cu index ceb545ee..e2d8ef73 100644 --- a/src/ext/collectives/allreduce/allreduce_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_packet.cu @@ -2,16 +2,17 @@ // Licensed under the MIT License. #include +#include #include "allreduce/allreduce_packet.hpp" #include "allreduce/common.hpp" #include "collective_utils.hpp" -#include "debug.h" +#include "logger.hpp" namespace mscclpp { namespace collective { -template +template __global__ void __launch_bounds__(1024, 1) allreducePacket(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* memoryChannels, size_t channelDataOffset, size_t scratchBufferSize, int rank, int nRanksPerNode, int worldSize, @@ -92,12 +93,21 @@ __global__ void __launch_bounds__(1024, 1) // step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) { uint2 data = src[idx]; - for (int index = 0; index < nPeers; index++) { - const int remoteRank = index < rank ? index : index + 1; - mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank; - uint2 val = dstPkt[idx].read(flag); - data.x = cal_vector(val.x, data.x); - data.y = cal_vector(val.y, data.y); + { + // When T == AccumT, stay with raw uint32_t to avoid type mismatch in identity path. + using AccRaw = std::conditional_t, uint32_t, + mscclpp::VectorType>; + AccRaw accX = mscclpp::upcastVector(data.x); + AccRaw accY = mscclpp::upcastVector(data.y); + for (int index = 0; index < nPeers; index++) { + const int remoteRank = index < rank ? index : index + 1; + mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank; + uint2 val = dstPkt[idx].read(flag); + accX = mscclpp::calVectorAccum(accX, val.x); + accY = mscclpp::calVectorAccum(accY, val.y); + } + data.x = mscclpp::downcastVector(accX); + data.y = mscclpp::downcastVector(accY); } dst[idx].x = data.x; @@ -142,7 +152,7 @@ __global__ void __launch_bounds__(1024, 1) #endif } -template +template struct PacketAdapter { static cudaError_t call(const void* buff, void* scratch, void* resultBuff, void* memoryChannels, void*, DeviceHandle*, DeviceHandle*, size_t channelInOffset, size_t, @@ -155,12 +165,12 @@ struct PacketAdapter { nBlocks = nBlocks / (worldSize - 1) * (worldSize - 1); #if defined(ENABLE_NPKIT) size_t sharedMemSize = sizeof(NpKitEvent) * NPKIT_SHM_NUM_EVENTS; - allreducePacket<<>>( + allreducePacket<<>>( (T*)buff, (T*)scratch, (T*)resultBuff, (ChannelType*)memoryChannels, channelInOffset, scratchBufferSize, rank, nRanksPerNode, worldSize, nelems, flags, flagBufferSize, numScratchBuff, NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); #else - allreducePacket<<>>( + allreducePacket<<>>( (T*)buff, (T*)scratch, (T*)resultBuff, (ChannelType*)memoryChannels, channelInOffset, scratchBufferSize, rank, nRanksPerNode, worldSize, nelems, flags, flagBufferSize, numScratchBuff); #endif @@ -186,18 +196,22 @@ inline std::pair getDefaultBlockNumAndThreadNum(size_t inputSize, int } } -#if defined(__FP8_TYPES_EXIST__) // FP8-specific tuning for 32KB-256KB range - if (dtype == DataType::FLOAT8_E4M3 || dtype == DataType::FLOAT8_E5M2) { - if (inputSize < (64 << 10)) { - nThreadsPerBlock = 64; - } else if (inputSize >= (64 << 10) && inputSize <= (128 << 10)) { - nThreadsPerBlock = 128; - } else if (inputSize >= (128 << 10) && inputSize <= (256 << 10)) { - nThreadsPerBlock = 256; + { + bool isFp8 = dtype == DataType::FLOAT8_E4M3B15; +#if defined(__FP8_TYPES_EXIST__) + isFp8 = isFp8 || dtype == DataType::FLOAT8_E4M3 || dtype == DataType::FLOAT8_E5M2; +#endif + if (isFp8) { + if (inputSize < (64 << 10)) { + nThreadsPerBlock = 64; + } else if (inputSize >= (64 << 10) && inputSize <= (128 << 10)) { + nThreadsPerBlock = 128; + } else if (inputSize >= (128 << 10) && inputSize <= (256 << 10)) { + nThreadsPerBlock = 256; + } } } -#endif #endif return {nBlocks, nThreadsPerBlock}; } @@ -213,7 +227,8 @@ void AllreducePacket::initialize(std::shared_ptr comm) { CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr ctx_void, const void* input, void* output, size_t inputSize, [[maybe_unused]] DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map&) { + const std::unordered_map&, + DataType accumDtype) { auto ctx = std::static_pointer_cast(ctx_void); std::pair blockAndThreadNum = {nBlocks, nThreadsPerBlock}; if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) { @@ -225,9 +240,10 @@ CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr ctx_ MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)input)); size_t channelInOffset = (char*)input - (char*)sendBasePtr; - AllreduceFunc allreduce = dispatch(op, dtype); + AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { - WARN("Unsupported operation or data type for allreduce: op=%d, dtype=%d", op, static_cast(dtype)); + WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast(op), + ", dtype=", static_cast(dtype)); return CommResult::CommInvalidArgument; } cudaError_t error = @@ -236,7 +252,7 @@ CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr ctx_ stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_, this->nSegmentsForScratchBuffer_, blockAndThreadNum.first, blockAndThreadNum.second); if (error != cudaSuccess) { - WARN("AllreducePacket failed with error: %s", cudaGetErrorString(error)); + WARN(ALGO, "AllreducePacket failed with error: ", cudaGetErrorString(error)); return CommResult::CommUnhandledCudaError; } return CommResult::CommSuccess; @@ -280,9 +296,9 @@ std::shared_ptr AllreducePacket::build() { "default_allreduce_packet", "allreduce", [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) { + int nThreadsPerBlock, const std::unordered_map& extras, DataType accumDtype) { return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, - extras); + extras, accumDtype); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, diff --git a/src/ext/collectives/allreduce/allreduce_rsag.cu b/src/ext/collectives/allreduce/allreduce_rsag.cu index d5be2257..db471b93 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag.cu @@ -87,7 +87,7 @@ __global__ void __launch_bounds__(1024, 1) int rankIdx = (rank + i + 1) % nRanksPerNode; int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1; int4 data = mscclpp::read(((void**)remoteMemories)[peerIdx], offset); - tmp = cal_vector(data, tmp); + tmp = calVector(data, tmp); } for (uint32_t i = 0; i < nPeers; i++) { int rankIdx = (rank + i + 1) % nRanksPerNode; @@ -123,7 +123,7 @@ __global__ void __launch_bounds__(1024, 1) } } -template +template struct AllreduceRsAgAdapter { static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories, DeviceHandle* switchChannel, DeviceHandle*, size_t, size_t, @@ -166,9 +166,9 @@ void AllreduceRsAg::initialize(std::shared_ptr comm) { CommResult AllreduceRsAg::allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map&) { + const std::unordered_map&, DataType accumDtype) { auto algoCtx = std::static_pointer_cast(ctx); - AllreduceFunc allreduce = dispatch(op, dtype); + AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast(op), ", dtype=", static_cast(dtype)); @@ -213,9 +213,10 @@ std::shared_ptr AllreduceRsAg::build() { [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) -> CommResult { + int nThreadsPerBlock, const std::unordered_map& extras, + DataType accumDtype) -> CommResult { return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, - extras); + extras, accumDtype); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, diff --git a/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu b/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu index a230d8cd..eabe3dc5 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu @@ -168,7 +168,7 @@ __global__ void __launch_bounds__(1024, 1) uint32_t peerSlotOffset = baseOffset + remoteRankId * nInt4PerIter + threadIdInPut + putStep * blockDim.x * nblocksForPut; int4 data = scratch4[peerSlotOffset]; - tmp = cal_vector(data, tmp); + tmp = calVector(data, tmp); } storeVec(resultBuff, myChunkOffset, tmp, nelems); // Broadcast reduced result to all peers' scratch at SCATTER_AG_OFFSET + rank * nInt4PerIter @@ -220,7 +220,7 @@ __global__ void __launch_bounds__(1024, 1) } } -template +template struct AllreduceRsAgPipelineAdapter { static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories, DeviceHandle* switchChannel, DeviceHandle*, size_t, size_t, @@ -274,12 +274,12 @@ void AllreduceRsAgPipeline::initialize(std::shared_ptr comm) { cudaMemcpyHostToDevice); } -CommResult AllreduceRsAgPipeline::allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, - size_t inputSize, DataType dtype, ReduceOp op, - cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map&) { +CommResult AllreduceRsAgPipeline::allreduceKernelFunc( + const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, + cudaStream_t stream, int nBlocks, int nThreadsPerBlock, + [[maybe_unused]] const std::unordered_map& extras, DataType accumDtype) { auto algoCtx = std::static_pointer_cast(ctx); - AllreduceFunc allreduce = dispatch(op, dtype); + AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast(op), ", dtype=", static_cast(dtype)); @@ -320,9 +320,10 @@ std::shared_ptr AllreduceRsAgPipeline::build() { [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) -> CommResult { + int nThreadsPerBlock, const std::unordered_map& extras, + DataType accumDtype) -> CommResult { return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, - extras); + extras, accumDtype); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, diff --git a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu index caac07ae..f95ba7e3 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +#include + #include "allreduce/allreduce_rsag_zero_copy.hpp" #include "allreduce/common.hpp" #include "collective_utils.hpp" @@ -36,7 +38,7 @@ __device__ mscclpp::DeviceSyncer globalSyncer; // the extra copy steps of the standard RSAG. The NRanksPerNode template // parameter enables compile-time unrolling of peer loops (supports 4 or 8). -template +template __global__ void __launch_bounds__(1024, 1) allreduceRsAgZeroCopy(T* buff, T* scratch, T* resultBuff, DeviceHandle* memoryChannels, DeviceHandle* switchChannels, void* remoteMemories, int rank, int worldSize, @@ -73,19 +75,26 @@ __global__ void __launch_bounds__(1024, 1) } __syncthreads(); int4 data[NPeers]; + // AccumInt4: when AccumT != T, use a wider accumulator type. + // For AccumT == T, this is just int4 (no-op conversion). + constexpr int nElemsPerInt4 = sizeof(int4) / sizeof(T); + // When T == AccumT, stay with raw int4 to avoid type mismatch in identity path. + using AccumVec = std::conditional_t, int4, mscclpp::VectorType>; for (uint32_t idx = threadIdx.x; idx < nInt4PerBlock; idx += blockDim.x) { uint32_t offset = idx + offset4 + rank * nInt4PerRank; if (offset >= nInt4Total) continue; - int4 tmp = buff4[offset]; + int4 tmp_raw = buff4[offset]; #pragma unroll for (int i = 0; i < NPeers; i++) { int rankIdx = (rank + i + 1) % NRanksPerNode; int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1; data[i] = mscclpp::read(((void**)remoteMemories)[peerIdx], offset); } + AccumVec acc = mscclpp::upcastVector(tmp_raw); for (int i = 0; i < NPeers; i++) { - tmp = cal_vector(data[i], tmp); + acc = mscclpp::calVectorAccum(acc, data[i]); } + int4 tmp = mscclpp::downcastVector(acc); #pragma unroll for (int i = 0; i < NPeers; i++) { int rankIdx = (rank + i + 1) % NRanksPerNode; @@ -102,7 +111,7 @@ __global__ void __launch_bounds__(1024, 1) } } -template +template struct AllreduceRsAgZeroCopyAdapter { static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories, DeviceHandle* switchChannel, DeviceHandle*, size_t, size_t, @@ -118,11 +127,11 @@ struct AllreduceRsAgZeroCopyAdapter { } } if (nRanksPerNode == 4) { - allreduceRsAgZeroCopy<4, OpType, T> + allreduceRsAgZeroCopy<4, OpType, T, AccumT> <<>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, switchChannel, remoteMemories, rank, worldSize, nelems); } else if (nRanksPerNode == 8) { - allreduceRsAgZeroCopy<8, OpType, T> + allreduceRsAgZeroCopy<8, OpType, T, AccumT> <<>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, switchChannel, remoteMemories, rank, worldSize, nelems); } else { @@ -145,9 +154,10 @@ void AllreduceRsAgZeroCopy::initialize(std::shared_ptr comm) { CommResult AllreduceRsAgZeroCopy::allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map&) { + const std::unordered_map&, + DataType accumDtype) { auto algoCtx = std::static_pointer_cast(ctx); - AllreduceFunc allreduce = dispatch(op, dtype); + AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast(op), ", dtype=", static_cast(dtype)); @@ -220,9 +230,10 @@ std::shared_ptr AllreduceRsAgZeroCopy::build() { [self](std::shared_ptr comm) { self->initialize(comm); }, [self](const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras) -> CommResult { + int nThreadsPerBlock, const std::unordered_map& extras, + DataType accumDtype) -> CommResult { return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, op, stream, nBlocks, nThreadsPerBlock, - extras); + extras, accumDtype); }, [self](std::shared_ptr comm, const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize, diff --git a/src/ext/collectives/include/allreduce/allreduce_allpair_packet.hpp b/src/ext/collectives/include/allreduce/allreduce_allpair_packet.hpp index bd402cfa..362308b2 100644 --- a/src/ext/collectives/include/allreduce/allreduce_allpair_packet.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_allpair_packet.hpp @@ -20,7 +20,7 @@ class AllreduceAllpairPacket : public AlgorithmBuilder { void initialize(std::shared_ptr comm); CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras); + const std::unordered_map& extras, DataType accumDtype); std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, DataType); diff --git a/src/ext/collectives/include/allreduce/allreduce_fullmesh.hpp b/src/ext/collectives/include/allreduce/allreduce_fullmesh.hpp index fa811b15..a54352b3 100644 --- a/src/ext/collectives/include/allreduce/allreduce_fullmesh.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_fullmesh.hpp @@ -16,7 +16,7 @@ class AllreduceFullmesh : public mscclpp::AlgorithmBuilder { void initialize(std::shared_ptr comm); CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras); + const std::unordered_map& extras, DataType accumDtype); std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, DataType); diff --git a/src/ext/collectives/include/allreduce/allreduce_nvls_block_pipeline.hpp b/src/ext/collectives/include/allreduce/allreduce_nvls_block_pipeline.hpp index 8b9b04ae..81b74add 100644 --- a/src/ext/collectives/include/allreduce/allreduce_nvls_block_pipeline.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_nvls_block_pipeline.hpp @@ -19,7 +19,7 @@ class AllreduceNvlsBlockPipeline : public AlgorithmBuilder { void initialize(std::shared_ptr comm); CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras); + const std::unordered_map& extras, DataType accumDtype); std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, DataType); diff --git a/src/ext/collectives/include/allreduce/allreduce_nvls_packet.hpp b/src/ext/collectives/include/allreduce/allreduce_nvls_packet.hpp index 65a48923..fb0c63b8 100644 --- a/src/ext/collectives/include/allreduce/allreduce_nvls_packet.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_nvls_packet.hpp @@ -21,7 +21,8 @@ class AllreduceNvlsPacket : public mscclpp::AlgorithmBuilder { void initialize(std::shared_ptr comm); CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, mscclpp::DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, - int nThreadsPerBlock, const std::unordered_map& extras); + int nThreadsPerBlock, const std::unordered_map& extras, + mscclpp::DataType accumDtype); std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, mscclpp::DataType); @@ -34,6 +35,7 @@ class AllreduceNvlsPacket : public mscclpp::AlgorithmBuilder { uintptr_t flagBuffer_; size_t flagBufferSize_; std::vector> nvlsConnections_; + std::vector switchChannels_; }; } // namespace collective } // namespace mscclpp diff --git a/src/ext/collectives/include/allreduce/allreduce_nvls_warp_pipeline.hpp b/src/ext/collectives/include/allreduce/allreduce_nvls_warp_pipeline.hpp index e392b54e..8f02a873 100644 --- a/src/ext/collectives/include/allreduce/allreduce_nvls_warp_pipeline.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_nvls_warp_pipeline.hpp @@ -19,7 +19,7 @@ class AllreduceNvlsWarpPipeline : public AlgorithmBuilder { void initialize(std::shared_ptr comm); CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras); + const std::unordered_map& extras, DataType accumDtype); std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, DataType); diff --git a/src/ext/collectives/include/allreduce/allreduce_nvls_zero_copy.hpp b/src/ext/collectives/include/allreduce/allreduce_nvls_zero_copy.hpp index d0593500..d53ea180 100644 --- a/src/ext/collectives/include/allreduce/allreduce_nvls_zero_copy.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_nvls_zero_copy.hpp @@ -19,7 +19,7 @@ class AllreduceNvls : public AlgorithmBuilder { void initialize(std::shared_ptr comm); CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras); + const std::unordered_map& extras, DataType accumDtype); std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, DataType); diff --git a/src/ext/collectives/include/allreduce/allreduce_packet.hpp b/src/ext/collectives/include/allreduce/allreduce_packet.hpp index f0438dea..de7ca471 100644 --- a/src/ext/collectives/include/allreduce/allreduce_packet.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_packet.hpp @@ -20,7 +20,7 @@ class AllreducePacket : public AlgorithmBuilder { void initialize(std::shared_ptr comm); CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras); + const std::unordered_map& extras, DataType accumDtype); std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, DataType); diff --git a/src/ext/collectives/include/allreduce/allreduce_rsag.hpp b/src/ext/collectives/include/allreduce/allreduce_rsag.hpp index 6e033f67..1fd663da 100644 --- a/src/ext/collectives/include/allreduce/allreduce_rsag.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_rsag.hpp @@ -19,7 +19,7 @@ class AllreduceRsAg : public mscclpp::AlgorithmBuilder { void initialize(std::shared_ptr comm); CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras); + const std::unordered_map& extras, DataType accumDtype); std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, DataType); diff --git a/src/ext/collectives/include/allreduce/allreduce_rsag_pipeline.hpp b/src/ext/collectives/include/allreduce/allreduce_rsag_pipeline.hpp index 2a740ac0..7629f2fe 100644 --- a/src/ext/collectives/include/allreduce/allreduce_rsag_pipeline.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_rsag_pipeline.hpp @@ -19,7 +19,7 @@ class AllreduceRsAgPipeline : public mscclpp::AlgorithmBuilder { void initialize(std::shared_ptr comm); CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras); + const std::unordered_map& extras, DataType accumDtype); std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, DataType); diff --git a/src/ext/collectives/include/allreduce/allreduce_rsag_zero_copy.hpp b/src/ext/collectives/include/allreduce/allreduce_rsag_zero_copy.hpp index 6153a0e4..05bf2ef3 100644 --- a/src/ext/collectives/include/allreduce/allreduce_rsag_zero_copy.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_rsag_zero_copy.hpp @@ -18,7 +18,7 @@ class AllreduceRsAgZeroCopy : public mscclpp::AlgorithmBuilder { void initialize(std::shared_ptr comm); CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras); + const std::unordered_map& extras, DataType accumDtype); std::shared_ptr initAllreduceContext(std::shared_ptr comm, const void*, void* output, size_t, DataType); diff --git a/src/ext/collectives/include/allreduce/common.hpp b/src/ext/collectives/include/allreduce/common.hpp index 9bfac69a..1e0e7e69 100644 --- a/src/ext/collectives/include/allreduce/common.hpp +++ b/src/ext/collectives/include/allreduce/common.hpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -#ifndef MSCCLPP_ALLREDUCE_COMMOM_HPP_ -#define MSCCLPP_ALLREDUCE_COMMOM_HPP_ +#ifndef MSCCLPP_ALLREDUCE_COMMON_HPP_ +#define MSCCLPP_ALLREDUCE_COMMON_HPP_ #include #include @@ -77,55 +77,51 @@ using AllreduceFunc = mscclpp::DeviceHandle*, size_t, size_t, size_t, int, int, int, size_t, cudaStream_t, void*, uint32_t, uint32_t, int, int)>; -template