mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-11 08:50:09 +00:00
534 lines
20 KiB
Python
534 lines
20 KiB
Python
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
"""
|
|
Hierarchical Reduction Utilities for CuTe-DSL Kernels
|
|
=====================================================
|
|
|
|
This module provides reusable reduction primitives for GPU kernels that need to
|
|
reduce values across warps, thread blocks, and clusters (SM90+).
|
|
|
|
Overview
|
|
--------
|
|
GPU reductions typically follow a hierarchical pattern:
|
|
|
|
1. **Warp Reduction**: Threads within a warp reduce using shuffle instructions.
|
|
Use `cute.arch.warp_reduction()` from the CuTe-DSL library.
|
|
|
|
2. **Block Reduction**: Multiple warps within a block reduce using shared memory.
|
|
Use `block_reduce()` from this module.
|
|
|
|
3. **Cluster Reduction** (SM90+): Multiple CTAs in a cluster reduce using
|
|
distributed shared memory and mbarrier synchronization.
|
|
Use `cluster_reduce()` from this module.
|
|
|
|
4. **Row Reduction**: Orchestrates all levels based on problem configuration.
|
|
Use `row_reduce()` from this module.
|
|
|
|
Shared Memory Buffer Layout Assumptions
|
|
---------------------------------------
|
|
|
|
For `block_reduce`:
|
|
- Buffer shape: (rows_per_block, warps_per_row)
|
|
- Each warp's lane 0 writes its reduced value to buffer[row_idx, col_idx]
|
|
- Thread mapping: row_idx = warp_idx // warps_per_row
|
|
col_idx = warp_idx % warps_per_row
|
|
|
|
Example for 8 warps, 2 rows, 4 warps per row:
|
|
Warp 0 -> buffer[0, 0] Warp 4 -> buffer[1, 0]
|
|
Warp 1 -> buffer[0, 1] Warp 5 -> buffer[1, 1]
|
|
Warp 2 -> buffer[0, 2] Warp 6 -> buffer[1, 2]
|
|
Warp 3 -> buffer[0, 3] Warp 7 -> buffer[1, 3]
|
|
|
|
For `cluster_reduce`:
|
|
- Buffer shape: (rows_per_block, (warps_per_row, cluster_n))
|
|
- The second dimension is hierarchical: (local_warp_slot, cta_rank)
|
|
- Each CTA contributes to its own slot in the cluster dimension
|
|
|
|
Example for cluster_n=4, 2 warps per row:
|
|
CTA 0, Warp 0 -> buffer[row, (0, 0)]
|
|
CTA 0, Warp 1 -> buffer[row, (1, 0)]
|
|
CTA 1, Warp 0 -> buffer[row, (0, 1)]
|
|
CTA 1, Warp 1 -> buffer[row, (1, 1)]
|
|
... etc for CTAs 2, 3
|
|
|
|
Mbarrier Requirements (Cluster Reduction)
|
|
-----------------------------------------
|
|
For cluster reduction, the caller must:
|
|
1. Allocate an mbarrier in shared memory
|
|
2. Initialize it with `cute.arch.mbarrier_init(mbar_ptr, thread_count)`
|
|
3. Pass the mbarrier pointer to `cluster_reduce()`
|
|
|
|
The cluster_reduce function handles:
|
|
- Setting up the expected transaction count
|
|
- Performing async cross-CTA stores
|
|
- Waiting for all stores to complete
|
|
|
|
Usage Example
|
|
-------------
|
|
|
|
.. code-block:: python
|
|
|
|
from reduce import row_reduce, block_reduce, cluster_reduce
|
|
|
|
@cute.jit
|
|
def my_kernel(...):
|
|
# Allocate shared memory for reduction
|
|
# Shape depends on warps_per_row and cluster_n
|
|
if cluster_n > 1:
|
|
reduction_buffer = cute.make_smem_tensor(
|
|
cute.make_layout((rows_per_block, (warps_per_row, cluster_n))),
|
|
Float32
|
|
)
|
|
else:
|
|
reduction_buffer = cute.make_smem_tensor(
|
|
cute.make_layout((rows_per_block, warps_per_row)),
|
|
Float32
|
|
)
|
|
|
|
# Perform row reduction
|
|
result = row_reduce(
|
|
tensor_ssa,
|
|
cute.ReductionOp.ADD,
|
|
threads_per_row,
|
|
reduction_buffer,
|
|
mbar_ptr,
|
|
cluster_n,
|
|
init_val=Float32(0.0)
|
|
)
|
|
|
|
References
|
|
----------
|
|
The cluster synchronization primitives (set_block_rank, store_shared_remote)
|
|
are inspired by Quack: https://github.com/Dao-AILab/quack
|
|
"""
|
|
|
|
import operator
|
|
from collections.abc import Callable
|
|
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
from cutlass import Float32, Int32
|
|
from cutlass._mlir.dialects import llvm
|
|
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
|
|
|
|
# =============================================================================
|
|
# Inline PTX Operations for Cluster Communication
|
|
# =============================================================================
|
|
#
|
|
# These operations enable cross-CTA communication within a cluster (SM90+).
|
|
# They use inline PTX assembly for functionality not yet exposed in MLIR.
|
|
# =============================================================================
|
|
|
|
|
|
@dsl_user_op
|
|
def set_block_rank(
|
|
smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None
|
|
) -> Int32:
|
|
"""
|
|
Map a shared memory pointer to the equivalent address in another CTA's
|
|
shared memory within the same cluster.
|
|
|
|
This uses the PTX `mapa.shared::cluster` instruction to translate a local
|
|
shared memory address to the corresponding address in a peer CTA's shared
|
|
memory space.
|
|
|
|
Args:
|
|
smem_ptr: Pointer to local shared memory
|
|
peer_cta_rank_in_cluster: Target CTA's rank within the cluster (0 to cluster_size-1)
|
|
|
|
Returns:
|
|
Int32 representing the mapped address in the peer CTA's shared memory
|
|
|
|
Note:
|
|
This operation requires SM90+ with cluster support enabled.
|
|
The cluster must be launched with the appropriate cluster dimensions.
|
|
"""
|
|
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
|
return Int32(
|
|
llvm.inline_asm(
|
|
T.i32(),
|
|
[smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()],
|
|
"mapa.shared::cluster.u32 $0, $1, $2;",
|
|
"=r,r,r",
|
|
has_side_effects=False,
|
|
is_align_stack=False,
|
|
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
)
|
|
)
|
|
|
|
|
|
@dsl_user_op
|
|
def store_shared_remote(
|
|
val: Float32,
|
|
smem_ptr: cute.Pointer,
|
|
mbar_ptr: cute.Pointer,
|
|
peer_cta_rank_in_cluster: Int32,
|
|
*,
|
|
loc=None,
|
|
ip=None,
|
|
) -> None:
|
|
"""
|
|
Asynchronously store a Float32 value to shared memory on a remote CTA
|
|
within the cluster, with mbarrier completion tracking.
|
|
|
|
This uses the PTX `st.async.shared::cluster` instruction which:
|
|
1. Translates the local smem address to the peer CTA's address space
|
|
2. Performs an asynchronous store to the remote shared memory
|
|
3. Signals the mbarrier when the store completes
|
|
|
|
Args:
|
|
val: The Float32 value to store
|
|
smem_ptr: Pointer to the destination in local shared memory coordinates
|
|
mbar_ptr: Pointer to the mbarrier that tracks completion
|
|
peer_cta_rank_in_cluster: Target CTA's rank within the cluster
|
|
|
|
Note:
|
|
- The mbarrier must be initialized with the expected transaction byte count
|
|
- Use `cute.arch.mbarrier_arrive_and_expect_tx()` to set up the transaction
|
|
- Use `cute.arch.mbarrier_wait()` to wait for all stores to complete
|
|
- This operation requires SM90+ with cluster support enabled
|
|
"""
|
|
remote_smem_ptr_i32 = set_block_rank(
|
|
smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
|
).ir_value()
|
|
remote_mbar_ptr_i32 = set_block_rank(
|
|
mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
|
).ir_value()
|
|
llvm.inline_asm(
|
|
None,
|
|
[remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
|
|
"st.async.shared::cluster.mbarrier::complete_tx::bytes.f32 [$0], $1, [$2];",
|
|
"r,f,r",
|
|
has_side_effects=True,
|
|
is_align_stack=False,
|
|
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
)
|
|
|
|
|
|
@dsl_user_op
|
|
def elem_pointer(x: cute.Tensor, coord, *, loc=None, ip=None) -> cute.Pointer:
|
|
"""
|
|
Get a pointer to an element at the specified coordinate in a tensor.
|
|
|
|
This is useful for getting the shared memory address of a specific element
|
|
when performing cross-CTA stores in cluster reduction.
|
|
|
|
Args:
|
|
x: The tensor (typically a shared memory tensor)
|
|
coord: The coordinate tuple, can be hierarchical like (row, (col, cluster_idx))
|
|
|
|
Returns:
|
|
Pointer to the element at the specified coordinate
|
|
"""
|
|
return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
|
|
|
|
|
|
# =============================================================================
|
|
# Block-Level Reduction
|
|
# =============================================================================
|
|
|
|
|
|
@cute.jit
|
|
def block_reduce(
|
|
val: Float32,
|
|
op: Callable,
|
|
reduction_buffer: cute.Tensor,
|
|
init_val: Float32,
|
|
) -> Float32:
|
|
"""
|
|
Reduce values across all warps within a thread block using shared memory.
|
|
|
|
This function assumes each warp has already performed a warp-level reduction
|
|
and is contributing a single value (from lane 0). The function then:
|
|
1. Writes each warp's value to shared memory
|
|
2. Synchronizes the block
|
|
3. Performs a final warp reduction across the collected values
|
|
|
|
Args:
|
|
val: The warp-reduced value (only lane 0's value is used)
|
|
op: Binary reduction operator, e.g., `operator.add` or `cute.arch.fmax`
|
|
reduction_buffer: Shared memory tensor with shape (rows_per_block, warps_per_row)
|
|
init_val: Identity element for the reduction (0 for sum, -inf for max)
|
|
|
|
Returns:
|
|
The block-reduced result (same value across all threads)
|
|
|
|
Buffer Layout:
|
|
- Shape: (rows_per_block, warps_per_row)
|
|
- warps_per_row is inferred from reduction_buffer.shape[1]
|
|
- Thread mapping:
|
|
row_idx = warp_idx // warps_per_row
|
|
col_idx = warp_idx % warps_per_row
|
|
|
|
Example:
|
|
For a block with 8 warps processing 2 rows (4 warps per row):
|
|
|
|
.. code-block:: python
|
|
|
|
reduction_buffer = cute.make_smem_tensor(
|
|
cute.make_layout((2, 4)), # 2 rows, 4 warps per row
|
|
Float32
|
|
)
|
|
result = block_reduce(warp_val, operator.add, reduction_buffer, Float32(0.0))
|
|
"""
|
|
lane_idx = cute.arch.lane_idx()
|
|
warp_idx = cute.arch.warp_idx()
|
|
warps_per_row = cute.size(reduction_buffer.shape[1])
|
|
row_idx = warp_idx // warps_per_row
|
|
col_idx = warp_idx % warps_per_row
|
|
|
|
# Lane 0 of each warp writes its value to shared memory
|
|
if lane_idx == 0:
|
|
reduction_buffer[row_idx, col_idx] = val
|
|
cute.arch.barrier()
|
|
|
|
# All lanes participate in reading and reducing
|
|
# Only lanes < warps_per_row have valid data
|
|
block_reduce_val = init_val
|
|
if lane_idx < warps_per_row:
|
|
block_reduce_val = reduction_buffer[row_idx, lane_idx]
|
|
return cute.arch.warp_reduction(block_reduce_val, op)
|
|
|
|
|
|
# =============================================================================
|
|
# Cluster-Level Reduction (SM90+)
|
|
# =============================================================================
|
|
|
|
|
|
@cute.jit
|
|
def cluster_reduce(
|
|
val: Float32,
|
|
op: Callable,
|
|
reduction_buffer: cute.Tensor,
|
|
mbar_ptr: cute.Pointer,
|
|
cluster_n: cutlass.Constexpr[int],
|
|
init_val: Float32,
|
|
) -> Float32:
|
|
"""
|
|
Reduce values across all CTAs within a cluster using distributed shared memory.
|
|
|
|
This function extends block reduction to work across multiple CTAs in a cluster
|
|
using asynchronous cross-CTA stores and mbarrier synchronization. It:
|
|
1. Sets up the mbarrier with expected transaction count
|
|
2. Asynchronously stores each warp's value to all peer CTAs
|
|
3. Waits for all stores to complete
|
|
4. Reduces across all collected values
|
|
|
|
Args:
|
|
val: The warp-reduced value (only lane 0's value is used for stores)
|
|
op: Binary reduction operator, e.g., `operator.add` or `cute.arch.fmax`
|
|
reduction_buffer: Shared memory tensor with hierarchical shape
|
|
(rows_per_block, (warps_per_row, cluster_n))
|
|
mbar_ptr: Pointer to an initialized mbarrier in shared memory
|
|
cluster_n: Number of CTAs in the cluster (compile-time constant)
|
|
init_val: Identity element for the reduction (0 for sum, -inf for max)
|
|
|
|
Returns:
|
|
The cluster-reduced result (same value across all threads in all CTAs)
|
|
|
|
Buffer Layout:
|
|
- Shape: (rows_per_block, (warps_per_row, cluster_n))
|
|
- The second dimension is hierarchical:
|
|
- First level: warps_per_row (local warp slots)
|
|
- Second level: cluster_n (one slot per CTA in cluster)
|
|
- Access pattern: buffer[row_idx, (col_idx, cta_rank)]
|
|
|
|
Requirements:
|
|
- SM90+ with cluster support
|
|
- Mbarrier must be initialized before calling
|
|
- Kernel must be launched with appropriate cluster dimensions
|
|
|
|
Example:
|
|
For a cluster of 4 CTAs, each with 2 warps per row:
|
|
|
|
.. code-block:: python
|
|
|
|
# Allocate buffer with cluster dimension
|
|
reduction_buffer = cute.make_smem_tensor(
|
|
cute.make_layout((rows_per_block, (2, 4))), # 2 warps, 4 CTAs
|
|
Float32
|
|
)
|
|
|
|
# Initialize mbarrier (once per kernel)
|
|
mbar = cute.make_smem_tensor(cute.make_layout((1,)), cute.arch.Mbarrier)
|
|
cute.arch.mbarrier_init(mbar.iterator, thread_count)
|
|
|
|
# Perform cluster reduction
|
|
result = cluster_reduce(
|
|
warp_val, operator.add, reduction_buffer,
|
|
mbar.iterator, cluster_n=4, init_val=Float32(0.0)
|
|
)
|
|
"""
|
|
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
lane_idx = cute.arch.lane_idx()
|
|
warp_idx = cute.arch.warp_idx()
|
|
|
|
rows_per_block = reduction_buffer.shape[0]
|
|
warps_per_row = reduction_buffer.shape[1][0]
|
|
|
|
row_idx = warp_idx // warps_per_row
|
|
col_idx = warp_idx % warps_per_row
|
|
|
|
# Warp 0, lane 0 sets up mbarrier with expected transaction count
|
|
# Each warp sends cluster_n stores (one to each CTA), each store is 4 bytes
|
|
if warp_idx == 0:
|
|
with cute.arch.elect_one():
|
|
num_warps = rows_per_block * warps_per_row
|
|
expected_bytes = num_warps * cluster_n * 4 # 4 bytes per Float32
|
|
cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr, expected_bytes)
|
|
|
|
# Each lane < cluster_n writes to a different CTA's shared memory
|
|
# This distributes the warp's value to all CTAs in the cluster
|
|
if lane_idx < cluster_n:
|
|
store_shared_remote(
|
|
val,
|
|
elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
|
|
mbar_ptr,
|
|
peer_cta_rank_in_cluster=lane_idx,
|
|
)
|
|
|
|
# Wait for all cross-CTA stores to complete
|
|
cute.arch.mbarrier_wait(mbar_ptr, phase=0)
|
|
|
|
# Now each CTA has all values from all CTAs in the cluster
|
|
# Reduce across all collected values
|
|
num_total = warps_per_row * cluster_n
|
|
num_iter = cute.ceil_div(num_total, 32)
|
|
|
|
block_reduce_val = init_val
|
|
for i in cutlass.range_constexpr(num_iter):
|
|
idx = lane_idx + i * 32
|
|
if idx < num_total:
|
|
block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx])
|
|
|
|
return cute.arch.warp_reduction(block_reduce_val, op)
|
|
|
|
|
|
# =============================================================================
|
|
# Row Reduction (Orchestration Function)
|
|
# =============================================================================
|
|
|
|
|
|
@cute.jit
|
|
def row_reduce(
|
|
x: cute.TensorSSA,
|
|
op: cute.ReductionOp,
|
|
threads_per_row: cutlass.Constexpr[int],
|
|
reduction_buffer: cute.Tensor,
|
|
mbar_ptr,
|
|
cluster_n: cutlass.Constexpr[int],
|
|
init_val: Float32,
|
|
):
|
|
"""
|
|
Perform hierarchical row reduction with automatic selection of reduction strategy.
|
|
|
|
This function orchestrates the full reduction pipeline:
|
|
1. Local reduction: Each thread reduces its portion of the row
|
|
2. Warp reduction: Threads within a warp reduce using shuffles
|
|
3. Block reduction: If needed, warps reduce using shared memory
|
|
4. Cluster reduction: If needed, CTAs reduce using distributed shared memory
|
|
|
|
The function automatically selects the appropriate reduction level based on
|
|
`threads_per_row` and `cluster_n`.
|
|
|
|
Args:
|
|
x: TensorSSA containing the values to reduce (in registers)
|
|
op: Reduction operation (cute.ReductionOp.ADD or cute.ReductionOp.MAX)
|
|
threads_per_row: Number of threads cooperating on each row (compile-time)
|
|
reduction_buffer: Shared memory tensor for block/cluster reduction
|
|
mbar_ptr: Mbarrier pointer (only used if cluster_n > 1)
|
|
cluster_n: Number of CTAs in cluster (1 for single-CTA reduction)
|
|
init_val: Identity element for the reduction
|
|
|
|
Returns:
|
|
The fully reduced result for each row
|
|
|
|
Reduction Strategy Selection:
|
|
- threads_per_row <= 32, cluster_n == 1: Warp reduction only
|
|
- threads_per_row > 32, cluster_n == 1: Warp + block reduction
|
|
- cluster_n > 1: Warp + cluster reduction (handles all cases)
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
# Sum reduction across 128 threads per row, single CTA
|
|
result = row_reduce(
|
|
tensor_ssa,
|
|
cute.ReductionOp.ADD,
|
|
threads_per_row=128,
|
|
reduction_buffer=smem_buffer,
|
|
mbar_ptr=None,
|
|
cluster_n=1,
|
|
init_val=Float32(0.0)
|
|
)
|
|
|
|
# Max reduction across 256 threads per row, 4 CTAs in cluster
|
|
result = row_reduce(
|
|
tensor_ssa,
|
|
cute.ReductionOp.MAX,
|
|
threads_per_row=256,
|
|
reduction_buffer=smem_buffer,
|
|
mbar_ptr=mbar.iterator,
|
|
cluster_n=4,
|
|
init_val=Float32.neg_inf
|
|
)
|
|
"""
|
|
# Step 1: Local reduction - each thread reduces its register values
|
|
local_val = x.reduce(op, init_val=init_val, reduction_profile=0)
|
|
|
|
# Map ReductionOp enum to binary operator for warp/block reductions
|
|
warp_op = {
|
|
cute.ReductionOp.ADD: operator.add,
|
|
cute.ReductionOp.MAX: cute.arch.fmax,
|
|
}[op]
|
|
|
|
# Step 2: Warp reduction
|
|
# If threads_per_row < 32, only use that many threads in the reduction
|
|
warp_width = min(threads_per_row, 32)
|
|
warp_val = cute.arch.warp_reduction(local_val, warp_op, threads_in_group=warp_width)
|
|
|
|
# Determine if we need additional reduction levels
|
|
warps_per_row = max(threads_per_row // 32, 1)
|
|
|
|
# Step 3 & 4: Block or cluster reduction (if needed)
|
|
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
if cutlass.const_expr(cluster_n == 1):
|
|
# Single CTA: use block reduction
|
|
return block_reduce(warp_val, warp_op, reduction_buffer, init_val)
|
|
else:
|
|
# Multiple CTAs: use cluster reduction
|
|
return cluster_reduce(
|
|
warp_val, warp_op, reduction_buffer, mbar_ptr, cluster_n, init_val
|
|
)
|
|
else:
|
|
# Single warp handles entire row: warp reduction is sufficient
|
|
return warp_val
|
|
|