Commit Graph

179 Commits

Author SHA1 Message Date
Binyang Li
5380a4ac6e Add MSCCLPP_IB_GID_INDEX env (#780)
Use MSCCLPP_IB_GID_INDEX to control ib gid index

---------

Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-04-13 09:59:42 -07:00
Changho Hwang
d63f9403c0 IB host-no-atomic: GDRCopy + mlx5dv Data Direct for memory-consistent low-latency signaling (#753)
Major enhancements to the IB signal forwarding mechanisms
(`host-no-atomic` mode), primarily adding support for GDRCopy and MLX5
Direct Verbs, and refactoring the signal forwarding path for IB
HostNoAtomic mode. The changes fix memory consistency issues and reduce
signaling latency.
- GDRCopy and MLX5 Direct Verbs MR integration
- Signal forwarding path redesign
- Semaphore and connection API updates
- Environment (`MSCCLPP_FORCE_DISABLE_GDR`) and documentation updates
2026-04-09 09:24:30 +00:00
Binyang Li
8896cd909a Add ROCm FP8 E4M3B15 support (#774)
## Summary

Add ROCm (gfx942) support for the FP8 E4M3B15 data type, including
optimized conversion routines between FP8 E4M3B15 and FP16/FP32 using
inline assembly.

Extends the allpair packet and fullmesh allreduce kernels to support
higher-precision accumulation (e.g., FP16/FP32) when reducing FP8 data,
improving numerical accuracy.

Adds Python tests to verify that higher-precision accumulation is at
least as accurate as native FP8 accumulation across all algorithm
variants.

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-04-08 09:53:45 -07:00
Binyang Li
96a72bbd3e Support E4M3B15 datatype (#765)
## Summary

- **Add `fp8_e4m3b15` datatype**: A software-defined FP8 type with 4
exponent bits, 3 mantissa bits, and bias=15 (max finite value: 0.9375).
Implemented entirely in software with no HW dependency, using
Triton-style bit manipulation through fp16 as intermediate for efficient
conversion.
- **Add mixed-precision accumulation for allreduce**: All allreduce
algorithm variants (packet, NVLS packet, fullmesh, RSAG zero-copy, and
others) now support a configurable `accumDtype` parameter, enabling FP8
inputs to be reduced in float16 or float32 for higher accuracy.
- **Propagate `accumDtype` through the full API**: The new parameter is
threaded from `Algorithm::execute()` → `NativeAlgorithm` → `KernelFunc`
→ dispatch → CUDA kernels, with `DataType::AUTO` as the default
(resolves to input dtype at runtime).
- **Add FP8 accumulation correctness tests**: New `test_fp8_accum.py`
validates that higher-precision accumulation produces results at least
as accurate as native FP8 accumulation across multiple algorithms and
sizes. Skipped on CUDA SM < 89 (pre-Hopper); runs on HIP/ROCm.
- **Add `test_fp8_accum.py` to CI**: Azure Pipeline `ut.yml` now runs
FP8 accumulation tests alongside existing pytests.
- **NCCL shim logging cleanup**: Migrated `printf`-style `WARN`/`INFO`
calls to streaming-style logging.

## Key files

| Area | Files |
|------|-------|
| New datatype + vector ops | `include/mscclpp/gpu_data_types.hpp` |
| Accumulation reduce helpers | `src/core/include/reduce_kernel.hpp` |
| Algorithm API (`accumDtype`) | `include/mscclpp/algorithm.hpp`,
`src/core/algorithm.cc` |
| Allreduce kernels | `src/ext/collectives/allreduce/*.cu` |
| Dispatch + common | `src/ext/collectives/include/allreduce/common.hpp`
|
| Python bindings | `python/csrc/algorithm.cpp`,
`python/mscclpp/_core/algorithm.py` |
| Tests | `python/test/test_fp8_accum.py` |
| CI | `.azure-pipelines/templates/ut.yml` |

## Test plan

- [x] CI passes on H100 (CUDA SM 90) — full FP8 E4M3 + E4M3B15
accumulation tests
- [x] CI passes on A100 (CUDA SM 80) — FP8 tests correctly skipped
- [x] CI passes on MI300X (ROCm) — FP8 tests run via HIP
- [x] Existing `test_mscclpp.py` tests continue to pass
- [x] NCCL shim builds and runs correctly with new `accumDtype` defaults

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-04-07 13:37:02 -07:00
Binyang Li
4f3638b60d Use PTX red for D2D semaphore signal (#768)
## Summary
- Replace the two-step `signal()` implementation (`incOutbound()` +
`atomicStore()`) with a single fire-and-forget PTX
`red.release.sys.global.add.u64` instruction
- This eliminates one local atomic fetch-add and replaces a remote store
with a remote atomic add that has no return value — more efficient on
both NVIDIA (PTX `red`) and AMD (compiler optimizes `(void)fetch_add` to
fire-and-forget `flat_atomic_add_x2`)
- Add a C++ perf test (`PERF_TEST`) in `mp_unit` for signal+wait
ping-pong latency

### Performance (H100, 2 ranks, signal+wait round-trip)

```
SemaphorePerfTest.SignalPingPong:
  Store-based (old): 2.595 us/iter
  Red-based   (new): 2.345 us/iter
  Speedup:           1.11x
```

## Test plan
- [x] Builds successfully (`make mp_unit_tests`)
- [x] `mpirun -np 2 ./build/bin/mp_unit_tests --filter
"SemaphorePerfTest"` — 1.11x speedup

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 15:34:43 -07:00
Xingbo Wu
69565a2f32 Do threadInit/cudaSetDevice before other cuda calls (#757)
I recently encountered a weird memory usage issue.
After starting the proxy service on a cuda device X > 0, I notice an
unexpected thread entity apprear on both the GPU X and GPU 0, where GPU
0's share is about 500MB. Note that when the device is 0, there is no
extra memory usage.
The image clearly shows that when 8 ranks each using one GPU and
starting proxies, the GPU 0 sees 7 extra threads, each consuming 500MB
extra memory.
<img width="1247" height="1367" alt="Screenshot 2026-02-28 000153"
src="https://github.com/user-attachments/assets/cfd0d47f-319b-4ebb-bf19-dec66062e6f4"
/>


After tracking down to when it happens, I identified the root cause in
Proxy thread initialization.

    // never capture in a proxy thread
    auto mode = cudaStreamCaptureModeRelaxed;
    MSCCLPP_CUDATHROW(cudaThreadExchangeStreamCaptureMode(&mode));

    pimpl_->threadInit();

The call to cudaThreadExchangeStreamCaptureMode() actually triggers some
resource allocation on the "current device" which is still 0 for the
starting thread.
The later threadInit() is too late to set the correct GPU number.

The fix is simple: call threadInit() before the first cuda call:

    pimpl_->threadInit();
    // never capture in a proxy thread
    auto mode = cudaStreamCaptureModeRelaxed;
    MSCCLPP_CUDATHROW(cudaThreadExchangeStreamCaptureMode(&mode));

This guarantees that the current device is properly set before calling
any resource-allocating cuda functions.

This is the memory usage after the fix. The extra memory usages are
gone.

<img width="1242" height="459" alt="Image (1)"
src="https://github.com/user-attachments/assets/4256e4c8-6f1d-4844-9f77-5b2935387df9"
/>

---------

Co-authored-by: Binyang Li <binyli@microsoft.com>
2026-03-02 15:53:59 -08:00
Caio Rocha
4bc1999001 Adding Support to Setting Message Size Range in Native Algorithm API (#758) 2026-02-27 17:50:43 -08:00
Binyang Li
39865c218b address flagBuffer ownership issue (#749)
This pull request updates the handling of the default flag buffer in the
C++ and Python bindings to ensure proper memory management when
interfacing with Python.

Make sure the buffer will not be deallocated when transfer ownership
from cpp to python
2026-02-20 13:42:29 -08:00
Binyang Li
4701ae3a95 Update dtype name (#748)
- Change FP8_E4M3/FP8_E5M2 to FLOAT8_E4M3/FLOAT8_E5M2
- Add torch.uint8 to DataType.uint8 mapping
2026-02-18 10:35:44 -08:00
Qinghua Zhou
edc9c38751 Support uint8 data type for Allreduce (#736)
Support uint8 data type for Allreduce.
Current limitation: uint8 is not supported for NVLS.

Performance results with RCCL-test with MSCCLPP on MI300X:


\# out-of-place in-place
\# size count type redop root time algbw busbw #wrong time algbw busbw
#wrong
\# (B) (elements) (us) (GB/s) (GB/s) (us) (GB/s) (GB/s)
1024 | 512 | half | sum | -1 | 5.39 | 0.19 | 0.33 | 0 | 5.45 | 0.19 |
0.33 | 0
-- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | --
2048 | 1024 | half | sum | -1 | 5.53 | 0.37 | 0.65 | 0 | 5.63 | 0.36 |
0.64 | 0
4096 | 2048 | half | sum | -1 | 5.55 | 0.74 | 1.29 | 0 | 5.56 | 0.74 |
1.29 | 0
8192 | 4096 | half | sum | -1 | 5.8 | 1.41 | 2.47 | 0 | 5.84 | 1.4 |
2.46 | 0
16384 | 8192 | half | sum | -1 | 6.57 | 2.49 | 4.36 | 0 | 6.56 | 2.5 |
4.37 | 0
32768 | 16384 | half | sum | -1 | 8.02 | 4.09 | 7.15 | 0 | 8.06 | 4.07 |
7.11 | 0
65536 | 32768 | half | sum | -1 | 8.77 | 7.47 | 13.07 | 0 | 8.82 | 7.43
| 13 | 0
131072 | 65536 | half | sum | -1 | 9.61 | 13.64 | 23.87 | 0 | 9.78 |
13.4 | 23.45 | 0
262144 | 131072 | half | sum | -1 | 11.68 | 22.44 | 39.27 | 0 | 12.1 |
21.67 | 37.93 | 0
524288 | 262144 | half | sum | -1 | 13.77 | 38.08 | 66.64 | 0 | 13.87 |
37.79 | 66.13 | 0
1048576 | 524288 | half | sum | -1 | 19.11 | 54.87 | 96.03 | 0 | 19.27 |
54.42 | 95.24 | 0
2097152 | 1048576 | half | sum | -1 | 24.1 | 87 | 152.26 | 0 | 24.24 |
86.52 | 151.41 | 0
4194304 | 2097152 | half | sum | -1 | 37.16 | 112.87 | 197.52 | 0 |
37.44 | 112.03 | 196.06 | 0
8388608 | 4194304 | half | sum | -1 | 61.53 | 136.33 | 238.58 | 0 |
61.68 | 135.99 | 237.99 | 0
16777216 | 8388608 | half | sum | -1 | 108.8 | 154.22 | 269.88 | 0 |
109.2 | 153.6 | 268.79 | 0
33554432 | 16777216 | half | sum | -1 | 197.8 | 169.68 | 296.94 | 0 |
198.6 | 168.92 | 295.61 | 0
67108864 | 33554432 | half | sum | -1 | 384.6 | 174.51 | 305.39 | 0 |
385.1 | 174.27 | 304.98 | 0
134217728 | 67108864 | half | sum | -1 | 754.1 | 177.99 | 311.48 | 0 |
754.9 | 177.78 | 311.12 | 0
268435456 | 134217728 | half | sum | -1 | 1491.8 | 179.94 | 314.89 | 0 |
1493.2 | 179.77 | 314.6 | 0
536870912 | 268435456 | half | sum | -1 | 2979.6 | 180.18 | 315.31 | 0 |
2983.9 | 179.92 | 314.87 | 0


\# out-of-place in-place
\# size count type redop root time algbw busbw #wrong time algbw busbw
#wrong
\# (B) (elements) (us) (GB/s) (GB/s) (us) (GB/s) (GB/s)
1024 | 1024 | fp8_e4m3 | sum | -1 | 5.4 | 0.19 | 0.33 | 0 | 5.45 | 0.19
| 0.33 | 0
-- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | --
2048 | 2048 | fp8_e4m3 | sum | -1 | 5.5 | 0.37 | 0.65 | 0 | 5.6 | 0.37 |
0.64 | 0
4096 | 4096 | fp8_e4m3 | sum | -1 | 5.61 | 0.73 | 1.28 | 0 | 5.68 | 0.72
| 1.26 | 0
8192 | 8192 | fp8_e4m3 | sum | -1 | 5.96 | 1.38 | 2.41 | 0 | 5.98 | 1.37
| 2.4 | 0
16384 | 16384 | fp8_e4m3 | sum | -1 | 6.49 | 2.52 | 4.42 | 0 | 6.58 |
2.49 | 4.36 | 0
32768 | 32768 | fp8_e4m3 | sum | -1 | 8.09 | 4.05 | 7.09 | 0 | 8.15 |
4.02 | 7.03 | 0
65536 | 65536 | fp8_e4m3 | sum | -1 | 8.58 | 7.64 | 13.37 | 0 | 8.7 |
7.53 | 13.18 | 0
131072 | 131072 | fp8_e4m3 | sum | -1 | 9.44 | 13.88 | 24.29 | 0 | 9.62
| 13.63 | 23.85 | 0
262144 | 262144 | fp8_e4m3 | sum | -1 | 10.12 | 25.9 | 45.32 | 0 | 10.37
| 25.27 | 44.22 | 0
524288 | 524288 | fp8_e4m3 | sum | -1 | 13.73 | 38.19 | 66.82 | 0 |
13.89 | 37.74 | 66.04 | 0
1048576 | 1048576 | fp8_e4m3 | sum | -1 | 18.66 | 56.2 | 98.34 | 0 |
18.92 | 55.41 | 96.97 | 0
2097152 | 2097152 | fp8_e4m3 | sum | -1 | 24.54 | 85.46 | 149.56 | 0 |
24.63 | 85.16 | 149.03 | 0
4194304 | 4194304 | fp8_e4m3 | sum | -1 | 37.79 | 110.98 | 194.21 | 0 |
38.05 | 110.22 | 192.88 | 0
8388608 | 8388608 | fp8_e4m3 | sum | -1 | 62.22 | 134.82 | 235.94 | 0 |
62.63 | 133.94 | 234.4 | 0
16777216 | 16777216 | fp8_e4m3 | sum | -1 | 109.9 | 152.62 | 267.09 | 0
| 110.4 | 151.9 | 265.83 | 0
33554432 | 33554432 | fp8_e4m3 | sum | -1 | 201.1 | 166.82 | 291.94 | 0
| 202.3 | 165.84 | 290.22 | 0
67108864 | 67108864 | fp8_e4m3 | sum | -1 | 390 | 172.06 | 301.11 | 0 |
390.2 | 171.99 | 300.99 | 0
134217728 | 134217728 | fp8_e4m3 | sum | -1 | 763.9 | 175.7 | 307.47 | 0
| 764.2 | 175.62 | 307.34 | 0
268435456 | 268435456 | fp8_e4m3 | sum | -1 | 1509.5 | 177.83 | 311.2 |
0 | 1510.1 | 177.76 | 311.08 | 0
536870912 | 536870912 | fp8_e4m3 | sum | -1 | 3010.2 | 178.35 | 312.11 |
0 | 3014.2 | 178.11 | 311.7 | 0



\# out-of-place in-place
\# size count type redop root time algbw busbw #wrong time algbw busbw
#wrong
\# (B) (elements) (us) (GB/s) (GB/s) (us) (GB/s) (GB/s)
1024 | 1024 | fp8_e5m2 | sum | -1 | 5.41 | 0.19 | 0.33 | 0 | 5.44 | 0.19
| 0.33 | 0
-- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | --
2048 | 2048 | fp8_e5m2 | sum | -1 | 5.5 | 0.37 | 0.65 | 0 | 5.67 | 0.36
| 0.63 | 0
4096 | 4096 | fp8_e5m2 | sum | -1 | 5.61 | 0.73 | 1.28 | 0 | 5.69 | 0.72
| 1.26 | 0
8192 | 8192 | fp8_e5m2 | sum | -1 | 5.96 | 1.37 | 2.4 | 0 | 6 | 1.36 |
2.39 | 0
16384 | 16384 | fp8_e5m2 | sum | -1 | 6.63 | 2.47 | 4.32 | 0 | 6.59 |
2.49 | 4.35 | 0
32768 | 32768 | fp8_e5m2 | sum | -1 | 8.07 | 4.06 | 7.1 | 0 | 8.16 |
4.02 | 7.03 | 0
65536 | 65536 | fp8_e5m2 | sum | -1 | 8.62 | 7.61 | 13.31 | 0 | 8.73 |
7.51 | 13.14 | 0
131072 | 131072 | fp8_e5m2 | sum | -1 | 9.43 | 13.9 | 24.33 | 0 | 9.6 |
13.66 | 23.9 | 0
262144 | 262144 | fp8_e5m2 | sum | -1 | 10.11 | 25.94 | 45.39 | 0 |
10.38 | 25.26 | 44.21 | 0
524288 | 524288 | fp8_e5m2 | sum | -1 | 13.73 | 38.19 | 66.84 | 0 |
13.87 | 37.79 | 66.13 | 0
1048576 | 1048576 | fp8_e5m2 | sum | -1 | 18.65 | 56.22 | 98.39 | 0 |
18.93 | 55.38 | 96.92 | 0
2097152 | 2097152 | fp8_e5m2 | sum | -1 | 24.54 | 85.47 | 149.57 | 0 |
24.63 | 85.16 | 149.03 | 0
4194304 | 4194304 | fp8_e5m2 | sum | -1 | 37.84 | 110.83 | 193.96 | 0 |
38.01 | 110.36 | 193.12 | 0
8388608 | 8388608 | fp8_e5m2 | sum | -1 | 62.32 | 134.61 | 235.58 | 0 |
62.55 | 134.12 | 234.71 | 0
16777216 | 16777216 | fp8_e5m2 | sum | -1 | 110 | 152.58 | 267.01 | 0 |
110.3 | 152.12 | 266.21 | 0
33554432 | 33554432 | fp8_e5m2 | sum | -1 | 201.1 | 166.9 | 292.07 | 0 |
201.8 | 166.26 | 290.96 | 0
67108864 | 67108864 | fp8_e5m2 | sum | -1 | 390 | 172.07 | 301.12 | 0 |
390.5 | 171.87 | 300.78 | 0
134217728 | 134217728 | fp8_e5m2 | sum | -1 | 763.9 | 175.69 | 307.46 |
0 | 764.5 | 175.56 | 307.23 | 0
268435456 | 268435456 | fp8_e5m2 | sum | -1 | 1509.4 | 177.84 | 311.22 |
0 | 1509.8 | 177.8 | 311.14 | 0
536870912 | 536870912 | fp8_e5m2 | sum | -1 | 3013 | 178.18 | 311.82 | 0
| 3018 | 177.89 | 311.31 | 0


\# out-of-place in-place
\# size count type redop root time algbw busbw #wrong time algbw busbw
#wrong
\# (B) (elements) (us) (GB/s) (GB/s) (us) (GB/s) (GB/s)
1024 | 1024 | uint8 | sum | -1 | 5.46 | 0.19 | 0.33 | 0 | 5.46 | 0.19 |
0.33 | 0
-- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | --
2048 | 2048 | uint8 | sum | -1 | 5.54 | 0.37 | 0.65 | 0 | 5.63 | 0.36 |
0.64 | 0
4096 | 4096 | uint8 | sum | -1 | 5.61 | 0.73 | 1.28 | 0 | 5.63 | 0.73 |
1.27 | 0
8192 | 8192 | uint8 | sum | -1 | 5.9 | 1.39 | 2.43 | 0 | 5.9 | 1.39 |
2.43 | 0
16384 | 16384 | uint8 | sum | -1 | 6.6 | 2.48 | 4.35 | 0 | 6.64 | 2.47 |
4.32 | 0
32768 | 32768 | uint8 | sum | -1 | 8.99 | 3.65 | 6.38 | 0 | 8.99 | 3.64
| 6.38 | 0
65536 | 65536 | uint8 | sum | -1 | 9.44 | 6.94 | 12.15 | 0 | 9.58 | 6.84
| 11.98 | 0
131072 | 131072 | uint8 | sum | -1 | 11.72 | 11.18 | 19.57 | 0 | 11.83 |
11.08 | 19.4 | 0
262144 | 262144 | uint8 | sum | -1 | 12.29 | 21.32 | 37.31 | 0 | 12.45 |
21.05 | 36.84 | 0
524288 | 524288 | uint8 | sum | -1 | 13.87 | 37.8 | 66.15 | 0 | 13.93 |
37.64 | 65.88 | 0
1048576 | 1048576 | uint8 | sum | -1 | 19.11 | 54.88 | 96.04 | 0 | 19.3
| 54.33 | 95.08 | 0
2097152 | 2097152 | uint8 | sum | -1 | 24.38 | 86.01 | 150.51 | 0 |
24.52 | 85.53 | 149.67 | 0
4194304 | 4194304 | uint8 | sum | -1 | 37.52 | 111.78 | 195.61 | 0 |
37.76 | 111.08 | 194.39 | 0
8388608 | 8388608 | uint8 | sum | -1 | 62.4 | 134.44 | 235.26 | 0 |
62.56 | 134.1 | 234.67 | 0
16777216 | 16777216 | uint8 | sum | -1 | 110.2 | 152.22 | 266.39 | 0 |
110.3 | 152.04 | 266.08 | 0
33554432 | 33554432 | uint8 | sum | -1 | 199.8 | 167.94 | 293.9 | 0 |
197.5 | 169.88 | 297.29 | 0
67108864 | 67108864 | uint8 | sum | -1 | 386.3 | 173.73 | 304.03 | 0 |
378.4 | 177.37 | 310.39 | 0
134217728 | 134217728 | uint8 | sum | -1 | 758 | 177.07 | 309.87 | 0 |
741.1 | 181.12 | 316.95 | 0
268435456 | 268435456 | uint8 | sum | -1 | 1500.1 | 178.95 | 313.16 | 0
| 1466.2 | 183.09 | 320.4 | 0
536870912 | 536870912 | uint8 | sum | -1 | 2991.7 | 179.45 | 314.04 | 0
| 2924.8 | 183.56 | 321.23 | 0

---------

Co-authored-by: Qinghua Zhou <qinghuahzhou@microsoft.com>
2026-02-13 10:49:25 -08:00
Binyang Li
bd68319e3e Refactor algo selection logic and introduce symmetric_memory env (#741)
This PR refactors the algorithm selection logic in MSCCL++ and
introduces support for symmetric memory configuration through
environment variables.


1. Algorithm Selection Refactoring
Use separate class for algo selection. Could introduce more complex
logic for algo selection based on message size, arch, if cuda graph is
enabled and memory allocation method

2. Symmetric Memory Support
Introduced symmetricMemory parameter in algorithm context key
generation. Remove disableChannelCache env as is ambiguous

3. Add new args for build_default_algorithms 
Add flag_buffer, and flag_buffer_size args to build default algorithm.
Then we could use unified flag buffer for different algorithms, avoid
application hanging when switch algo for different message size.

---------

Co-authored-by: chhwang <8018170+chhwang@users.noreply.github.com>
Co-authored-by: Qinghua Zhou <qinghuazhou@microsoft.com>
Co-authored-by: Caio Rocha <caiorocha@microsoft.com>
2026-02-12 19:06:18 -08:00
Changho Hwang
42be3660e0 Add a new IB stack impl that doesn't use RDMA atomics (#728)
* Added configurable InfiniBand (IB) signaling mode.
`EndpointConfig::Ib::Mode` enum selects the mode (`Default`, `Host`,
`HostNoAtomic`). `Default` is equivalent to `Host` unless specified
different by envrionment `MSCCLPP_IBV_MODE`. `Host` corresponds to the
previous implementation using RDMA atomics for signaling, while
`HostNoAtomic` uses write-with-immediate instead.
* Regarding updates in Python bindings and API.
2026-02-10 01:07:53 +00:00
Qinghua Zhou
620378b4fb Fix cpplint error in main branch (#740)
Fix the legacy cpplint error in main branch.

---------

Co-authored-by: Qinghua Zhou <qinghuahzhou@microsoft.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Binyang Li <binyli@microsoft.com>
2026-02-05 09:25:12 -08:00
Binyang Li
dc747b1522 Refactor reduce kernel (#738)
- Put the common reduce kernel to reduce_kernel.hpp
- Implement operator overloading for the vector type
- Clean up the duplicated code at `executor_ kernel.hpp` and
`allreduce/common.hpp`
2026-02-05 09:23:43 -08:00
Binyang Li
e21513791a Address comments for PR #692 (#733)
Rename nanobind-exposed C++ types to Cpp*
Replace MSCCLPP_EXECUTION_PLAN_DIR / MSCCLPP_NATIVE_CACHE_DIR with
MSCCLPP_CACHE_DIR across C++ and Python.
2026-02-03 10:13:20 -08:00
mahdiehghazim
071dc92d38 fp8 nvls support (e5m2 and e4m3) (#730)
This PR adds FP8 support to the nvls code. For compilation, we need to
add this flag to the cmake command:

-DMSCCLPP_GPU_ARCHS=100a
2026-01-23 10:38:38 -05:00
Binyang Li
a707273701 Torch integration (#692)
Reorganize current native algorithm implementation and DSL algorithm
implementation.
Provide unified API for DSL algo and native algo and provide interface
to tune the algo
Provide interface for pytorch integration with native API and DSL

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: chhwang <8018170+chhwang@users.noreply.github.com>
2026-01-21 20:32:24 -08:00
Changho Hwang
105239fc6c Use GpuIpcMem for NVLS connections (#719)
* Now `NvlsConnection` internally reuses `GpuIpcMem` for multicast
memory handling.
* Removed unnecessary barriers from `connectNvlsCollective()` (CUDA API
handles this automatically).
* Updated `GpuIpcMem::map()` and `GpuIpcMem::mapMulticast()` to return a
shared pointer with custom deleter for unmapping, which prevents misuse
of raw pointers and reduces states to be stored in the `GpuIpcMem`
instance.
* Now for `RuntimeIpc` type handles, for consistency with other types,
`cudaIpcOpenMemHandle` will be called in `GpuIpcMem::map()` instead of
the ctor of `GpuIpcMem`.

---------

Co-authored-by: Binyang Li <binyli@microsoft.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Binyang2014 <9415966+Binyang2014@users.noreply.github.com>
2026-01-15 13:16:04 +08:00
Changho Hwang
a02ba3b1bd Add GpuIpcMemHandle (#704)
Add `GpuIpcMemHandle` that is a generic GPU memory handle that covers
all existing methods for GPU memory mapping. This PR fixes issues that
fail to properly fallback to a feasible type of memory handle on the
importing environment. It also consolidates code for creating or
destroying various memory handles into a single RAII wrapper.

---------

Co-authored-by: Binyang Li <binyli@microsoft.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Binyang2014 <9415966+Binyang2014@users.noreply.github.com>
2026-01-14 10:49:31 +08:00
Changho Hwang
fc221e234d Remove UB std:: declarations (#709)
Remove custom delcarations inside `std::` of which behaviors are
undefined by the standard
2026-01-05 11:11:46 +08:00
Changho Hwang
bb555277ad Rename P2P log subsys into GPU (#716) 2026-01-05 11:08:43 +08:00
Binyang Li
ca6a4a3274 Replace __HIP_PLATFORM_AMD__ to use internal macro (#712)
Replacing most of checks for `__HIP_PLATFORM_AMD__` with
`MSCCLPP_DEVICE_HIP` for device and `MSCCLPP_USE_ROCM` for host source
file.
2026-01-04 04:47:58 -08:00
Changho Hwang
9e076da3d4 Make IB more configurable (#703)
* Added `port` and `gidIndex` field in the IB endpoint config (and
`deviceIndex` field for future usages)
* Added `MSCCLPP_IBV_SO` env variable to specify a custom libibverbs.so
* Added `--ib_gid_index` CLI option to `mp_unit_tests`
* Other minor fixes
2025-12-18 13:21:07 -08:00
Changho Hwang
ddf84a6b9d Add CudaDeviceGuard (#691)
Add an RAII guard that sets a proper GPU device before a CUDA API call.
We may change this stateful in the future to minimize `cudaGetDevice()`
calls. This PR fixes a bug of the tutorial 01.
2025-11-24 13:38:44 -08:00
Qinghua Zhou
b9428341a2 Revise the mscclpp datatype (#671)
Use mscclpp::DataType to replace the following types in API interface:
int dtype; 
ncclDataType_t dtype;

Add data type conversion:
Convert ncclDataType_t to mscclpp::DataType
2025-11-17 12:58:47 -08:00
Changho Hwang
1bf4e8c90e connect() APIs changed to return an instance instead of a shared_ptr (#680)
The key purpose is handling all mscclpp objects' memory internally by
hiding shared pointers from user APIs.
* `Connection` class is now a wrapper of `BaseConnection` class that is
equivalent to the previous `Connection` class
* `connect()` methods now return `Connection` instead of
`std::shared_ptr<Connection>`
* Removed `connectOnSetup()` method
2025-11-15 11:40:40 -08:00
Caio Rocha
7eb3ff701a Supporting New Packet Kernel Operation at Executor (#677)
This PR introduces three new operations to enhance flexibility and
performance at executor.

One operation can be invoked directly via the DSL API and two operations
are created through fusion of existing operations, reducing overhead and
improving efficiency.

1. Port Channel Put Packet (Direct DSL API Call): Sends data from pkt
format to the remote side in pkt format via the port channel. Both
source and destination buffers must be scratch.

2. Reduce Copy Packet (Fusion):
Reduce Packet+Copy Packet=Reduce Copy Packet
Triggered when the destination buffer of Reduce Packet matches the
source buffer of Copy Packet.
Purpose: Combine reduction and copy into a single step for better
performance.

3. Reduce Copy Send Packet (Fusion):
Reduce Copy Packet+Put Packet=Reduce Copy Send Packet (when dst buffer
of Reduce Copy Packet matches src buffer of Put Packet)
Reduce Copy Packet+Read Put Packet=Reduce Copy Send Packet (when dst pkt
buffer of Reduce Copy Packet matches src buffer of Read Put Packet)
Purpose: Combine reduction, copy, and send operations into one optimized
pipeline.


Fusion Diagram
Reduce Packet + Copy Packet → Reduce Copy Packet
Reduce Copy Packet + Put Packet → Reduce Copy Send Packet
Reduce Copy Packet + Read Put Packet → Reduce Copy Send Packet

Beyond this, this PR adjust the AllReduce 2 Node algorithm:

Message Size  |  Latency (µs)
        1K            |     15.34
        2K            |     15.88
        4K            |     15.71
        8K            |     16.01
        16K          |     15.88
        32K          |     16.21
        64K          |     16.90
        128K        |     18.24
        256K        |     20.39
        512K        |     25.26
        1M           |     32.74
        2M           |     53.64
2025-11-13 14:08:44 -08:00
Caio Rocha
eb202780f5 Support Synchronous Initialization for Proxy Service (#679) 2025-11-12 18:35:57 -08:00
Changho Hwang
ffafcaf6d6 IB stack enhancements & bug fixes (#673)
* Always use `ibv_reg_dmabuf_mr` when DMABUF is supported
* Do not check `nvidia-peermem` when unnecessary
* More rigorous check on IB port availability
* Fixed ibverbs wrappers
* Fixed `IbPeerToPeerTest.SimpleAtomicAdd` test
2025-11-07 14:26:17 -08:00
Changho Hwang
960a8ddebd Add a new logger (#668)
* Add `logger.hpp` that will gradually replace `debug.h`
* Minor fixes in `ib.cc`

---------

Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: chhwang <8018170+chhwang@users.noreply.github.com>
2025-11-04 10:32:46 -08:00
Binyang Li
5acac93dbc Integrate MSCCL++ DSL to torch workload (#620)
Provides two integration ways for MSCCL++ DSL.
1. Integrate with customized communication group
2. Integrate with NCCL API

Introduce new Python APIs to make it work:
```python
mscclpp.compile # compile dsl to json based execution plan
mscclpp.ExecutionPlanRegistry.register_plan(plan) # register the compiled plan to executionPlanRegistery
mscclpp.ExecutionPlanRegistry.set_selector(selector) # set the selector, the selector will return the best execution plan based on collection, message size, world size....
```
Fix #556

---------

Co-authored-by: Caio Rocha <caiorocha@microsoft.com>
Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
2025-10-29 15:39:00 -07:00
Qinghua Zhou
a38c2ee784 FP8 support for Allreduce (#646)
Add FP8 support for Allreduce on both NVIDIA and AMD platform.
Add new data type: fp8_e4m3 and fp8_e5m2

---------

Co-authored-by: Binyang Li <binyli@microsoft.com>
2025-10-27 14:51:48 -07:00
Binyang Li
2b05908635 Add token pool for cuCreate API (#628)
Create a tokenPool to allocate token. This feature is used to support
inter node NVL and try to reduce the footprint caused by cuCreate

---------

Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
2025-10-27 11:19:21 -07:00
Changho Hwang
68b1f151f0 Rename nvls* files (#660)
Rename nvls* files to switch_channel*
2025-10-24 11:34:26 -07:00
Changho Hwang
200cdf946e Update EndpointConfig interfaces (#651)
* Separate IB-specific options into a nested struct
* Enable `connect()` by an `Endpoint`, not only by `EndpointConfig`
* Other minor changes
2025-10-22 10:39:39 -07:00
Binyang Li
3d94383696 Add MSCCLPP_GIT_COMMIT micro (#640)
- Add MSCCLPP_GIT_COMMIT micro
- Update docs
2025-10-06 15:57:28 -07:00
Binyang Li
70b8297c56 Revise NCCL API implementation (#617)
- Make nccl interface extensible. Customer can register their own algo
to NCCL API. User can provide customized algo selection function.
- Fallback to NCCL/RCCL if no algo is selected based on algo selection
function
- MSCCLPP interfaces now works for any scale
2025-09-26 10:08:12 -07:00
Binyang Li
5ac427610d Address teardown issue (#638)
Ignore cuda/cu errors during teardown. Some pointer may be invalid at this point
2025-09-25 12:12:40 -07:00
Changho Hwang
43f160c8e6 Fix for safe process teardown (#633)
* `gpuFree*()` functions are usually called during process teardown, so
we let them ignore regarding errors.
* `AvoidCudaGraphCaptureGuard` is constructed in `gpuFree*()` functions,
so it needs the same fix.
2025-09-10 20:28:04 -07:00
Changho Hwang
571fee16fb Add FifoDeviceHandle::poll() (#630) 2025-09-09 23:32:35 +00:00
Changho Hwang
c4d8781390 Fix memory exchange within a single process (#624) 2025-09-04 12:53:51 -07:00
Binyang Li
bb76d27553 all2all implementation (#609)
Implement single node all2all via MSCCL++ C++API
perf kernel 3:
```
       size         count     time   algbw   busbw  #wrong     time   algbw   busbw  #wrong
#        (B)    (elements)     (us)  (GB/s)  (GB/s)             (us)  (GB/s)  (GB/s)
     1048576         32768                                     23.41   44.78   39.19      0
     2097152         65536                                     23.95   87.56   76.61      0
     4194304        131072                                     27.50  152.51  133.45      0
     8388608        262144                                     35.14  238.73  208.89      0
    16777216        524288                                     57.54  291.55  255.11      0
    33554432       1048576                                     109.7  305.81  267.59      0
    67108864       2097152                                     212.3  316.07  276.56      0
   134217728       4194304                                     410.9  326.64  285.81      0
   268435456       8388608                                     784.9  341.99  299.24      0
```

kernel 2
```

#                                        in-place                       out-of-place
#       size         count     time   algbw   busbw  #wrong     time   algbw   busbw  #wrong
#        (B)    (elements)     (us)  (GB/s)  (GB/s)             (us)  (GB/s)  (GB/s)
     1048576         32768                                     23.42   44.77   39.17      0
     2097152         65536                                     24.96   84.02   73.52      0
     4194304        131072                                     28.53  147.03  128.65      0
     8388608        262144                                     36.75  228.28  199.75      0
    16777216        524288                                     58.01  289.20  253.05      0
    33554432       1048576                                     110.4  303.83  265.85      0
    67108864       2097152                                     212.4  315.99  276.49      0
   134217728       4194304                                     407.8  329.12  287.98      0
   268435456       8388608                                     797.4  336.64  294.56      0
```

NCCL:
```
NCCL version 2.21.5+cuda12.4
#
#                                                              out-of-place                       in-place          
#       size         count      type   redop    root     time   algbw   busbw #wrong     time   algbw   busbw #wrong
#        (B)    (elements)                               (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)       
     8388608        524288      half    none      -1    38.70  216.75  189.66      0    39.25  213.72  187.00    N/A
    16777216       1048576      half    none      -1    71.39  234.99  205.62      0    68.41  245.25  214.60    N/A
    33554432       2097152      half    none      -1    119.7  280.22  245.20      0    119.8  280.17  245.15    N/A
    67108864       4194304      half    none      -1    211.9  316.66  277.08      0    212.7  315.53  276.09    N/A
   134217728       8388608      half    none      -1    408.4  328.61  287.53      0    393.8  340.87  298.26    N/A
   268435456      16777216      half    none      -1    761.6  352.47  308.41      0    763.3  351.70  307.73    N/A
   536870912      33554432      half    none      -1   1502.5  357.31  312.64      0   1467.3  365.89  320.16    N/A
```
2025-08-14 11:30:40 -07:00
Binyang Li
be6a941fba New DSL implementation (#579)
The PR contains following changes:
Python side:
- Channel based DSL implementation: decouple channel with chunk.
- Users create channel explicitly, only need local_rank, remote_rank and
channel_type
- Adjust executor json file, add remote_buffer fields, different op can
use different channel and remote buffers combination.
- Reimplement operation fusion, data dependency check mechanism
- Add new op such as semaphore, pipeline 
- Clean code and enhance document
C++ side: 
- Support new execution file json format
- Support semaphore and pipeline operation
- code clean, support non-zero copy scenario

---------

Co-authored-by: Caio Rocha <caiorocha@microsoft.com>
Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
2025-08-09 00:36:20 -07:00
Changho Hwang
542a10f69e Merge ChannelTrigger with ProxyTrigger (#601) 2025-08-08 19:07:50 +00:00
Changho Hwang
9650e5c37e Update documentation (#576)
Documentation overhaul
2025-08-07 15:37:37 -07:00
Binyang Li
4f6f23dae3 Use smart pointer for IB structure (#585)
Change to use smart pointer for IB structure. Registered memory will own
ibMr, ibCtx will not held the reference
- Use smart pointer for IbQp and IbMr
- Update memoryChannel API, keep localRegisteredMemory
- Close fd when registedMemory released

---------

Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
2025-08-06 10:01:58 -07:00
Changho Hwang
d55ac96f5e Fixed the local channel test (#597) 2025-08-05 15:33:48 -07:00
Changho Hwang
334b232e36 Fix GpuStreamPool to be aware of the device ID of streams (#590)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-08-04 11:07:31 -07:00
Changho Hwang
c580e4c503 Support CudaIpc connection within a single process (#593)
* Allow CudaIpc connection between GPUs in a single process
* Added an example of connection in a single process
* Minor interface updates

---------

Co-authored-by: Binyang Li <binyli@microsoft.com>
2025-08-02 12:59:36 +08:00
Changho Hwang
aa28b06bf5 Fix relaxedWait() (#594) 2025-08-01 12:51:30 +08:00