Files
cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/reduce/reduce.py
Junkai-Wu cb37157db5 v4.5 tag update (#3202)
* Python DSL examples reorganization.

* v4.5 tag update.
2026-05-05 20:55:27 -04:00

534 lines
20 KiB
Python

# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
Hierarchical Reduction Utilities for CuTe-DSL Kernels
=====================================================
This module provides reusable reduction primitives for GPU kernels that need to
reduce values across warps, thread blocks, and clusters (SM90+).
Overview
--------
GPU reductions typically follow a hierarchical pattern:
1. **Warp Reduction**: Threads within a warp reduce using shuffle instructions.
Use `cute.arch.warp_reduction()` from the CuTe-DSL library.
2. **Block Reduction**: Multiple warps within a block reduce using shared memory.
Use `block_reduce()` from this module.
3. **Cluster Reduction** (SM90+): Multiple CTAs in a cluster reduce using
distributed shared memory and mbarrier synchronization.
Use `cluster_reduce()` from this module.
4. **Row Reduction**: Orchestrates all levels based on problem configuration.
Use `row_reduce()` from this module.
Shared Memory Buffer Layout Assumptions
---------------------------------------
For `block_reduce`:
- Buffer shape: (rows_per_block, warps_per_row)
- Each warp's lane 0 writes its reduced value to buffer[row_idx, col_idx]
- Thread mapping: row_idx = warp_idx // warps_per_row
col_idx = warp_idx % warps_per_row
Example for 8 warps, 2 rows, 4 warps per row:
Warp 0 -> buffer[0, 0] Warp 4 -> buffer[1, 0]
Warp 1 -> buffer[0, 1] Warp 5 -> buffer[1, 1]
Warp 2 -> buffer[0, 2] Warp 6 -> buffer[1, 2]
Warp 3 -> buffer[0, 3] Warp 7 -> buffer[1, 3]
For `cluster_reduce`:
- Buffer shape: (rows_per_block, (warps_per_row, cluster_n))
- The second dimension is hierarchical: (local_warp_slot, cta_rank)
- Each CTA contributes to its own slot in the cluster dimension
Example for cluster_n=4, 2 warps per row:
CTA 0, Warp 0 -> buffer[row, (0, 0)]
CTA 0, Warp 1 -> buffer[row, (1, 0)]
CTA 1, Warp 0 -> buffer[row, (0, 1)]
CTA 1, Warp 1 -> buffer[row, (1, 1)]
... etc for CTAs 2, 3
Mbarrier Requirements (Cluster Reduction)
-----------------------------------------
For cluster reduction, the caller must:
1. Allocate an mbarrier in shared memory
2. Initialize it with `cute.arch.mbarrier_init(mbar_ptr, thread_count)`
3. Pass the mbarrier pointer to `cluster_reduce()`
The cluster_reduce function handles:
- Setting up the expected transaction count
- Performing async cross-CTA stores
- Waiting for all stores to complete
Usage Example
-------------
.. code-block:: python
from reduce import row_reduce, block_reduce, cluster_reduce
@cute.jit
def my_kernel(...):
# Allocate shared memory for reduction
# Shape depends on warps_per_row and cluster_n
if cluster_n > 1:
reduction_buffer = cute.make_smem_tensor(
cute.make_layout((rows_per_block, (warps_per_row, cluster_n))),
Float32
)
else:
reduction_buffer = cute.make_smem_tensor(
cute.make_layout((rows_per_block, warps_per_row)),
Float32
)
# Perform row reduction
result = row_reduce(
tensor_ssa,
cute.ReductionOp.ADD,
threads_per_row,
reduction_buffer,
mbar_ptr,
cluster_n,
init_val=Float32(0.0)
)
References
----------
The cluster synchronization primitives (set_block_rank, store_shared_remote)
are inspired by Quack: https://github.com/Dao-AILab/quack
"""
import operator
from collections.abc import Callable
import cutlass
import cutlass.cute as cute
from cutlass import Float32, Int32
from cutlass._mlir.dialects import llvm
from cutlass.cutlass_dsl import T, dsl_user_op
# =============================================================================
# Inline PTX Operations for Cluster Communication
# =============================================================================
#
# These operations enable cross-CTA communication within a cluster (SM90+).
# They use inline PTX assembly for functionality not yet exposed in MLIR.
# =============================================================================
@dsl_user_op
def set_block_rank(
smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None
) -> Int32:
"""
Map a shared memory pointer to the equivalent address in another CTA's
shared memory within the same cluster.
This uses the PTX `mapa.shared::cluster` instruction to translate a local
shared memory address to the corresponding address in a peer CTA's shared
memory space.
Args:
smem_ptr: Pointer to local shared memory
peer_cta_rank_in_cluster: Target CTA's rank within the cluster (0 to cluster_size-1)
Returns:
Int32 representing the mapped address in the peer CTA's shared memory
Note:
This operation requires SM90+ with cluster support enabled.
The cluster must be launched with the appropriate cluster dimensions.
"""
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
return Int32(
llvm.inline_asm(
T.i32(),
[smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()],
"mapa.shared::cluster.u32 $0, $1, $2;",
"=r,r,r",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)
@dsl_user_op
def store_shared_remote(
val: Float32,
smem_ptr: cute.Pointer,
mbar_ptr: cute.Pointer,
peer_cta_rank_in_cluster: Int32,
*,
loc=None,
ip=None,
) -> None:
"""
Asynchronously store a Float32 value to shared memory on a remote CTA
within the cluster, with mbarrier completion tracking.
This uses the PTX `st.async.shared::cluster` instruction which:
1. Translates the local smem address to the peer CTA's address space
2. Performs an asynchronous store to the remote shared memory
3. Signals the mbarrier when the store completes
Args:
val: The Float32 value to store
smem_ptr: Pointer to the destination in local shared memory coordinates
mbar_ptr: Pointer to the mbarrier that tracks completion
peer_cta_rank_in_cluster: Target CTA's rank within the cluster
Note:
- The mbarrier must be initialized with the expected transaction byte count
- Use `cute.arch.mbarrier_arrive_and_expect_tx()` to set up the transaction
- Use `cute.arch.mbarrier_wait()` to wait for all stores to complete
- This operation requires SM90+ with cluster support enabled
"""
remote_smem_ptr_i32 = set_block_rank(
smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
).ir_value()
remote_mbar_ptr_i32 = set_block_rank(
mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
).ir_value()
llvm.inline_asm(
None,
[remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
"st.async.shared::cluster.mbarrier::complete_tx::bytes.f32 [$0], $1, [$2];",
"r,f,r",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
@dsl_user_op
def elem_pointer(x: cute.Tensor, coord, *, loc=None, ip=None) -> cute.Pointer:
"""
Get a pointer to an element at the specified coordinate in a tensor.
This is useful for getting the shared memory address of a specific element
when performing cross-CTA stores in cluster reduction.
Args:
x: The tensor (typically a shared memory tensor)
coord: The coordinate tuple, can be hierarchical like (row, (col, cluster_idx))
Returns:
Pointer to the element at the specified coordinate
"""
return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
# =============================================================================
# Block-Level Reduction
# =============================================================================
@cute.jit
def block_reduce(
val: Float32,
op: Callable,
reduction_buffer: cute.Tensor,
init_val: Float32,
) -> Float32:
"""
Reduce values across all warps within a thread block using shared memory.
This function assumes each warp has already performed a warp-level reduction
and is contributing a single value (from lane 0). The function then:
1. Writes each warp's value to shared memory
2. Synchronizes the block
3. Performs a final warp reduction across the collected values
Args:
val: The warp-reduced value (only lane 0's value is used)
op: Binary reduction operator, e.g., `operator.add` or `cute.arch.fmax`
reduction_buffer: Shared memory tensor with shape (rows_per_block, warps_per_row)
init_val: Identity element for the reduction (0 for sum, -inf for max)
Returns:
The block-reduced result (same value across all threads)
Buffer Layout:
- Shape: (rows_per_block, warps_per_row)
- warps_per_row is inferred from reduction_buffer.shape[1]
- Thread mapping:
row_idx = warp_idx // warps_per_row
col_idx = warp_idx % warps_per_row
Example:
For a block with 8 warps processing 2 rows (4 warps per row):
.. code-block:: python
reduction_buffer = cute.make_smem_tensor(
cute.make_layout((2, 4)), # 2 rows, 4 warps per row
Float32
)
result = block_reduce(warp_val, operator.add, reduction_buffer, Float32(0.0))
"""
lane_idx = cute.arch.lane_idx()
warp_idx = cute.arch.warp_idx()
warps_per_row = cute.size(reduction_buffer.shape[1])
row_idx = warp_idx // warps_per_row
col_idx = warp_idx % warps_per_row
# Lane 0 of each warp writes its value to shared memory
if lane_idx == 0:
reduction_buffer[row_idx, col_idx] = val
cute.arch.barrier()
# All lanes participate in reading and reducing
# Only lanes < warps_per_row have valid data
block_reduce_val = init_val
if lane_idx < warps_per_row:
block_reduce_val = reduction_buffer[row_idx, lane_idx]
return cute.arch.warp_reduction(block_reduce_val, op)
# =============================================================================
# Cluster-Level Reduction (SM90+)
# =============================================================================
@cute.jit
def cluster_reduce(
val: Float32,
op: Callable,
reduction_buffer: cute.Tensor,
mbar_ptr: cute.Pointer,
cluster_n: cutlass.Constexpr[int],
init_val: Float32,
) -> Float32:
"""
Reduce values across all CTAs within a cluster using distributed shared memory.
This function extends block reduction to work across multiple CTAs in a cluster
using asynchronous cross-CTA stores and mbarrier synchronization. It:
1. Sets up the mbarrier with expected transaction count
2. Asynchronously stores each warp's value to all peer CTAs
3. Waits for all stores to complete
4. Reduces across all collected values
Args:
val: The warp-reduced value (only lane 0's value is used for stores)
op: Binary reduction operator, e.g., `operator.add` or `cute.arch.fmax`
reduction_buffer: Shared memory tensor with hierarchical shape
(rows_per_block, (warps_per_row, cluster_n))
mbar_ptr: Pointer to an initialized mbarrier in shared memory
cluster_n: Number of CTAs in the cluster (compile-time constant)
init_val: Identity element for the reduction (0 for sum, -inf for max)
Returns:
The cluster-reduced result (same value across all threads in all CTAs)
Buffer Layout:
- Shape: (rows_per_block, (warps_per_row, cluster_n))
- The second dimension is hierarchical:
- First level: warps_per_row (local warp slots)
- Second level: cluster_n (one slot per CTA in cluster)
- Access pattern: buffer[row_idx, (col_idx, cta_rank)]
Requirements:
- SM90+ with cluster support
- Mbarrier must be initialized before calling
- Kernel must be launched with appropriate cluster dimensions
Example:
For a cluster of 4 CTAs, each with 2 warps per row:
.. code-block:: python
# Allocate buffer with cluster dimension
reduction_buffer = cute.make_smem_tensor(
cute.make_layout((rows_per_block, (2, 4))), # 2 warps, 4 CTAs
Float32
)
# Initialize mbarrier (once per kernel)
mbar = cute.make_smem_tensor(cute.make_layout((1,)), cute.arch.Mbarrier)
cute.arch.mbarrier_init(mbar.iterator, thread_count)
# Perform cluster reduction
result = cluster_reduce(
warp_val, operator.add, reduction_buffer,
mbar.iterator, cluster_n=4, init_val=Float32(0.0)
)
"""
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
lane_idx = cute.arch.lane_idx()
warp_idx = cute.arch.warp_idx()
rows_per_block = reduction_buffer.shape[0]
warps_per_row = reduction_buffer.shape[1][0]
row_idx = warp_idx // warps_per_row
col_idx = warp_idx % warps_per_row
# Warp 0, lane 0 sets up mbarrier with expected transaction count
# Each warp sends cluster_n stores (one to each CTA), each store is 4 bytes
if warp_idx == 0:
with cute.arch.elect_one():
num_warps = rows_per_block * warps_per_row
expected_bytes = num_warps * cluster_n * 4 # 4 bytes per Float32
cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr, expected_bytes)
# Each lane < cluster_n writes to a different CTA's shared memory
# This distributes the warp's value to all CTAs in the cluster
if lane_idx < cluster_n:
store_shared_remote(
val,
elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
mbar_ptr,
peer_cta_rank_in_cluster=lane_idx,
)
# Wait for all cross-CTA stores to complete
cute.arch.mbarrier_wait(mbar_ptr, phase=0)
# Now each CTA has all values from all CTAs in the cluster
# Reduce across all collected values
num_total = warps_per_row * cluster_n
num_iter = cute.ceil_div(num_total, 32)
block_reduce_val = init_val
for i in cutlass.range_constexpr(num_iter):
idx = lane_idx + i * 32
if idx < num_total:
block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx])
return cute.arch.warp_reduction(block_reduce_val, op)
# =============================================================================
# Row Reduction (Orchestration Function)
# =============================================================================
@cute.jit
def row_reduce(
x: cute.TensorSSA,
op: cute.ReductionOp,
threads_per_row: cutlass.Constexpr[int],
reduction_buffer: cute.Tensor,
mbar_ptr,
cluster_n: cutlass.Constexpr[int],
init_val: Float32,
):
"""
Perform hierarchical row reduction with automatic selection of reduction strategy.
This function orchestrates the full reduction pipeline:
1. Local reduction: Each thread reduces its portion of the row
2. Warp reduction: Threads within a warp reduce using shuffles
3. Block reduction: If needed, warps reduce using shared memory
4. Cluster reduction: If needed, CTAs reduce using distributed shared memory
The function automatically selects the appropriate reduction level based on
`threads_per_row` and `cluster_n`.
Args:
x: TensorSSA containing the values to reduce (in registers)
op: Reduction operation (cute.ReductionOp.ADD or cute.ReductionOp.MAX)
threads_per_row: Number of threads cooperating on each row (compile-time)
reduction_buffer: Shared memory tensor for block/cluster reduction
mbar_ptr: Mbarrier pointer (only used if cluster_n > 1)
cluster_n: Number of CTAs in cluster (1 for single-CTA reduction)
init_val: Identity element for the reduction
Returns:
The fully reduced result for each row
Reduction Strategy Selection:
- threads_per_row <= 32, cluster_n == 1: Warp reduction only
- threads_per_row > 32, cluster_n == 1: Warp + block reduction
- cluster_n > 1: Warp + cluster reduction (handles all cases)
Example:
.. code-block:: python
# Sum reduction across 128 threads per row, single CTA
result = row_reduce(
tensor_ssa,
cute.ReductionOp.ADD,
threads_per_row=128,
reduction_buffer=smem_buffer,
mbar_ptr=None,
cluster_n=1,
init_val=Float32(0.0)
)
# Max reduction across 256 threads per row, 4 CTAs in cluster
result = row_reduce(
tensor_ssa,
cute.ReductionOp.MAX,
threads_per_row=256,
reduction_buffer=smem_buffer,
mbar_ptr=mbar.iterator,
cluster_n=4,
init_val=Float32.neg_inf
)
"""
# Step 1: Local reduction - each thread reduces its register values
local_val = x.reduce(op, init_val=init_val, reduction_profile=0)
# Map ReductionOp enum to binary operator for warp/block reductions
warp_op = {
cute.ReductionOp.ADD: operator.add,
cute.ReductionOp.MAX: cute.arch.fmax,
}[op]
# Step 2: Warp reduction
# If threads_per_row < 32, only use that many threads in the reduction
warp_width = min(threads_per_row, 32)
warp_val = cute.arch.warp_reduction(local_val, warp_op, threads_in_group=warp_width)
# Determine if we need additional reduction levels
warps_per_row = max(threads_per_row // 32, 1)
# Step 3 & 4: Block or cluster reduction (if needed)
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
if cutlass.const_expr(cluster_n == 1):
# Single CTA: use block reduction
return block_reduce(warp_val, warp_op, reduction_buffer, init_val)
else:
# Multiple CTAs: use cluster reduction
return cluster_reduce(
warp_val, warp_op, reduction_buffer, mbar_ptr, cluster_n, init_val
)
else:
# Single warp handles entire row: warp reduction is sufficient
return warp_val