mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
[CuTeDSL] Distributed example, using TMA load to access remote memory rank-by-rank, reducing in cta, broadcast result to all ranks by multimem TMA store (#2970)
This commit is contained in:
691
examples/python/CuTeDSL/distributed/all_reduce_tma.py
Normal file
691
examples/python/CuTeDSL/distributed/all_reduce_tma.py
Normal file
@@ -0,0 +1,691 @@
|
||||
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
A Distributed All-Reduce Example using TMA (Tensor Memory Accelerator).
|
||||
|
||||
This example demonstrates distributed all-reduce across multiple GPUs using TMA
|
||||
for data movement. It serves as a tutorial for TMA-based distributed operations,
|
||||
not as a performance-optimized implementation.
|
||||
|
||||
Tensor Semantics:
|
||||
- Input: Logical shape (world_size, S), where S is the per-rank tensor size
|
||||
- Output: Logical shape (world_size, S), each rank gets the sum of all inputs
|
||||
|
||||
Kernel Parameters:
|
||||
- input: List of world_size tensors, each with shape S (accessible via NVSHMEM)
|
||||
- output: Single tensor with shape S, using multicast address for broadcast
|
||||
|
||||
Algorithm (Two-Shot):
|
||||
1. Each CTA loads data from all ranks at its assigned tile position (TMA Load)
|
||||
2. Accumulates the data locally in registers
|
||||
3. Stores the result via TMA multicast (broadcasts to all ranks)
|
||||
4. Cross-GPU barrier ensures completion before kernel exit
|
||||
|
||||
Tile Assignment:
|
||||
- Total tiles = ceil(S / elems_per_cta)
|
||||
- Each rank processes ceil(total_tiles / world_size) CTAs
|
||||
- CTA i on rank r processes global_tile_id = r * ctas_per_rank + i
|
||||
|
||||
TMA Usage Notes (for tutorial purposes, not perf-optimal):
|
||||
- Uses TMALDG.1D to load from remote GPU memory via NVSHMEM addresses
|
||||
- Uses TMASTG.1D to store to multicast address for broadcasting to all ranks
|
||||
- Supports any input shape by flattening to 1D and tiling linearly
|
||||
- Pipeline with 2 stages overlaps TMA loads across ranks
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
torchrun --nproc-per-node 8 examples/distributed/all_reduce_tma.py --shape 1024,1024
|
||||
torchrun --nproc-per-node 8 examples/distributed/all_reduce_tma.py --shape 4,6,8,10,12
|
||||
"""
|
||||
|
||||
import cutlass
|
||||
import cutlass.utils as utils
|
||||
import cutlass.cute as cute
|
||||
import cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync
|
||||
|
||||
|
||||
class AllReduceTmaKernel:
|
||||
"""
|
||||
TMA-based distributed All-Reduce kernel.
|
||||
|
||||
This kernel performs an all-reduce operation across multiple GPUs using TMA
|
||||
(Tensor Memory Accelerator) for efficient data movement.
|
||||
|
||||
Algorithm (Two-Shot):
|
||||
1. Each CTA loads data from all ranks at its assigned tile position
|
||||
2. Accumulates the data locally in registers
|
||||
3. Stores the result via TMA multicast (broadcasts to all ranks)
|
||||
4. Cross-GPU barrier ensures completion before kernel exit
|
||||
|
||||
The input/output tensors can be of any rank, as long as:
|
||||
- All input tensors and output tensor share the same layout
|
||||
- The layout is compact (no holes in memory)
|
||||
|
||||
We traverse the tensors linearly in codomain (physical offset) order,
|
||||
which guarantees consistent logical coordinate access across all tensors.
|
||||
"""
|
||||
|
||||
_elems_per_cta: int = 128 * 128 # Elements processed per CTA
|
||||
_tma_threads: int = 32
|
||||
_consumer_threads: int = 128
|
||||
_threads_per_cta: int = _tma_threads + _consumer_threads
|
||||
_num_stages: int = 2
|
||||
|
||||
def __init__(self, dtype):
|
||||
self.dtype = dtype
|
||||
|
||||
# SMEM layout shape (will be converted to Layout in JIT context)
|
||||
self.smem_layout_shape = (self._elems_per_cta,)
|
||||
self.tiler = (self._elems_per_cta,)
|
||||
|
||||
# TMA transaction bytes (computed from dtype size)
|
||||
# dtype.width is in bits, divide by 8 to get bytes
|
||||
self.tma_bytes = (dtype.width // 8) * self._elems_per_cta
|
||||
|
||||
# Dynamically create SharedStorage type based on dtype
|
||||
elems = self._elems_per_cta
|
||||
stages = self._num_stages
|
||||
|
||||
@cute.struct
|
||||
class SharedStorage:
|
||||
mbar_array: cute.struct.MemRange[cutlass.Int64, stages * 2]
|
||||
smem_buffer: cute.struct.Align[
|
||||
cute.struct.MemRange[dtype, elems * stages], # stages 个 tile
|
||||
128,
|
||||
]
|
||||
|
||||
self._SharedStorage = SharedStorage
|
||||
|
||||
@cute.jit
|
||||
def __call__(
|
||||
self,
|
||||
input_tensors: list[cute.Tensor],
|
||||
output_tensor_mc: cute.Tensor,
|
||||
flag: cute.Tensor,
|
||||
flag_mc: cute.Tensor,
|
||||
local_rank: cutlass.Constexpr,
|
||||
world_size: cutlass.Constexpr,
|
||||
):
|
||||
"""
|
||||
Host-side JIT function: creates TMA descriptors and launches kernel.
|
||||
|
||||
Args:
|
||||
input_tensors: List of input tensors from each rank (world_size tensors)
|
||||
output_tensor_mc: Output tensor with multicast address
|
||||
flag: Synchronization flag (local view)
|
||||
flag_mc: Synchronization flag (multicast view)
|
||||
local_rank: This rank's ID
|
||||
world_size: Total number of ranks
|
||||
"""
|
||||
# ======================================================================
|
||||
# Layout validation
|
||||
# ======================================================================
|
||||
ref_layout = input_tensors[0].layout
|
||||
ref_size = cute.size(ref_layout)
|
||||
ref_cosize = cute.cosize(ref_layout)
|
||||
|
||||
# Check compact: size == cosize (no holes in memory)
|
||||
assert ref_size == ref_cosize, (
|
||||
f"Input tensor must be compact: size={ref_size}, cosize={ref_cosize}"
|
||||
)
|
||||
assert self.tma_bytes % 16 == 0, f"Not aligned to 16B, TMA should not be used."
|
||||
|
||||
# Check all input tensors have the same layout
|
||||
for i in cutlass.range_constexpr(world_size):
|
||||
assert input_tensors[i].layout == ref_layout, (
|
||||
f"All input tensors must have the same layout. "
|
||||
f"input_tensors[0].layout={ref_layout}, "
|
||||
f"input_tensors[{i}].layout={input_tensors[i].layout}"
|
||||
)
|
||||
|
||||
# Check output tensor has the same layout
|
||||
assert output_tensor_mc.layout == ref_layout, (
|
||||
f"Output tensor must have the same layout as input tensors. "
|
||||
f"input layout={ref_layout}, output layout={output_tensor_mc.layout}"
|
||||
)
|
||||
|
||||
# ======================================================================
|
||||
# Extract tensor info
|
||||
# ======================================================================
|
||||
# Verify dtype matches
|
||||
assert input_tensors[0].element_type == self.dtype, (
|
||||
f"Input tensor dtype mismatch: expected {self.dtype}, "
|
||||
f"got {input_tensors[0].element_type}"
|
||||
)
|
||||
|
||||
total_elems = ref_size
|
||||
|
||||
# Flatten layout: treat tensor as 1D in codomain order
|
||||
flat_layout = cute.make_layout((total_elems,))
|
||||
|
||||
# SMEM layout (created in JIT context)
|
||||
smem_layout = cute.make_layout(self.smem_layout_shape)
|
||||
|
||||
# Create TMA load descriptors (one per rank)
|
||||
tma_load_op = cpasync.CopyBulkTensorTileG2SOp()
|
||||
tma_load_atoms = []
|
||||
tma_load_tensors = []
|
||||
|
||||
for i in cutlass.range_constexpr(world_size):
|
||||
flat_input = cute.make_tensor(input_tensors[i].iterator, flat_layout)
|
||||
tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(
|
||||
tma_load_op,
|
||||
flat_input,
|
||||
smem_layout,
|
||||
self.tiler,
|
||||
)
|
||||
tma_load_atoms.append(tma_atom)
|
||||
tma_load_tensors.append(tma_tensor)
|
||||
|
||||
# Create TMA store descriptor
|
||||
tma_store_op = cpasync.CopyBulkTensorTileS2GOp()
|
||||
flat_output = cute.make_tensor(output_tensor_mc.iterator, flat_layout)
|
||||
tma_store_atom, tma_store_tensor = cpasync.make_tiled_tma_atom(
|
||||
tma_store_op,
|
||||
flat_output,
|
||||
smem_layout,
|
||||
self.tiler,
|
||||
)
|
||||
|
||||
# Grid calculation
|
||||
num_tiles_total = cute.ceil_div(total_elems, self._elems_per_cta)
|
||||
ctas_per_rank = cute.ceil_div(num_tiles_total, world_size)
|
||||
|
||||
# SMEM size from SharedStorage
|
||||
smem_bytes = self._SharedStorage.size_in_bytes()
|
||||
|
||||
# Launch kernel
|
||||
self.kernel(
|
||||
tma_load_atoms,
|
||||
tma_load_tensors,
|
||||
tma_store_atom,
|
||||
tma_store_tensor,
|
||||
flag,
|
||||
flag_mc,
|
||||
local_rank,
|
||||
world_size,
|
||||
num_tiles_total,
|
||||
ctas_per_rank,
|
||||
).launch(
|
||||
grid=[ctas_per_rank, 1, 1],
|
||||
block=[self._threads_per_cta, 1, 1],
|
||||
smem=smem_bytes,
|
||||
)
|
||||
|
||||
@cute.kernel
|
||||
def kernel(
|
||||
self,
|
||||
# TMA atoms and tensors for loading from each rank
|
||||
tma_load_atoms: list[cute.CopyAtom],
|
||||
tma_load_tensors: list[cute.Tensor],
|
||||
# TMA atom and tensor for storing to multicast address
|
||||
tma_store_atom: cute.CopyAtom,
|
||||
tma_store_tensor: cute.Tensor,
|
||||
# Synchronization flags
|
||||
flag: cute.Tensor,
|
||||
flag_mc: cute.Tensor,
|
||||
# Rank info
|
||||
local_rank: cutlass.Constexpr,
|
||||
world_size: cutlass.Constexpr,
|
||||
# Grid info for tile calculation
|
||||
num_tiles_total: cutlass.Constexpr,
|
||||
ctas_per_rank: cutlass.Constexpr,
|
||||
):
|
||||
# ======================================================================
|
||||
# Thread/Block indexing
|
||||
# ======================================================================
|
||||
tidx = cute.arch.thread_idx()[0]
|
||||
bidx = cute.arch.block_idx()[0]
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
||||
|
||||
# ======================================================================
|
||||
# SMEM allocation
|
||||
# ======================================================================
|
||||
staged_smem_layout = cute.make_layout((self._elems_per_cta, self._num_stages))
|
||||
|
||||
smem = utils.SmemAllocator()
|
||||
storage = smem.allocate(self._SharedStorage)
|
||||
mbar_ptr = storage.mbar_array.data_ptr()
|
||||
staged_smem_tensor = storage.smem_buffer.get_tensor(staged_smem_layout)
|
||||
|
||||
# ======================================================================
|
||||
# TMA Pipeline setup
|
||||
# ======================================================================
|
||||
producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1)
|
||||
consumer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, self._consumer_threads
|
||||
)
|
||||
|
||||
tma_pipeline = pipeline.PipelineTmaAsync.create(
|
||||
barrier_storage=mbar_ptr,
|
||||
num_stages=self._num_stages,
|
||||
producer_group=producer_group,
|
||||
consumer_group=consumer_group,
|
||||
tx_count=self.tma_bytes,
|
||||
cta_layout_vmnk=cute.make_layout((1, 1, 1, 1)),
|
||||
)
|
||||
|
||||
global_tile_id = local_rank * ctas_per_rank + bidx
|
||||
|
||||
if global_tile_id < num_tiles_total:
|
||||
# ======================================================================
|
||||
# Warp 0: Producer - TMA Load from all ranks
|
||||
# ======================================================================
|
||||
if warp_idx == 0:
|
||||
producer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Producer, self._num_stages
|
||||
)
|
||||
|
||||
for rank_i in cutlass.range_constexpr(world_size):
|
||||
tma_pipeline.producer_acquire(producer_state)
|
||||
|
||||
stage_idx = producer_state.index
|
||||
smem_tile = cute.slice_(staged_smem_tensor, (None, stage_idx))
|
||||
|
||||
g_tensor_tiled = cute.zipped_divide(
|
||||
tma_load_tensors[rank_i], self.tiler
|
||||
)
|
||||
g_tile = g_tensor_tiled[(None,), global_tile_id]
|
||||
|
||||
g_tile_flat = cute.group_modes(g_tile, 0, cute.rank(g_tile))
|
||||
s_tile_flat = cute.group_modes(smem_tile, 0, cute.rank(smem_tile))
|
||||
|
||||
s_part, g_part = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_load_atoms[rank_i],
|
||||
0,
|
||||
cute.make_layout(1),
|
||||
s_tile_flat,
|
||||
g_tile_flat,
|
||||
)
|
||||
|
||||
cute.copy(
|
||||
tma_load_atoms[rank_i],
|
||||
g_part,
|
||||
s_part,
|
||||
tma_bar_ptr=tma_pipeline.producer_get_barrier(producer_state),
|
||||
)
|
||||
|
||||
tma_pipeline.producer_commit(producer_state)
|
||||
producer_state.advance()
|
||||
|
||||
# ======================================================================
|
||||
# Warp 1-4: Consumer - LDS, ADD, STS
|
||||
# ======================================================================
|
||||
else:
|
||||
consumer_tid = tidx - self._tma_threads
|
||||
|
||||
vec_size = 4
|
||||
chunk_size = vec_size * self._consumer_threads
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Initialize accumulator using stage 0's layout
|
||||
# ------------------------------------------------------------------
|
||||
# (elems, stages) -> (elems,)
|
||||
smem_tensor_wo_stage = cute.slice_(staged_smem_tensor, (None, 0))
|
||||
# (elems,) -> ((thr_vec,), (num_chunks,))
|
||||
smem_tensor_tiled_by_thr_vec = cute.zipped_divide(
|
||||
smem_tensor_wo_stage, (chunk_size,)
|
||||
)
|
||||
# ((thr_vec,), (num_chunks,)) -> (((vec, threads),), (num_chunks,))
|
||||
smem_tensor_tiled_by_thr_vec_tiled_by_vec = cute.logical_divide(
|
||||
smem_tensor_tiled_by_thr_vec, (vec_size,)
|
||||
)
|
||||
# (((vec, threads),), (num_chunks,)) -> ((vec,), (num_chunks,))
|
||||
per_thread_smem_tensor = cute.slice_(
|
||||
smem_tensor_tiled_by_thr_vec_tiled_by_vec,
|
||||
((None, consumer_tid), None),
|
||||
)
|
||||
|
||||
accum = cute.make_rmem_tensor(per_thread_smem_tensor.layout, self.dtype)
|
||||
accum.fill(self.dtype(0.0))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Main loop: load from SMEM and accumulate
|
||||
# ------------------------------------------------------------------
|
||||
consumer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self._num_stages
|
||||
)
|
||||
|
||||
for rank_i in cutlass.range_constexpr(world_size):
|
||||
tma_pipeline.consumer_wait(consumer_state)
|
||||
|
||||
stage_idx = consumer_state.index
|
||||
smem_tile = cute.slice_(staged_smem_tensor, (None, stage_idx))
|
||||
|
||||
# (elems,) -> ((thr_vec,), (num_chunks,))
|
||||
smem_tiled_by_thr_vec = cute.zipped_divide(smem_tile, (chunk_size,))
|
||||
# ((thr_vec,), (num_chunks,)) -> (((vec, threads),), (num_chunks,))
|
||||
smem_tiled_by_thr_vec_tiled_by_vec = cute.logical_divide(
|
||||
smem_tiled_by_thr_vec, (vec_size,)
|
||||
)
|
||||
# (((vec, threads),), (num_chunks,)) -> ((vec,), (num_chunks,))
|
||||
per_thread_smem_view = cute.slice_(
|
||||
smem_tiled_by_thr_vec_tiled_by_vec,
|
||||
((None, consumer_tid), None),
|
||||
)
|
||||
|
||||
fragment = per_thread_smem_view.load()
|
||||
accum.store(accum.load() + fragment)
|
||||
|
||||
tma_pipeline.sync_object_empty.arrive(
|
||||
consumer_state.index, tma_pipeline.consumer_mask
|
||||
)
|
||||
consumer_state.advance()
|
||||
|
||||
# Store accumulated result back to SMEM (stage 0)
|
||||
per_thread_smem_tensor.store(accum.load())
|
||||
|
||||
# ======================================================================
|
||||
# Sync point: all warps meet here
|
||||
# ======================================================================
|
||||
cute.arch.sync_threads()
|
||||
|
||||
# ======================================================================
|
||||
# Warp 0: TMA Store to multicast output
|
||||
# ======================================================================
|
||||
if warp_idx == 0:
|
||||
# Fence to ensure SMEM writes are visible
|
||||
cute.arch.fence_proxy(
|
||||
cute.arch.ProxyKind.async_shared,
|
||||
space=cute.arch.SharedSpace.shared_cta,
|
||||
)
|
||||
|
||||
smem_tile_out = cute.slice_(staged_smem_tensor, (None, 0))
|
||||
|
||||
g_output_tiled = cute.zipped_divide(tma_store_tensor, self.tiler)
|
||||
g_output_tile = g_output_tiled[(None,), global_tile_id]
|
||||
|
||||
g_out_flat = cute.group_modes(
|
||||
g_output_tile, 0, cute.rank(g_output_tile)
|
||||
)
|
||||
s_out_flat = cute.group_modes(
|
||||
smem_tile_out, 0, cute.rank(smem_tile_out)
|
||||
)
|
||||
|
||||
s_part, g_part = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_store_atom,
|
||||
0,
|
||||
cute.make_layout(1),
|
||||
s_out_flat,
|
||||
g_out_flat,
|
||||
)
|
||||
|
||||
cute.copy(tma_store_atom, s_part, g_part)
|
||||
cute.arch.cp_async_bulk_commit_group()
|
||||
cute.arch.cp_async_bulk_wait_group(0)
|
||||
|
||||
# ==================================================================
|
||||
# Cross-GPU barrier synchronization (thread 0 only)
|
||||
# ==================================================================
|
||||
if tidx == 0:
|
||||
sm_id_linear = (
|
||||
cute.arch.block_idx()[0]
|
||||
+ cute.arch.block_idx()[1] * cute.arch.grid_dim()[0]
|
||||
+ cute.arch.block_idx()[2]
|
||||
* cute.arch.grid_dim()[0]
|
||||
* cute.arch.grid_dim()[1]
|
||||
)
|
||||
|
||||
# Signal completion to all ranks
|
||||
utils.distributed.multimem_red_add1(
|
||||
flag_mc.iterator + sm_id_linear,
|
||||
scope="sys",
|
||||
order="release",
|
||||
)
|
||||
|
||||
# The same idx ctas wait until all peer ranks' ctas complete
|
||||
utils.distributed.spin_lock_atom_cas_relaxed_wait(
|
||||
flag.iterator + sm_id_linear,
|
||||
expected_val=world_size,
|
||||
reset_val=0,
|
||||
scope="sys",
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# HOST-SIDE DRIVER CODE
|
||||
# =============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from cuda.core.experimental import Device
|
||||
from cuda.pathfinder import load_nvidia_dynamic_lib
|
||||
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
try:
|
||||
import nvshmem.core
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"nvshmem4py is required but not installed. Please install it using:\n"
|
||||
" For CUDA 12: pip install nvshmem4py-cu12\n"
|
||||
" For CUDA 13: pip install nvshmem4py-cu13\n"
|
||||
"Note: nvshmem4py version >= 0.1.3 is recommended."
|
||||
) from None
|
||||
|
||||
try:
|
||||
load_nvidia_dynamic_lib("nvshmem_host")
|
||||
except RuntimeError as exc:
|
||||
raise ImportError(
|
||||
"nvshmem lib is required but not installed. Please install it using:\n"
|
||||
" For CUDA 12: pip install nvidia-nvshmem-cu12\n"
|
||||
" For CUDA 13: pip install nvidia-nvshmem-cu13\n"
|
||||
) from None
|
||||
|
||||
|
||||
def torchrun_uid_init_bcast():
|
||||
"""Initialize NVSHMEM using UniqueID with torchrun as launcher."""
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
dev = Device(local_rank)
|
||||
dev.set_current()
|
||||
global stream
|
||||
stream = dev.create_stream()
|
||||
|
||||
dist.init_process_group(backend="cpu:gloo,cuda:nccl")
|
||||
num_ranks = dist.get_world_size()
|
||||
|
||||
uid = nvshmem.core.get_unique_id(empty=(local_rank != 0))
|
||||
uid_bytes = uid._data.view(np.uint8).copy()
|
||||
uid_tensor = torch.from_numpy(uid_bytes).cuda()
|
||||
dist.broadcast(uid_tensor, src=0)
|
||||
dist.barrier()
|
||||
uid._data[:] = uid_tensor.cpu().numpy().view(uid._data.dtype)
|
||||
|
||||
nvshmem.core.init(
|
||||
device=dev, uid=uid, rank=local_rank, nranks=num_ranks, initializer_method="uid"
|
||||
)
|
||||
|
||||
|
||||
def torchrun_finalize():
|
||||
"""Finalize NVSHMEM and destroy process group."""
|
||||
nvshmem.core.finalize()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def run_all_reduce_tma(
|
||||
shape: tuple,
|
||||
skip_ref_check: bool = False,
|
||||
):
|
||||
"""
|
||||
Run the TMA-based All-Reduce kernel.
|
||||
|
||||
Args:
|
||||
shape: Tensor shape tuple, e.g., (4, 6, 8, 10)
|
||||
skip_ref_check: If True, skip reference result verification
|
||||
"""
|
||||
local_rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
# Calculate total elements
|
||||
total_elems = math.prod(shape)
|
||||
|
||||
if local_rank == 0:
|
||||
print("\nRunning TMA All-Reduce test with:")
|
||||
print(f" Tensor shape: {shape}")
|
||||
print(f" Total elements: {total_elems}")
|
||||
print(f" GPU count: {world_size}")
|
||||
|
||||
# Allocate input tensor (symmetric memory, accessible from all ranks)
|
||||
local_input_tensor = nvshmem.core.tensor(shape, dtype=torch.float32)
|
||||
local_input_tensor.random_(0, 100)
|
||||
|
||||
# Get peer tensors (views into each rank's input)
|
||||
peer_input_tensors = [
|
||||
nvshmem.core.get_peer_tensor(local_input_tensor, r) for r in range(world_size)
|
||||
]
|
||||
|
||||
if local_rank == 0:
|
||||
print(f" Input tensor ptr: {local_input_tensor.data_ptr():#x}")
|
||||
|
||||
# Allocate output tensor with multicast address
|
||||
local_output_tensor = nvshmem.core.tensor(shape, dtype=torch.float32)
|
||||
local_output_tensor.fill_(0)
|
||||
output_tensor_mc = nvshmem.core.get_multicast_tensor(
|
||||
nvshmem.core.Teams.TEAM_NODE, local_output_tensor
|
||||
)
|
||||
|
||||
# Allocate synchronization flags
|
||||
# Flag size = ctas_per_rank (matches kernel's bidx indexing)
|
||||
elems_per_cta = AllReduceTmaKernel._elems_per_cta
|
||||
num_tiles = (total_elems + elems_per_cta - 1) // elems_per_cta
|
||||
ctas_per_rank = (num_tiles + world_size - 1) // world_size
|
||||
local_flag = nvshmem.core.tensor((ctas_per_rank,), dtype=torch.int32)
|
||||
local_flag.fill_(0)
|
||||
flag_mc = nvshmem.core.get_multicast_tensor(
|
||||
nvshmem.core.Teams.TEAM_NODE, local_flag
|
||||
)
|
||||
|
||||
if local_rank == 0:
|
||||
print(f" Number of tiles: {num_tiles}")
|
||||
print(f" CTAs per rank: {ctas_per_rank}")
|
||||
print("Compiling kernel...")
|
||||
|
||||
# Create kernel instance and compile
|
||||
kernel = AllReduceTmaKernel(cutlass.Float32)
|
||||
|
||||
compiled_func = cute.compile(
|
||||
kernel,
|
||||
[from_dlpack(t) for t in peer_input_tensors],
|
||||
from_dlpack(output_tensor_mc),
|
||||
from_dlpack(local_flag),
|
||||
from_dlpack(flag_mc),
|
||||
local_rank,
|
||||
world_size,
|
||||
)
|
||||
|
||||
if local_rank == 0:
|
||||
print("Compilation successful!")
|
||||
|
||||
if not skip_ref_check:
|
||||
if local_rank == 0:
|
||||
print("Executing kernel...")
|
||||
|
||||
dist.barrier(device_ids=[local_rank])
|
||||
compiled_func(
|
||||
[from_dlpack(t) for t in peer_input_tensors],
|
||||
from_dlpack(output_tensor_mc),
|
||||
from_dlpack(local_flag),
|
||||
from_dlpack(flag_mc),
|
||||
)
|
||||
dist.barrier(device_ids=[local_rank])
|
||||
|
||||
if local_rank == 0:
|
||||
print("Verifying results...")
|
||||
|
||||
# Compute expected result: sum of all inputs
|
||||
expected = sum([t.cpu() for t in peer_input_tensors])
|
||||
|
||||
# Compare with actual output
|
||||
torch.testing.assert_close(expected, local_output_tensor.cpu())
|
||||
|
||||
if local_rank == 0:
|
||||
print("Results verified successfully!")
|
||||
|
||||
# Cleanup
|
||||
for i in range(world_size):
|
||||
if i != local_rank:
|
||||
nvshmem.core.free_tensor(peer_input_tensors[i])
|
||||
|
||||
nvshmem.core.free_tensor(output_tensor_mc)
|
||||
nvshmem.core.free_tensor(flag_mc)
|
||||
nvshmem.core.free_tensor(local_input_tensor)
|
||||
nvshmem.core.free_tensor(local_output_tensor)
|
||||
nvshmem.core.free_tensor(local_flag)
|
||||
|
||||
|
||||
def parse_shape(shape_str: str) -> tuple:
|
||||
"""
|
||||
Parse shape string into tuple.
|
||||
Examples:
|
||||
"1024,1024" -> (1024, 1024)
|
||||
"2,3,4,5,6,7,8" -> (2, 3, 4, 5, 6, 7, 8)
|
||||
"""
|
||||
return tuple(int(x.strip()) for x in shape_str.split(","))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="TMA-based distributed all-reduce example"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shape",
|
||||
default="1024,1024",
|
||||
type=str,
|
||||
help="Tensor shape as comma-separated values, e.g., '1024,1024' or 4,6,8,10,12'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_ref_check",
|
||||
action="store_true",
|
||||
help="Skip reference result verification",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
shape = parse_shape(args.shape)
|
||||
|
||||
torchrun_uid_init_bcast()
|
||||
run_all_reduce_tma(
|
||||
shape=shape,
|
||||
skip_ref_check=args.skip_ref_check,
|
||||
)
|
||||
torchrun_finalize()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user