mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
[examples][CuTeDSL] init commit for distirbuted examples (#2806)
* init commit for distirbuted examples * better OOB protection * and try import to nvshmem for better error message and a READMME.md to introduce nvshmem and multimem instructions * add some lamport explanation * enhance f8 output and warn that f8 output can have nan in it * tell user why we need complicate data conversions in ref check part * tell user we don't support nvshmem device function --------- Co-authored-by: bangyus <bangyus@nvidia.com>
This commit is contained in:
139
examples/python/CuTeDSL/distributed/README.md
Normal file
139
examples/python/CuTeDSL/distributed/README.md
Normal file
@@ -0,0 +1,139 @@
|
||||
# CuTeDSL Distributed Examples
|
||||
|
||||
This directory contains distributed examples using CuTeDSL with NVSHMEM for multi-GPU communication. Currently, we do not support to use NVSHMEM for any device side copy/put/get impl, only use the host side setup and allocations.
|
||||
|
||||
## NVSHMEM Dependency
|
||||
|
||||
These examples require two components:
|
||||
|
||||
1. **NVSHMEM4Py** (`nvshmem4py-cu12` / `nvshmem4py-cu13`): A Python package that provides the official Python binding for NVIDIA's NVSHMEM. See the [NVSHMEM4Py Documentation](https://docs.nvidia.com/nvshmem/api/api/language_bindings/python/index.html).
|
||||
|
||||
2. **NVSHMEM Library** (`nvidia-nvshmem-cu12` / `nvidia-nvshmem-cu13`): The underlying native library that contains the actual NVSHMEM implementation.
|
||||
|
||||
### Overview
|
||||
|
||||
**NVSHMEM4Py** (`nvshmem4py-cu12` / `nvshmem4py-cu13`) is a Python binding library that provides a Pythonic interface to NVSHMEM functionality. In these examples, we use it primarily for:
|
||||
|
||||
- Allocating tensors that support peer-to-peer (P2P) communication across GPUs
|
||||
- Allocating multicast (MC) tensors that can leverage `multimem` instructions for efficient collective operations
|
||||
|
||||
**nvidia-nvshmem** (`nvidia-nvshmem-cu12` / `nvidia-nvshmem-cu13`) is the underlying library that wraps NVSHMEM functions into dynamic libraries (`.so` files). NVSHMEM4Py dynamically loads and calls these libraries at runtime.
|
||||
|
||||
### Installation
|
||||
|
||||
For CUDA 12:
|
||||
```bash
|
||||
pip install nvshmem4py-cu12 nvidia-nvshmem-cu12
|
||||
```
|
||||
|
||||
For CUDA 13:
|
||||
```bash
|
||||
pip install nvshmem4py-cu13 nvidia-nvshmem-cu13
|
||||
```
|
||||
|
||||
> **Note:** `nvshmem4py` version >= 0.1.3 is recommended.
|
||||
|
||||
### Key APIs Used
|
||||
|
||||
We primarily use the following APIs from `nvshmem.core`:
|
||||
|
||||
| API | Description |
|
||||
|-----|-------------|
|
||||
| `nvshmem.core.tensor(shape, dtype)` | Allocates a symmetric tensor that supports P2P communication |
|
||||
| `nvshmem.core.get_peer_tensor(tensor, pe)` | Returns a tensor handle for accessing the given tensor on a remote PE (processing element) |
|
||||
| `nvshmem.core.get_multicast_tensor(tensor)` | Returns a tensor that can be accessed using `multimem` instructions for efficient multicast operations |
|
||||
| `nvshmem.core.free_tensor(tensor)` | Explicitly frees the allocated symmetric memory |
|
||||
|
||||
### Memory Management
|
||||
|
||||
NVSHMEM requires **manual memory management**. Unlike PyTorch tensors that are garbage-collected automatically, NVSHMEM symmetric memory must be explicitly freed using `nvshmem.core.free_tensor()` to avoid memory leaks.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import nvshmem.core
|
||||
|
||||
# init the environment
|
||||
# refer to the torchrun_uid_init_bcast() in example
|
||||
|
||||
# Allocate symmetric tensor
|
||||
local_tensor = nvshmem.core.tensor((M, N), dtype=torch.float32)
|
||||
|
||||
# Get peer tensors for P2P access
|
||||
tensor_list = [nvshmem.core.get_peer_tensor(local_tensor, rank) for rank in range(world_size)]
|
||||
|
||||
# ... use tensors ...
|
||||
|
||||
# Explicitly free memory when done
|
||||
for t in tensor_list:
|
||||
nvshmem.core.free_tensor(t)
|
||||
|
||||
# finalize the environment
|
||||
# refer to the torchrun_finalize() in example
|
||||
|
||||
```
|
||||
|
||||
## Multimem Instructions
|
||||
|
||||
These examples demonstrate the use of NVIDIA's `multimem` PTX instructions for efficient multi-GPU collective operations. The `multimem` instructions operate on multicast (MC) addresses obtained via `nvshmem.core.get_multicast_tensor()`, enabling hardware-accelerated communication across multiple GPUs.
|
||||
|
||||
### Why Multimem is Fast: NVLS (NVLink SHARP)
|
||||
|
||||
The `multimem` instructions leverage **NVLS (NVLink SHARP)** technology to perform **in-network computation**. When multiple GPUs map the same symmetric memory region, `multimem` instructions can operate on a multicast address to perform hardware-accelerated reduction or broadcast operations directly in the NVLink/NVSwitch fabric, without requiring data to traverse to GPU memory first.
|
||||
|
||||
**Key benefits:**
|
||||
- **In-network computation**: Reduction and broadcast operations happen in the NVSwitch hardware, not in GPU compute units
|
||||
- **Reduced memory traffic**: Data is processed in-flight within the interconnect, minimizing HBM bandwidth consumption
|
||||
- **Lower latency**: Single instruction replaces multiple loads/stores and arithmetic operations
|
||||
|
||||
### Instruction Categories
|
||||
|
||||
We use three types of `multimem` instructions in these examples:
|
||||
|
||||
#### 1. `multimem.ld_reduce` - Reduction
|
||||
|
||||
Reads data from a multicast address and returns the **reduced result** (e.g., sum) across all GPUs:
|
||||
|
||||
```
|
||||
multimem.ld_reduce.sys.relaxed.global.add.v4.f32 {$0, $1, $2, $3}, [$4];
|
||||
```
|
||||
|
||||
This instruction reads from a multicast address and performs a sum reduction (`.add`) across all GPUs that have mapped this address via NVLS.
|
||||
|
||||
**Accumulator Precision**: For lower-precision data types, you can specify a higher accumulator precision to improve numerical accuracy:
|
||||
- **FP16 / BF16**: Can use FP32 accumulator (`.acc::f32`)
|
||||
- **FP8 (E4M3 / E5M2)**: Can use FP16 accumulator (`.acc::f16`)
|
||||
|
||||
Example with FP16 using FP32 accumulator:
|
||||
```
|
||||
multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.f16x2 {$0, $1, $2, $3}, [$4];
|
||||
```
|
||||
|
||||
#### 2. `multimem.st` - Broadcast via Store
|
||||
|
||||
Stores data to a multicast address, which **broadcasts** the data to all participating GPUs:
|
||||
|
||||
```
|
||||
multimem.st.sys.relaxed.global.v4.f32 [$1], {$2, $3, $4, $5};
|
||||
```
|
||||
|
||||
This writes data to a multicast address, and the data becomes visible to all GPUs that have mapped this address via NVLS.
|
||||
|
||||
#### 3. `multimem.red` - Broadcast via Atomic Reduction
|
||||
|
||||
Performs an atomic reduction operation on a multicast address. This is commonly used for **signaling/synchronization** across GPUs:
|
||||
|
||||
```
|
||||
multimem.red.release.sys.global.add.u32 [$0], 1;
|
||||
```
|
||||
|
||||
This atomically adds a value to a multicast address. When used with synchronization patterns (e.g., spin locks), it enables efficient inter-GPU barriers where all GPUs can observe the updated value.
|
||||
|
||||
## Future Work
|
||||
|
||||
The `nvidia-nvshmem-cu12/cu13` packages include LLVM IR bitcode libraries that could potentially be integrated into CuTeDSL in the future. This would enable calling NVSHMEM functions directly from within CuTeDSL kernels, allowing for more fine-grained control over communication patterns at the kernel level.
|
||||
|
||||
## References
|
||||
|
||||
- [NVSHMEM4Py Documentation](https://docs.nvidia.com/nvshmem/api/api/language_bindings/python/index.html)
|
||||
- [NVSHMEM API Reference](https://docs.nvidia.com/nvshmem/api/api/language_bindings/python/index.html)
|
||||
- [multimem PTX instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-multimem)
|
||||
@@ -0,0 +1,426 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import os
|
||||
import torch
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._symmetric_memory as symm_mem
|
||||
import cuda.bindings.driver as cuda
|
||||
from cuda.core.experimental import Device
|
||||
from cuda.pathfinder import load_nvidia_dynamic_lib
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
from cutlass.cutlass_dsl import T
|
||||
from cutlass._mlir.dialects import vector
|
||||
|
||||
try:
|
||||
import nvshmem.core
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"nvshmem4py is required but not installed. Please install it using:\n"
|
||||
" For CUDA 12: pip install nvshmem4py-cu12\n"
|
||||
" For CUDA 13: pip install nvshmem4py-cu13\n"
|
||||
"Note: nvshmem4py version >= 0.1.3 is recommended."
|
||||
) from None
|
||||
|
||||
try:
|
||||
load_nvidia_dynamic_lib("nvshmem_host")
|
||||
except RuntimeError as exc:
|
||||
raise ImportError(
|
||||
"nvshmem lib is required but not installed. Please install it using:\n"
|
||||
" For CUDA 12: pip install nvidia-nvshmem-cu12\n"
|
||||
" For CUDA 13: pip install nvidia-nvshmem-cu13\n"
|
||||
) from None
|
||||
|
||||
"""
|
||||
A Distributed One-Shot All-Reduce Example using CuTe DSL and fine-grained memory control. This is a mirrored version of the
|
||||
existing tensorrt_llm kernel:
|
||||
https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu
|
||||
|
||||
In Lamport terminology this is a classic flag-based busy-wait: every participant keeps polling the shared slot until the
|
||||
flag changes from the sentinel (negative zero) to real data, which indicates that the Lamport-style logical ordering has
|
||||
advanced and the payload is safe to consume.
|
||||
|
||||
This example kernel demonstrates a one-shot all-reduce operation using the CuTe DSL with fine-grained memory control.
|
||||
It uses dedicated communication buffers for data exchange, and these buffers act as ping-pong buffers. During the
|
||||
process, the kernel uses one buffer for communication and initializes the next buffer to all negative zeros.
|
||||
|
||||
In this kernel, each thread is only responsible for 128bits of data. The kernel will write it's local data to every
|
||||
buffer at different ranks, then read the data from the local rank buffer. The buffer itself behaves as a barrier,
|
||||
if kernel read negtive 0, then it means data are not ready or not visible yet so that the kernel will read the data again.
|
||||
|
||||
If the input tensors from each device are not remotely accessible, this kernel can be used to perform the one-shot all-reduce
|
||||
since it uses communication buffers for data exchange.
|
||||
|
||||
The .SYS memory scope and .VOLATILE memory order are used to ensure that the data will be visible at the system scope.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
torchrun --nproc-per-node 8 examples/distributed/all_reduce_one_shot_lamport.py --M 8192 --N 8192
|
||||
torchrun --nproc-per-node 8 examples/distributed/all_reduce_one_shot_lamport.py \
|
||||
--M 8192 --N 8192 --benchmark --warmup_iterations 2 --iterations 10
|
||||
"""
|
||||
|
||||
|
||||
PING_PONG_SIZE = 3
|
||||
|
||||
|
||||
class AllReduceOneShotLamportKernel:
|
||||
@cute.jit
|
||||
def __call__(
|
||||
self,
|
||||
rank: cutlass.Constexpr,
|
||||
world_size: cutlass.Constexpr,
|
||||
signal: cutlass.Int32,
|
||||
local_input: cute.Tensor,
|
||||
local_output: cute.Tensor,
|
||||
buffers: list[cute.Tensor],
|
||||
stream: cuda.CUstream,
|
||||
):
|
||||
copy_bits = 128
|
||||
dtype = local_input.element_type
|
||||
vector_size = copy_bits // dtype.width
|
||||
|
||||
thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0))
|
||||
val_layout = cute.make_ordered_layout((1, vector_size), order=(1, 0))
|
||||
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
|
||||
|
||||
grouped_buffers = [cute.group_modes(buffer, 0, 2) for buffer in buffers]
|
||||
tiled_buffers = [
|
||||
cute.zipped_divide(buffer, (tiler_mn, world_size, PING_PONG_SIZE))
|
||||
for buffer in grouped_buffers
|
||||
]
|
||||
tiled_input = cute.zipped_divide(local_input, tiler_mn)
|
||||
tiled_output = cute.zipped_divide(local_output, tiler_mn)
|
||||
|
||||
self.kernel(
|
||||
tiled_buffers,
|
||||
tiled_input,
|
||||
tiled_output,
|
||||
thr_layout,
|
||||
val_layout,
|
||||
signal,
|
||||
rank,
|
||||
).launch(
|
||||
grid=[cute.size(tiled_input, mode=[1]), 1, 1],
|
||||
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# GPU device kernel
|
||||
@cute.kernel
|
||||
def kernel(
|
||||
self,
|
||||
buffers: list[cute.Tensor],
|
||||
local_input: cute.Tensor,
|
||||
local_output: cute.Tensor,
|
||||
thr_layout: cute.Layout,
|
||||
val_layout: cute.Layout,
|
||||
signal: cutlass.Int32,
|
||||
rank: cutlass.Constexpr,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
ping = signal % 3
|
||||
pong = (signal + 1) % 3
|
||||
|
||||
buffer_local = buffers[rank]
|
||||
cta_coord = ((None, None), bidx)
|
||||
local_tile_in = local_input[cta_coord]
|
||||
local_tile_out = local_output[cta_coord]
|
||||
|
||||
ping_coord = (((None, None), None, ping), bidx)
|
||||
pong_coord = (((None, None), None, pong), bidx)
|
||||
|
||||
read_buffer = buffer_local[ping_coord]
|
||||
clear_buffer = buffer_local[pong_coord]
|
||||
|
||||
write_coord = (((None, None), rank, ping), bidx)
|
||||
write_buffers = [buffer[write_coord] for buffer in buffers]
|
||||
|
||||
# assume all buffers have the same element type with input
|
||||
copy_atom_load = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
buffers[0].element_type,
|
||||
num_bits_per_copy=128,
|
||||
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
|
||||
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
|
||||
)
|
||||
copy_atom_store = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
buffers[0].element_type,
|
||||
num_bits_per_copy=128,
|
||||
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
|
||||
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
|
||||
)
|
||||
tiled_copy = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
|
||||
thr_copy = tiled_copy.get_slice(tidx)
|
||||
|
||||
thr_write_buffer_list = [
|
||||
thr_copy.partition_D(tensor) for tensor in write_buffers
|
||||
]
|
||||
thr_read_buffer = thr_copy.partition_S(read_buffer)
|
||||
|
||||
thr_clear_buffer = thr_copy.partition_D(clear_buffer)
|
||||
|
||||
thr_in = thr_copy.partition_S(local_tile_in)
|
||||
thr_out = thr_copy.partition_D(local_tile_out)
|
||||
|
||||
frg_in = cute.make_fragment_like(thr_in)
|
||||
frg_clear = cute.make_fragment_like(thr_clear_buffer)
|
||||
frg_acc = cute.make_fragment_like(thr_out)
|
||||
frg_acc.fill(0.0)
|
||||
|
||||
# clear a next buffer to be all negtive 0
|
||||
clear_tensor = frg_clear.load()
|
||||
frg_size = cute.size(clear_tensor.shape)
|
||||
neg0_i32_vec = cute.full_like(clear_tensor, 0x80000000, cutlass.Int32)
|
||||
neg0_f32_vec = vector.bitcast(T.vector(frg_size, T.f32()), neg0_i32_vec)
|
||||
neg0_f32_tensor = cute.TensorSSA(
|
||||
neg0_f32_vec, clear_tensor.shape, cutlass.Float32
|
||||
)
|
||||
frg_clear.store(neg0_f32_tensor)
|
||||
cute.copy(copy_atom_store, frg_clear, thr_clear_buffer)
|
||||
|
||||
# read local data to the register
|
||||
cute.copy(copy_atom_load, thr_in, frg_in)
|
||||
|
||||
# write local data to every buffer at different ranks
|
||||
for thr_write_buffer in thr_write_buffer_list:
|
||||
cute.copy(copy_atom_store, frg_in, thr_write_buffer)
|
||||
|
||||
frg_in_vector_neg0_i32 = cute.full_like(
|
||||
frg_in, cutlass.Int32(0x80000000), cutlass.Int32
|
||||
)
|
||||
frg_in_size = cute.size(frg_in.shape)
|
||||
|
||||
# loop over each buffer and accumulate the data
|
||||
for i in cutlass.range_constexpr(len(buffers)):
|
||||
read_coord = (None, 0, 0, i)
|
||||
cute.copy(copy_atom_load, thr_read_buffer[read_coord], frg_in[None, 0, 0])
|
||||
frg_vector = frg_in.load()
|
||||
frg_vector_i32 = cute.TensorSSA(
|
||||
vector.bitcast(T.vector(frg_in_size, T.i32()), frg_vector),
|
||||
frg_in.shape,
|
||||
cutlass.Int32,
|
||||
)
|
||||
isNotNeg0 = cute.all_(
|
||||
cute.TensorSSA(
|
||||
frg_vector_i32 != frg_in_vector_neg0_i32,
|
||||
frg_in.shape,
|
||||
cutlass.Boolean,
|
||||
)
|
||||
)
|
||||
# if the data is negtive 0, it means data are not ready or not visible yet, so we need to read the data again
|
||||
while not isNotNeg0:
|
||||
cute.copy(
|
||||
copy_atom_load, thr_read_buffer[read_coord], frg_in[None, 0, 0]
|
||||
)
|
||||
frg_vector = frg_in.load()
|
||||
frg_vector_i32 = cute.TensorSSA(
|
||||
vector.bitcast(T.vector(frg_in_size, T.i32()), frg_vector),
|
||||
frg_in.shape,
|
||||
cutlass.Int32,
|
||||
)
|
||||
isNotNeg0 = cute.all_(
|
||||
cute.TensorSSA(
|
||||
frg_vector_i32 != frg_in_vector_neg0_i32,
|
||||
frg_in.shape,
|
||||
cutlass.Boolean,
|
||||
)
|
||||
)
|
||||
frg_acc.store(frg_in.load() + frg_acc.load())
|
||||
|
||||
cute.copy(copy_atom_store, frg_acc, thr_out)
|
||||
|
||||
|
||||
def run_all_reduce_one_shot(
|
||||
M,
|
||||
N,
|
||||
warmup_iterations=2,
|
||||
iterations=10,
|
||||
skip_ref_check=False,
|
||||
benchmark=True,
|
||||
):
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
if rank == 0:
|
||||
print("\nRunning Elementwise Add test with:")
|
||||
print(f"Tensor dimensions: [{M}, {N}]")
|
||||
print(f"GPU count: {world_size}")
|
||||
|
||||
# init buffer tensors to be neg 0
|
||||
local_buffer_tensor = nvshmem.core.tensor([PING_PONG_SIZE, world_size, M, N,], dtype=torch.float32).neg_()
|
||||
buffer_tensor_list = [nvshmem.core.get_peer_tensor(local_buffer_tensor, rank).permute(2, 3, 1, 0) for rank in range(world_size)]
|
||||
signal = cutlass.Int32(0)
|
||||
input_tensor = torch.randn([M, N], device=f"cuda:{rank}")
|
||||
output_tensor = torch.zeros([M, N], device=f"cuda:{rank}")
|
||||
stream = cutlass.cuda.default_stream()
|
||||
all_reduce_one_shot_lamport_kernel = AllReduceOneShotLamportKernel()
|
||||
|
||||
compiled_func = cute.compile(
|
||||
all_reduce_one_shot_lamport_kernel,
|
||||
rank,
|
||||
world_size,
|
||||
signal,
|
||||
from_dlpack(input_tensor, assumed_align=32),
|
||||
from_dlpack(output_tensor, assumed_align=32),
|
||||
[from_dlpack(t, assumed_align=32) for t in buffer_tensor_list],
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if not skip_ref_check:
|
||||
compiled_func(
|
||||
signal,
|
||||
from_dlpack(input_tensor, assumed_align=32),
|
||||
from_dlpack(output_tensor, assumed_align=32),
|
||||
[from_dlpack(t, assumed_align=32) for t in buffer_tensor_list],
|
||||
stream,
|
||||
)
|
||||
if rank == 0:
|
||||
print("Verifying results...")
|
||||
dist.all_reduce(input_tensor, op=dist.ReduceOp.SUM)
|
||||
dist.barrier(device_ids=[rank])
|
||||
torch.testing.assert_close(input_tensor.cpu(), output_tensor.cpu())
|
||||
if rank == 0:
|
||||
print("Results verified successfully!")
|
||||
|
||||
for t in buffer_tensor_list:
|
||||
nvshmem.core.free_tensor(t)
|
||||
|
||||
if not benchmark:
|
||||
return
|
||||
|
||||
free_func_and_tensor_pairs = []
|
||||
def add_free_func_and_tensor(free_func, tensor):
|
||||
free_func_and_tensor_pairs.append((free_func, tensor))
|
||||
|
||||
def generate_tensors():
|
||||
local_buffer = nvshmem.core.tensor([PING_PONG_SIZE, world_size, M, N,], dtype=torch.float32).neg_()
|
||||
buffer_tensor_list = [nvshmem.core.get_peer_tensor(local_buffer, rank).permute(2, 3, 1, 0) for rank in range(world_size)]
|
||||
input_tensor = torch.randn([M, N], device=f"cuda:{rank}")
|
||||
output_tensor = torch.zeros([M, N], device=f"cuda:{rank}")
|
||||
|
||||
ja = testing.JitArguments(
|
||||
cutlass.Int32(0),
|
||||
from_dlpack(input_tensor, assumed_align=32),
|
||||
from_dlpack(output_tensor, assumed_align=32),
|
||||
[from_dlpack(t, assumed_align=32) for t in buffer_tensor_list],
|
||||
stream=stream
|
||||
)
|
||||
for tensor in buffer_tensor_list:
|
||||
add_free_func_and_tensor(nvshmem.core.free_tensor, tensor)
|
||||
|
||||
return ja
|
||||
avg_time_us = testing.benchmark(
|
||||
compiled_func,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=10,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
# Print execution results
|
||||
if rank == 0:
|
||||
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
|
||||
print(
|
||||
f"Achieved memory throughput: {((world_size + 1) * output_tensor.numel() * 32 // 8) / (avg_time_us / 1e6) / 1e9:.2f} GB/s"
|
||||
)
|
||||
|
||||
for free_func, tensor in free_func_and_tensor_pairs:
|
||||
free_func(tensor)
|
||||
|
||||
def torchrun_uid_init_bcast():
|
||||
"""
|
||||
Initialize NVSHMEM using UniqueID with `torchrun` as the launcher
|
||||
|
||||
It uses torch.distributed.broadcast on a NumPy array to handle the broadcasting
|
||||
"""
|
||||
# Set Torch device
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# nvshmem4py requires a cuda.core Device at init time
|
||||
dev = Device(local_rank)
|
||||
dev.set_current()
|
||||
global stream
|
||||
stream = dev.create_stream()
|
||||
|
||||
# Initialize torch.distributed process group
|
||||
dist.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
)
|
||||
|
||||
# Extract rank, nranks from process group
|
||||
num_ranks = dist.get_world_size()
|
||||
|
||||
# Create an empty uniqueid for all ranks
|
||||
uid = nvshmem.core.get_unique_id(empty=(local_rank != 0))
|
||||
uid_bytes = uid._data.view(np.uint8).copy()
|
||||
uid_tensor = torch.from_numpy(uid_bytes).cuda()
|
||||
dist.broadcast(uid_tensor, src=0)
|
||||
dist.barrier()
|
||||
uid._data[:] = uid_tensor.cpu().numpy().view(uid._data.dtype)
|
||||
|
||||
nvshmem.core.init(device=dev, uid=uid, rank=local_rank, nranks=num_ranks, initializer_method="uid")
|
||||
|
||||
|
||||
def torchrun_finalize():
|
||||
nvshmem.core.finalize()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="example of elementwise add to demonstrate the numpy/pytorch as input for kernels"
|
||||
)
|
||||
parser.add_argument("--M", default=1024, type=int)
|
||||
parser.add_argument("--N", default=1024, type=int)
|
||||
parser.add_argument("--warmup_iterations", default=2, type=int)
|
||||
parser.add_argument("--iterations", default=10, type=int)
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
parser.add_argument("--benchmark", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
torchrun_uid_init_bcast()
|
||||
|
||||
run_all_reduce_one_shot(args.M, args.N, args.warmup_iterations, args.iterations, args.skip_ref_check, args.benchmark)
|
||||
|
||||
torchrun_finalize()
|
||||
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
312
examples/python/CuTeDSL/distributed/all_reduce_simple.py
Normal file
312
examples/python/CuTeDSL/distributed/all_reduce_simple.py
Normal file
@@ -0,0 +1,312 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
import os
|
||||
import time
|
||||
import importlib
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from cuda.core.experimental import Device
|
||||
from cuda.pathfinder import load_nvidia_dynamic_lib
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
try:
|
||||
import nvshmem.core
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"nvshmem4py is required but not installed. Please install it using:\n"
|
||||
" For CUDA 12: pip install nvshmem4py-cu12\n"
|
||||
" For CUDA 13: pip install nvshmem4py-cu13\n"
|
||||
"Note: nvshmem4py version >= 0.1.3 is recommended."
|
||||
) from None
|
||||
|
||||
try:
|
||||
load_nvidia_dynamic_lib("nvshmem_host")
|
||||
except RuntimeError as exc:
|
||||
raise ImportError(
|
||||
"nvshmem lib is required but not installed. Please install it using:\n"
|
||||
" For CUDA 12: pip install nvidia-nvshmem-cu12\n"
|
||||
" For CUDA 13: pip install nvidia-nvshmem-cu13\n"
|
||||
) from None
|
||||
|
||||
"""
|
||||
A Distributed All-Reduce Addition Example using CuTe DSL and PyTorch Symmetric Memory.
|
||||
|
||||
This example kernel demonstrates distributed all-reduce across multiple GPUs using the SIMT copy
|
||||
of CuTe DSL and PyTorch's symmetric memory feature. Basic CuTe layout calculation is derived
|
||||
from the elementwise_add.py example.
|
||||
|
||||
This kernel is a simple version of all-reduce. It will directly copy data from remote memory to
|
||||
registers, then accumulate the data and finally store the accumulated data back to local global memory.
|
||||
If the input tensors from each device are remotely accessible, then this kernel can be used to perform the all-reduce.
|
||||
|
||||
On the host side, we use `torch.distributed._symmetric_memory` to manage the symmetric memory. We use `symm_mem.empty`
|
||||
and `symm_mem.rendezvous` to create a symmetric tensor. Then we use `get_buffer` to get tensors that are accessible from all devices.
|
||||
In this way, we can hide the details of CUDA driver API calls to enable access to remote memory.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
t = symm_mem.empty((M, N), device=torch.device(f"cuda:{rank}"))
|
||||
hdl = symm_mem.rendezvous(t, dist.group.WORLD)
|
||||
# get tensors from other devices from the symmetric memory
|
||||
tensor_list = [hdl.get_buffer(rank, t.shape, t.dtype) for rank in range(world_size)]
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
torchrun --nproc-per-node 8 examples/distributed/all_reduce_simple.py --M 1024 --N 512
|
||||
torchrun --nproc-per-node 8 examples/distributed/all_reduce_simple.py \
|
||||
--M 1024 --N 1024 --benchmark --warmup_iterations 2 --iterations 100
|
||||
"""
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def all_reduce_simple_kernel(
|
||||
inputs: list[cute.Tensor],
|
||||
gOut: cute.Tensor,
|
||||
thr_layout: cute.Layout,
|
||||
val_layout: cute.Layout,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
|
||||
# slice for CTAs
|
||||
# logical id -> address
|
||||
blk_coord = ((None, None), bidx)
|
||||
local_tile_out = gOut[blk_coord]
|
||||
local_tile_list = [t[blk_coord] for t in inputs]
|
||||
|
||||
assert all(t.element_type == inputs[0].element_type for t in inputs)
|
||||
|
||||
copy_atom_load = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
inputs[0].element_type,
|
||||
)
|
||||
copy_atom_store = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
inputs[0].element_type,
|
||||
)
|
||||
tiled_copy = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
|
||||
thr_copy = tiled_copy.get_slice(tidx)
|
||||
|
||||
thr_tensor_list = [thr_copy.partition_S(tensor) for tensor in local_tile_list]
|
||||
thr_out = thr_copy.partition_D(local_tile_out)
|
||||
frg_tensor_list = [cute.make_fragment_like(tensor) for tensor in thr_tensor_list]
|
||||
frg_acc = cute.make_fragment_like(thr_out)
|
||||
frg_acc.fill(0.0)
|
||||
|
||||
# load the frg at the same offset from all devices and accumulate the result in frg_acc
|
||||
for thr, frg in zip(thr_tensor_list, frg_tensor_list):
|
||||
cute.copy(copy_atom_load, thr, frg)
|
||||
tmp = frg.load() + frg_acc.load()
|
||||
frg_acc.store(tmp)
|
||||
|
||||
# copy from register memory to global memory
|
||||
cute.copy(copy_atom_store, frg_acc, thr_out)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def all_reduce_simple(
|
||||
inputs: list[cute.Tensor], output: cute.Tensor, copy_bits: cutlass.Constexpr = 128
|
||||
):
|
||||
dtype = inputs[0].element_type
|
||||
vector_size = copy_bits // dtype.width
|
||||
|
||||
thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0))
|
||||
val_layout = cute.make_ordered_layout((4, vector_size), order=(1, 0))
|
||||
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
|
||||
|
||||
divided_inputs = [cute.zipped_divide(tensor, tiler_mn) for tensor in inputs]
|
||||
gOut = cute.zipped_divide(output, tiler_mn) # ((Tile),(Rest))
|
||||
all_reduce_simple_kernel(
|
||||
divided_inputs,
|
||||
gOut,
|
||||
thr_layout,
|
||||
val_layout,
|
||||
).launch(
|
||||
grid=[cute.size(gOut, mode=[1]), 1, 1],
|
||||
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
||||
)
|
||||
|
||||
|
||||
def run_all_reduce_simple(
|
||||
M,
|
||||
N,
|
||||
warmup_iterations=2,
|
||||
iterations=10,
|
||||
skip_ref_check=False,
|
||||
benchmark=True,
|
||||
):
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
if rank == 0:
|
||||
print("\nRunning Elementwise Add test with:")
|
||||
print(f"Tensor dimensions: [{M}, {N}]")
|
||||
print(f"GPU count: {world_size}")
|
||||
|
||||
local_tensor = nvshmem.core.tensor((M, N), dtype=torch.float32)
|
||||
local_tensor.random_(0, 100)
|
||||
tensor_list = [nvshmem.core.get_peer_tensor(local_tensor, rank) for rank in range(world_size)]
|
||||
output = torch.zeros((M, N), device=f"cuda:{rank}")
|
||||
|
||||
if rank == 0:
|
||||
print("Compiling kernel with cute.compile ...")
|
||||
start_time = time.time()
|
||||
compiled_func = cute.compile(all_reduce_simple, [from_dlpack(t) for t in tensor_list], from_dlpack(output))
|
||||
compilation_time = time.time() - start_time
|
||||
if rank == 0:
|
||||
print(f"Compilation time: {compilation_time:.4f} seconds")
|
||||
print("Executing vector add kernel...")
|
||||
|
||||
if not skip_ref_check:
|
||||
dist.barrier(device_ids=[rank])
|
||||
compiled_func([from_dlpack(t) for t in tensor_list], from_dlpack(output))
|
||||
if rank == 0:
|
||||
print("Verifying results...")
|
||||
dist.barrier(device_ids=[rank])
|
||||
torch.testing.assert_close(sum([t.cpu() for t in tensor_list]), output.cpu())
|
||||
if rank == 0:
|
||||
print("Results verified successfully!")
|
||||
|
||||
for t in tensor_list:
|
||||
nvshmem.core.free_tensor(t)
|
||||
|
||||
if not benchmark:
|
||||
return
|
||||
|
||||
free_func_and_tensor_pairs = []
|
||||
def add_free_func_and_tensor(free_func, tensor):
|
||||
free_func_and_tensor_pairs.append((free_func, tensor))
|
||||
|
||||
def generate_tensors():
|
||||
local_tensor = nvshmem.core.tensor((M, N), dtype=torch.float32)
|
||||
local_tensor.random_(0, 100)
|
||||
tensor_list = [nvshmem.core.get_peer_tensor(local_tensor, rank) for rank in range(world_size)]
|
||||
output = torch.zeros((M, N), device=f"cuda:{rank}")
|
||||
|
||||
ja = testing.JitArguments(
|
||||
[from_dlpack(t) for t in tensor_list],
|
||||
from_dlpack(output),
|
||||
)
|
||||
for tensor in tensor_list:
|
||||
add_free_func_and_tensor(nvshmem.core.free_tensor, tensor)
|
||||
return ja
|
||||
|
||||
avg_time_us = testing.benchmark(
|
||||
compiled_func,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=10,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
# Print execution results
|
||||
if rank == 0:
|
||||
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
|
||||
print(
|
||||
f"Achieved memory throughput: {((world_size + 1) * output.numel() * 32 // 8) / (avg_time_us / 1e6) / 1e9:.2f} GB/s"
|
||||
)
|
||||
print(f"First few elements of result: \n{output[:3, :3]}")
|
||||
|
||||
for free_func, tensor in free_func_and_tensor_pairs:
|
||||
free_func(tensor)
|
||||
|
||||
|
||||
def torchrun_uid_init_bcast():
|
||||
"""
|
||||
Initialize NVSHMEM using UniqueID with `torchrun` as the launcher
|
||||
|
||||
It uses torch.distributed.broadcast on a NumPy array to handle the broadcasting
|
||||
"""
|
||||
# Set Torch device
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# nvshmem4py requires a cuda.core Device at init time
|
||||
dev = Device(local_rank)
|
||||
dev.set_current()
|
||||
global stream
|
||||
stream = dev.create_stream()
|
||||
|
||||
# Initialize torch.distributed process group
|
||||
dist.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
)
|
||||
|
||||
# Extract rank, nranks from process group
|
||||
num_ranks = dist.get_world_size()
|
||||
|
||||
# Create an empty uniqueid for all ranks
|
||||
uid = nvshmem.core.get_unique_id(empty=(local_rank != 0))
|
||||
uid_bytes = uid._data.view(np.uint8).copy()
|
||||
uid_tensor = torch.from_numpy(uid_bytes).cuda()
|
||||
dist.broadcast(uid_tensor, src=0)
|
||||
dist.barrier()
|
||||
uid._data[:] = uid_tensor.cpu().numpy().view(uid._data.dtype)
|
||||
|
||||
nvshmem.core.init(device=dev, uid=uid, rank=local_rank, nranks=num_ranks, initializer_method="uid")
|
||||
|
||||
|
||||
def torchrun_finalize():
|
||||
nvshmem.core.finalize()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="example of elementwise add to demonstrate the numpy/pytorch as input for kernels"
|
||||
)
|
||||
parser.add_argument("--M", default=1024, type=int)
|
||||
parser.add_argument("--N", default=1024, type=int)
|
||||
parser.add_argument("--warmup_iterations", default=2, type=int)
|
||||
parser.add_argument("--iterations", default=10, type=int)
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
parser.add_argument("--benchmark", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
torchrun_uid_init_bcast()
|
||||
|
||||
run_all_reduce_simple(args.M, args.N, args.warmup_iterations, args.iterations, args.skip_ref_check, args.benchmark)
|
||||
|
||||
torchrun_finalize()
|
||||
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,393 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from cuda.core.experimental import Device
|
||||
from cuda.pathfinder import load_nvidia_dynamic_lib
|
||||
|
||||
import cutlass
|
||||
import cutlass.utils as utils
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
try:
|
||||
import nvshmem.core
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"nvshmem4py is required but not installed. Please install it using:\n"
|
||||
" For CUDA 12: pip install nvshmem4py-cu12\n"
|
||||
" For CUDA 13: pip install nvshmem4py-cu13\n"
|
||||
"Note: nvshmem4py version >= 0.1.3 is recommended."
|
||||
) from None
|
||||
|
||||
try:
|
||||
load_nvidia_dynamic_lib("nvshmem_host")
|
||||
except RuntimeError as exc:
|
||||
raise ImportError(
|
||||
"nvshmem lib is required but not installed. Please install it using:\n"
|
||||
" For CUDA 12: pip install nvidia-nvshmem-cu12\n"
|
||||
" For CUDA 13: pip install nvidia-nvshmem-cu13\n"
|
||||
) from None
|
||||
|
||||
|
||||
"""
|
||||
A Distributed Two-Shot All-Reduce Example using CuTe DSL and PyTorch Symmetric Memory.
|
||||
|
||||
This example kernel demonstrates how to leverage the multimem feature to do a two-shot all-reduce.
|
||||
The multimem instruction is operated on symmetric memory, it can offload the broadcast and reduce
|
||||
to the Nvlink Switch so that the nvlink traffic will be reduced.
|
||||
|
||||
When calling a 'multimem.ld_reduce addrA', the corresponding data from each remote device will be sent to the NVLS
|
||||
and return the reduced data as result. And for 'multimem.st dataA addrA', the data will be sent to the NVLS once and
|
||||
the data will be broadcast to each remote device. So the memory traffic and instruction count is reduced by 8 times
|
||||
with multimem.
|
||||
|
||||
In this example, we are using two-shot styled all-reduce which means each device computes a portion
|
||||
of data and stores them to each device. Compared to the one-shot styled all-reduce, the two-shot one can
|
||||
maximize the performance of throughput. The input and output are symmetric memory so we don't need extra
|
||||
communication buffers here. We use the `sm_wise_inter_gpu_multimem_barrier` to synchronize the data
|
||||
between each device. It is to make sure that each device has done the data transfer.
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
torchrun --nproc-per-node 8 examples/distributed/all_reduce_two_shot_multimem.py --M 1024 --N 512
|
||||
torchrun --nproc-per-node 8 examples/distributed/all_reduce_two_shot_multimem.py \
|
||||
--M 1024 --N 1024 --benchmark --warmup_iterations 2 --iterations 100
|
||||
"""
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def all_reduce_multimem_kernel(
|
||||
gIn: cute.Tensor,
|
||||
gOut: cute.Tensor,
|
||||
flag: cute.Tensor,
|
||||
flag_mc: cute.Tensor,
|
||||
thr_layout: cute.Layout,
|
||||
val_layout: cute.Layout,
|
||||
local_rank: cutlass.Constexpr,
|
||||
world_size: cutlass.Constexpr,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
|
||||
# slice for CTAs
|
||||
# logical id -> address
|
||||
|
||||
num_ctas = cute.size(gIn, mode=[1])
|
||||
chunk_size = num_ctas // world_size
|
||||
blk_idx = local_rank * chunk_size + bidx
|
||||
|
||||
blk_coord = ((None, None), blk_idx)
|
||||
local_tile_out = gOut[blk_coord]
|
||||
local_tile_in = gIn[blk_coord]
|
||||
|
||||
assert gIn.element_type == gOut.element_type
|
||||
|
||||
copy_atom_load = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
gIn.element_type,
|
||||
num_bits_per_copy=128,
|
||||
)
|
||||
tiled_copy = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
|
||||
thr_copy = tiled_copy.get_slice(tidx)
|
||||
|
||||
thr_in = thr_copy.partition_S(local_tile_in)
|
||||
thr_out = thr_copy.partition_D(local_tile_out)
|
||||
|
||||
(_, rest_m), _, _ = thr_in.shape
|
||||
(_, rest_m_stride), _, _ = thr_in.stride
|
||||
|
||||
for i in cutlass.range_constexpr(rest_m):
|
||||
x, y, z, w = utils.distributed.multimem_ld_reduce_4xf32(
|
||||
thr_in[(None, i), 0, 0].iterator
|
||||
)
|
||||
utils.distributed.multimem_st_4xb32(
|
||||
thr_out[(None, i), 0, 0].iterator, x, y, z, w
|
||||
)
|
||||
|
||||
# Ensure all threads in cta have finish issue multimem.ld_reduce and multimem.st instructions
|
||||
cute.arch.sync_threads()
|
||||
|
||||
if tidx == 0:
|
||||
# Linear id of current SM.
|
||||
sm_id_linear = (
|
||||
cute.arch.block_idx()[0]
|
||||
+ cute.arch.block_idx()[1] * cute.arch.grid_dim()[0]
|
||||
+ cute.arch.block_idx()[2]
|
||||
* cute.arch.grid_dim()[0]
|
||||
* cute.arch.grid_dim()[1]
|
||||
)
|
||||
# Release flag with sys scope
|
||||
utils.distributed.multimem_red_add1(
|
||||
flag_mc.iterator + sm_id_linear,
|
||||
scope="sys",
|
||||
order="release",
|
||||
)
|
||||
# Relaxed spin-lock wait flag with sys scope
|
||||
utils.distributed.spin_lock_atom_cas_relaxed_wait(
|
||||
flag.iterator + sm_id_linear,
|
||||
expected_val=world_size,
|
||||
reset_val=0,
|
||||
scope="sys",
|
||||
)
|
||||
|
||||
@cute.jit
|
||||
def all_reduce_multimem(
|
||||
mIn: cute.Tensor,
|
||||
mOut: cute.Tensor,
|
||||
flag: cute.Tensor,
|
||||
flag_mc: cute.Tensor,
|
||||
local_rank: cutlass.Constexpr,
|
||||
world_size: cutlass.Constexpr,
|
||||
copy_bits: cutlass.Constexpr = 128,
|
||||
):
|
||||
dtype = mIn.element_type
|
||||
vector_size = copy_bits // dtype.width
|
||||
|
||||
# we choose a 128x128 tile for a CTA
|
||||
thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0))
|
||||
val_layout = cute.make_ordered_layout((32, vector_size), order=(1, 0))
|
||||
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
|
||||
|
||||
gIn = cute.zipped_divide(mIn, tiler_mn)
|
||||
gOut = cute.zipped_divide(mOut, tiler_mn)
|
||||
|
||||
all_reduce_multimem_kernel(
|
||||
gIn,
|
||||
gOut,
|
||||
flag,
|
||||
flag_mc,
|
||||
thr_layout,
|
||||
val_layout,
|
||||
local_rank,
|
||||
world_size,
|
||||
).launch(
|
||||
grid=[cute.size(gOut, mode=[1]) // world_size, 1, 1],
|
||||
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
||||
)
|
||||
|
||||
|
||||
def run_all_reduce_multimem(
|
||||
M,
|
||||
N,
|
||||
warmup_iterations=2,
|
||||
iterations=10,
|
||||
skip_ref_check=False,
|
||||
benchmark=True,
|
||||
):
|
||||
local_rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
tile_m = 128
|
||||
tile_n = 128
|
||||
|
||||
if local_rank == 0:
|
||||
print("\nRunning Elementwise Add test with:")
|
||||
print(f"Tensor dimensions: [{M}, {N}]")
|
||||
print(f"GPU count: {world_size}")
|
||||
|
||||
local_input_tensor = nvshmem.core.tensor((M, N), dtype=torch.float32)
|
||||
input_tensor = nvshmem.core.get_multicast_tensor(nvshmem.core.Teams.TEAM_NODE, local_input_tensor)
|
||||
|
||||
local_output_tensor = nvshmem.core.tensor((M, N), dtype=torch.float32)
|
||||
output_tensor = nvshmem.core.get_multicast_tensor(nvshmem.core.Teams.TEAM_NODE, local_output_tensor)
|
||||
|
||||
local_flag = nvshmem.core.tensor((M*N//(tile_m*tile_n)), dtype=torch.int32)
|
||||
flag_mc = nvshmem.core.get_multicast_tensor(nvshmem.core.Teams.TEAM_NODE, local_flag)
|
||||
|
||||
if local_rank == 0:
|
||||
print("Compiling kernel with cute.compile ...")
|
||||
start_time = time.time()
|
||||
compiled_func = cute.compile(
|
||||
all_reduce_multimem,
|
||||
from_dlpack(input_tensor),
|
||||
from_dlpack(output_tensor),
|
||||
from_dlpack(local_flag),
|
||||
from_dlpack(flag_mc),
|
||||
local_rank,
|
||||
world_size,
|
||||
)
|
||||
compilation_time = time.time() - start_time
|
||||
if local_rank == 0:
|
||||
print(f"Compilation time: {compilation_time:.4f} seconds")
|
||||
print("Executing all-reduce two shot multimem kernel...")
|
||||
|
||||
if not skip_ref_check:
|
||||
dist.barrier(device_ids=[local_rank])
|
||||
compiled_func(
|
||||
from_dlpack(input_tensor),
|
||||
from_dlpack(output_tensor),
|
||||
from_dlpack(local_flag),
|
||||
from_dlpack(flag_mc),
|
||||
)
|
||||
dist.barrier(device_ids=[local_rank])
|
||||
if local_rank == 0:
|
||||
print("Verifying results...")
|
||||
|
||||
local_buffers = [nvshmem.core.get_peer_tensor(local_input_tensor, local_rank) for local_rank in range(world_size)]
|
||||
torch.testing.assert_close(sum([buffer.cpu() for buffer in local_buffers]), local_output_tensor.cpu())
|
||||
if local_rank == 0:
|
||||
print("Results verified successfully!")
|
||||
for i in range(world_size):
|
||||
if i != local_rank:
|
||||
nvshmem.core.free_tensor(local_buffers[i])
|
||||
|
||||
# always free the multicast tensors first
|
||||
nvshmem.core.free_tensor(input_tensor)
|
||||
nvshmem.core.free_tensor(output_tensor)
|
||||
nvshmem.core.free_tensor(flag_mc)
|
||||
nvshmem.core.free_tensor(local_input_tensor)
|
||||
nvshmem.core.free_tensor(local_output_tensor)
|
||||
nvshmem.core.free_tensor(local_flag)
|
||||
|
||||
if not benchmark:
|
||||
return
|
||||
|
||||
free_func_and_tensor_pairs = []
|
||||
def add_free_func_and_tensor(free_func, tensor):
|
||||
free_func_and_tensor_pairs.append((free_func, tensor))
|
||||
|
||||
def generate_tensors():
|
||||
local_input_tensor = nvshmem.core.tensor((M, N), dtype=torch.float32)
|
||||
input_tensor_mc = nvshmem.core.get_multicast_tensor(nvshmem.core.Teams.TEAM_NODE, local_input_tensor)
|
||||
|
||||
local_output_tensor = nvshmem.core.tensor((M, N), dtype=torch.float32)
|
||||
output_tensor_mc = nvshmem.core.get_multicast_tensor(nvshmem.core.Teams.TEAM_NODE, local_output_tensor)
|
||||
|
||||
local_flag = nvshmem.core.tensor((M*N//(tile_m*tile_n)), dtype=torch.int32)
|
||||
flag_mc = nvshmem.core.get_multicast_tensor(nvshmem.core.Teams.TEAM_NODE, local_flag)
|
||||
|
||||
ja = testing.JitArguments(
|
||||
from_dlpack(input_tensor_mc),
|
||||
from_dlpack(output_tensor_mc),
|
||||
from_dlpack(local_flag),
|
||||
from_dlpack(flag_mc),
|
||||
)
|
||||
tensors_to_free = [input_tensor_mc, output_tensor_mc, flag_mc, local_input_tensor, local_output_tensor, local_flag]
|
||||
for tensor in tensors_to_free:
|
||||
add_free_func_and_tensor(nvshmem.core.free_tensor, tensor)
|
||||
return ja
|
||||
|
||||
dist.barrier(device_ids=[local_rank])
|
||||
avg_time_us = testing.benchmark(
|
||||
compiled_func,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=10,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
dist.barrier(device_ids=[local_rank])
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Print execution results
|
||||
if local_rank == 0:
|
||||
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
|
||||
print(
|
||||
f"Achieved memory throughput: {((world_size + 1) * output_tensor.numel() * 32 // 8) / (avg_time_us / 1e6) / 1e9:.2f} GB/s"
|
||||
)
|
||||
|
||||
for free_func, tensor in free_func_and_tensor_pairs:
|
||||
free_func(tensor)
|
||||
return
|
||||
|
||||
|
||||
def torchrun_uid_init_bcast():
|
||||
"""
|
||||
Initialize NVSHMEM using UniqueID with `torchrun` as the launcher
|
||||
|
||||
It uses torch.distributed.broadcast on a NumPy array to handle the broadcasting
|
||||
"""
|
||||
# Set Torch device
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# nvshmem4py requires a cuda.core Device at init time
|
||||
dev = Device(local_rank)
|
||||
dev.set_current()
|
||||
global stream
|
||||
stream = dev.create_stream()
|
||||
|
||||
# Initialize torch.distributed process group
|
||||
dist.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
)
|
||||
|
||||
# Extract rank, nranks from process group
|
||||
num_ranks = dist.get_world_size()
|
||||
|
||||
# Create an empty uniqueid for all ranks
|
||||
uid = nvshmem.core.get_unique_id(empty=(local_rank != 0))
|
||||
uid_bytes = uid._data.view(np.uint8).copy()
|
||||
uid_tensor = torch.from_numpy(uid_bytes).cuda()
|
||||
dist.broadcast(uid_tensor, src=0)
|
||||
dist.barrier()
|
||||
uid._data[:] = uid_tensor.cpu().numpy().view(uid._data.dtype)
|
||||
|
||||
nvshmem.core.init(device=dev, uid=uid, rank=local_rank, nranks=num_ranks, initializer_method="uid")
|
||||
|
||||
|
||||
def torchrun_finalize():
|
||||
nvshmem.core.finalize()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="example of elementwise add to demonstrate the numpy/pytorch as input for kernels"
|
||||
)
|
||||
parser.add_argument("--M", default=1024, type=int)
|
||||
parser.add_argument("--N", default=1024, type=int)
|
||||
parser.add_argument("--warmup_iterations", default=2, type=int)
|
||||
parser.add_argument("--iterations", default=10, type=int)
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
parser.add_argument("--benchmark", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
torchrun_uid_init_bcast()
|
||||
|
||||
run_all_reduce_multimem(args.M, args.N, args.warmup_iterations, args.iterations, args.skip_ref_check, args.benchmark)
|
||||
|
||||
torchrun_finalize()
|
||||
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user