Files
Junkai-Wu cb37157db5 v4.5 tag update (#3202)
* Python DSL examples reorganization.

* v4.5 tag update.
2026-05-05 20:55:27 -04:00

1390 lines
65 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import torch
from typing import Type, Tuple
import cutlass
from cutlass.cute import experimental as cute_ext
from cutlass.base_dsl.typing import Numeric, Constexpr
from cutlass import cute as cute
from cutlass import utils
from cutlass import torch as cutlass_torch
import cutlass.utils.blackwell_helpers as sm100_utils
import cutlass.cute.testing as testing
# ====================================================================================================
#
# This kernel implements a batched dense GEMM operation: D = A @ B
# where:
# - A has shape (M, K, L) and is stored in global memory
# - B has shape (N, K, L) and is stored in global memory
# - D has shape (M, N, L) and is the output in global memory
# - L is the batch dimension
#
# The kernel uses the LIR (Low-level Intermediate Representation) DSL which is a Python DSL
# for writing high-performance, Blackwell (SM100)-targeted kernels on top of CuTe abstractions.
#
# KEY CONCEPTS:
# - TMA (Tensor Memory Accelerator): Hardware feature for high-bandwidth GMEM <-> SMEM transfers
# - UMMA/MMA: Unified Matrix Multiply-Accumulate hardware units on SM100
# - TMEM: Tensor Memory - Blackwell's specialized memory for MMA accumulators
# - SMEM: Shared Memory - CTA-local memory for staging data
# - RMEM: Register Memory - Per-thread registers
#
# DATA FLOW:
# GMEM (A,B) --TMA--> SMEM (bufferA, bufferB) --MMA--> TMEM (accumulators)
# TMEM --copy--> RMEM (bufferRAcc) --epilogue--> RMEM (bufferRD) --copy--> SMEM (bufferC) --TMA--> GMEM (D)
#
# WARP SPECIALIZATION:
# This kernel uses 6 warps (192 threads) with specialized roles:
# - Warp 5: TMA load producer (loads A, B tiles from GMEM to SMEM)
# - Warp 4: MMA compute (performs matrix multiply-accumulate)
# - Warps 0-3: Epilogue (TMEM->RMEM->SMEM) and TMA store (warp 0 only)
#
# PIPELINE ARCHITECTURE:
# The kernel uses software pipelining to overlap memory transfers with compute:
# - mainloop_pipe: TMAToUMMAPipeline - synchronizes TMA loads with MMA operations
# - acc_pipe: UMMAtoAsyncPipeline - synchronizes MMA with TMEM->RMEM copies
# - tma_store_pipe: TMAStorePipeline - synchronizes SMEM writes with TMA stores
#
# ====================================================================================================
# ====================================================================================================
# KERNEL CLASS DEFINITION
# ====================================================================================================
class DenseGemmKernel:
"""
Dense GEMM kernel class for Blackwell (SM100) GPUs.
This class encapsulates all the configuration and logic for a high-performance
batched matrix multiplication: D = A @ B (with optional epilogue operation).
The design follows LIR conventions:
1. __init__: Store configuration parameters
2. __call__: JIT-decorated host launcher that computes grid and calls kernel
3. kernel: Device kernel that performs the actual computation
Attributes:
mn_tiler (tuple[int, int]): Tile sizes for M and N dimensions (e.g., (128, 256))
ab_dtype (Type[Numeric]): Data type for input matrices A and B (e.g., Float16)
acc_dtype (Type[Numeric]): Data type for accumulators (typically Float32)
tmem_output_dtype (Type[Numeric]): Data type for TMEM->RMEM copy output
use_2cta_instrs (bool): Whether to use 2-CTA MMA instructions (False = 1-CTA mode)
TMA_STORE_STAGE (int): Number of pipeline stages for TMA store operations
epilogue_op (callable): Optional epilogue function applied to output (default: identity)
"""
def __init__(
self,
mn_tiler: tuple[int, int],
mma_dtype: tuple[Type[Numeric], Type[Numeric]],
tmem_output_dtype: Type[Numeric],
epilogue_op=lambda x: x,
):
"""
Initialize the Dense GEMM kernel configuration.
Args:
mn_tiler: Tuple (M_tile, N_tile) specifying the tile dimensions.
CONSTRAINT: M must be 64 or 128 (SM100 hardware requirement).
Common configurations: (128, 256), (128, 128), (64, 128)
mma_dtype: Tuple (input_dtype, accumulator_dtype)
- input_dtype: Element type for A and B (e.g., Float16, Float8E4M3FN)
- accumulator_dtype: Precision for accumulation (typically Float32)
tmem_output_dtype: Element type for TMEM output during epilogue.
Typically matches the output matrix type.
epilogue_op: Optional function applied to accumulator values before store.
Default is identity (lambda x: x).
Examples: relu, sigmoid, GELU approximations using cute.exp/cute.where
"""
self.mn_tiler = mn_tiler
self.ab_dtype, self.acc_dtype = mma_dtype
self.tmem_output_dtype = tmem_output_dtype
self.use_2cta_instrs = False
# Number of pipeline stages for TMA store operations.
# More stages = better latency hiding, but more SMEM usage.
self.TMA_STORE_STAGE = 4
# Epilogue operation applied in registers before storing output.
self.epilogue_op = epilogue_op
# ================================================================================================
# JIT-DECORATED HOST LAUNCHER
# ================================================================================================
@cute.experimental.jit
def __call__(self, mA: cute.Tensor, mB: cute.Tensor, mD: cute.Tensor):
"""
Host-side JIT-compiled launcher function.
The @cute.experimental.jit decorator indicates this function:
- Runs on the HOST (CPU)
- Is JIT-compiled when first called
- Computes launch configuration and invokes the GPU kernel
This function performs two key tasks:
1. Compute the grid dimensions based on output tensor shape and tile size
2. Launch the kernel with appropriate grid/block/cluster/smem configuration
Args:
mA: Input tensor A in global memory, shape (M, K, L) where L is batch
mB: Input tensor B in global memory, shape (N, K, L)
mD: Output tensor D in global memory, shape (M, N, L)
CUTE ALGEBRA EXPLANATION - tiled_divide:
-----------------------------------------
cute.tiled_divide(tensor, tiler) divides a tensor into tiles, producing a tensor
with shape: ((Tile), Rest_M, Rest_N, ...)
Unlike zipped_divide which groups rest dimensions: ((Tile), (Rest_M, Rest_N, ...))
tiled_divide keeps rest dimensions SEPARATE, making it ideal for grid computation.
For example, if mD has shape (1024, 1024, 2) and tile_mn = (128, 128, 1):
- div.shape[0] = (128, 128, 1) - the tile shape
- div.shape[1] = 8 - number of tiles in M dimension (1024/128)
- div.shape[2] = 8 - number of tiles in N dimension (1024/128)
- div.shape[3] = 2 - batch dimension L
The grid is then (8, 8, 2) = 128 CTAs total, each processing one (128, 128) tile.
"""
# Create a packed tile shape for division. The _pack_shape helper handles
# creating the proper CuTe shape representation.
# (*self.mn_tiler, 1) = (M_tile, N_tile, 1) - the 1 handles the batch dimension
tile_mn = cute.core._pack_shape((*self.mn_tiler, 1))
# tiled_divide produces shape: ((tile_M, tile_N, 1), num_M_tiles, num_N_tiles, batch_L)
# This is used to compute the grid dimensions.
div = cute.tiled_divide(mD, tile_mn)
# Grid dimensions: (num_tiles_M, num_tiles_N, batch_size)
# Each CTA (Cooperative Thread Array / thread block) processes one tile.
grid = (div.shape[1], div.shape[2], div.shape[3])
# Launch the kernel with Blackwell-specific configuration:
# - block=(192, 1, 1): 6 warps × 32 threads/warp = 192 threads
# Warp assignment: warps 0-3 (epilogue), warp 4 (MMA), warp 5 (TMA load)
# - cluster=(1, 1, 1): Single-CTA mode (no cluster cooperation)
# - smem: Request maximum shared memory capacity for SM100 (~232KB)
self.kernel(mA, mB, mD).launch(
grid=grid,
block=(192, 1, 1), # 6 warps for warp-specialized GEMM
cluster=(1, 1, 1), # Single CTA per cluster
smem=cute.Int64(utils.get_smem_capacity_in_bytes("sm_100")),
)
# ================================================================================================
# DEVICE KERNEL
# ================================================================================================
@cute.experimental.kernel
def kernel(
self,
mA: cute.Tensor,
mB: cute.Tensor,
mD: cute.Tensor,
):
"""
Device-side kernel function - the actual GPU computation.
The @cute.experimental.kernel decorator indicates this function:
- Runs on the DEVICE (GPU)
- Contains all SMEM/TMEM/RMEM allocations, pipeline setup, and compute logic
- Is compiled to PTX and executed by each thread in the grid
This kernel follows the standard LIR GEMM structure:
1. Create tiled_mma configuration
2. Compute tiler and divide tensors
3. Allocate SMEM, TMEM, and RMEM buffers
4. Create pipelines for producer/consumer synchronization
5. Assign warps to specialized roles
6. Execute TMA load, MMA compute, and epilogue/store phases
Args:
mA: Input A tensor (GMEM), shape (M, K, L)
mB: Input B tensor (GMEM), shape (N, K, L)
mD: Output D tensor (GMEM), shape (M, N, L)
"""
# ========================================================================================
# STEP 1: CREATE TILED MMA CONFIGURATION
# ========================================================================================
# The tiled_mma object encapsulates the MMA instruction configuration for Blackwell.
# It defines:
# - The MMA atom shape (the hardware instruction's native tile size)
# - Thread-to-data mapping for the MMA operation
# - Layout requirements for operands
#
# make_trivial_tiled_mma creates a basic tiled MMA configuration:
# - ab_dtype: Element type for A and B operands
# - mma_major_mode(): Returns the major mode for MMA (K-major or MN-major)
# - acc_dtype: Accumulator precision (typically Float32)
# - CtaGroup.ONE: Single-CTA MMA (vs TWO for cooperative 2-CTA)
# - mn_tiler: The (M, N) tile dimensions
#
# The mma_major_mode() is derived from the tensor layout:
# - K-major A: stride(A)[1] < stride(A)[0] (K is the fast dimension)
# - M-major A: stride(A)[0] < stride(A)[1] (M is the fast dimension)
tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.ab_dtype,
utils.LayoutEnum.from_tensor(mA).mma_major_mode(),
utils.LayoutEnum.from_tensor(mB).mma_major_mode(),
self.acc_dtype,
cute.nvgpu.tcgen05.CtaGroup.ONE, # Single CTA mode
self.mn_tiler,
)
# ========================================================================================
# STEP 2: COMPUTE TILER DIMENSIONS (MNK)
# ========================================================================================
# The MMA instruction operates on tiles. We need to compute the full MNK tiler
# which includes the K dimension (reduction dimension).
#
# cute.size(tiled_mma.shape_mnk, mode=[2]):
# - tiled_mma.shape_mnk is the (M, N, K) shape of the MMA instruction
# - mode=[2] extracts the K dimension (0=M, 1=N, 2=K)
# - For SM100, this is typically 16 (the instruction's native K)
#
# mma_inst_tile_k (=4) is the number of MMA instructions per K-tile iteration.
# This is a tuning parameter:
# - Higher values (8): Larger K-tile, better MMA utilization, but more SMEM
# - Lower values (2): Smaller K-tile, less SMEM, but more loop iterations
# - 4 is a safe default that balances these tradeoffs
mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
mma_inst_tile_k = 4 # Number of MMA K-tile subdivisions per mainloop iteration
# Full MNK tiler: (M_tile, N_tile, K_tile)
# K_tile = mma_inst_shape_k * mma_inst_tile_k (e.g., 16 * 4 = 64)
mnk_tiler = (
self.mn_tiler[0], # M dimension from constructor
self.mn_tiler[1], # N dimension from constructor
mma_inst_shape_k * mma_inst_tile_k, # K dimension
)
# Get output tensor layout and type for epilogue configuration
d_layout = utils.LayoutEnum.from_tensor(mD)
d_dtype = mD.element_type
# Create sub-tilers for each operand:
# - A has shape (M, K, L) → tiler_mk = (M_tile, K_tile)
# - B has shape (N, K, L) → tiler_nk = (N_tile, K_tile)
# - D has shape (M, N, L) → tiler_mn = (M_tile, N_tile)
tiler_mk = (mnk_tiler[0], mnk_tiler[2])
tiler_nk = (mnk_tiler[1], mnk_tiler[2])
tiler_mn = (mnk_tiler[0], mnk_tiler[1])
# ========================================================================================
# STEP 3: DIVIDE GLOBAL TENSORS INTO TILES (zipped_divide)
# ========================================================================================
# cute.zipped_divide is the PRIMARY tiling operation in LIR kernels.
#
# CUTE ALGEBRA EXPLANATION - zipped_divide:
# ------------------------------------------
# zipped_divide(tensor, tiler) divides a tensor into tiles and produces:
# - Mode 0: The tile shape itself
# - Mode 1: A "zipped" layout of tile coordinates
#
# Result shape: ((TileM, TileK), (RestM, RestK, L))
#
# For example, if mA has shape (1024, 512, 2) and tiler_mk = (128, 64):
# - gA shape = ((128, 64), (8, 8, 2))
# - (128, 64): One tile of A
# - (8, 8, 2): 8 tiles in M, 8 tiles in K, 2 batches = 128 total tiles
#
# Key difference from tiled_divide:
# - zipped_divide: ((Tile), (Rest...)) - rest dimensions grouped together
# - tiled_divide: ((Tile), Rest_M, Rest_N, ...) - rest dimensions separate
#
# zipped_divide is preferred for CTA tile selection because the zipped
# rest coordinates can be indexed with a single (cta_m, k, batch) tuple.
gA = cute.zipped_divide(mA, tiler_mk)
gB = cute.zipped_divide(mB, tiler_nk)
gD = cute.zipped_divide(mD, tiler_mn)
# ========================================================================================
# STEP 4: PIPELINE CONFIGURATION
# ========================================================================================
# mainloop_stage: Number of pipeline stages for the TMA load → MMA pipeline.
# More stages allow better overlap of TMA loads with MMA compute.
# - 2 stages: Minimum for double-buffering
# - 4 stages: Good for large GEMMs (better latency hiding)
# Trade-off: More stages = more SMEM usage
#
# acc_stage: Number of accumulator stages in TMEM.
# - For N=256 tiles: use 1 (single accumulator buffer)
# - For N=128 tiles: use 2 (double-buffered)
# Using the correct acc_stage provides measurable performance improvement.
mainloop_stage = 2
acc_stage = 2
# ========================================================================================
# STEP 5: GET CTA AND THREAD INDICES
# ========================================================================================
# Each CTA is identified by its position in the 3D grid: (cta_m, cta_n, cta_l)
# - cta_m: Which M-tile this CTA processes
# - cta_n: Which N-tile this CTA processes
# - cta_l: Which batch element this CTA processes
#
# Each thread within a CTA is identified by tid_x (0-191 for 192 threads).
cta_m, cta_n, cta_l = cute.arch.block_idx()
tid_x, _, _ = cute.arch.thread_idx()
# ========================================================================================
# STEP 6: SELECT THIS CTA'S TILES FROM GLOBAL TENSORS
# ========================================================================================
# After zipped_divide, we select the specific tiles for this CTA using slicing.
#
# CUTE SLICING NOTATION:
# - None: Keep this dimension (preserve the mode)
# - integer: Fix this dimension at that index
#
# gA has shape ((M_tile, K_tile), (num_M_tiles, num_K_tiles, batch))
# gA_tile = gA[(None, None), (cta_m, None, cta_l)] means:
# - (None, None): Keep the tile shape modes (M_tile, K_tile)
# - (cta_m, None, cta_l): Select M-tile cta_m, keep K dimension, select batch cta_l
#
# Result: gA_tile has shape (M_tile, K_tile, num_K_tiles) - one CTA's work
# The K dimension (None) is kept because we iterate over K in the mainloop.
gA_tile = gA[(None, None), (cta_m, None, cta_l)]
gB_tile = gB[(None, None), (cta_n, None, cta_l)]
gD_tile = gD[(None, None), (cta_m, cta_n, cta_l)]
# ========================================================================================
# STEP 7: CREATE SMEM LAYOUTS WITH SWIZZLING
# ========================================================================================
# SMEM layouts must:
# 1. Match the tile dimensions from the tiler
# 2. Include staging for pipeline buffers
# 3. Use swizzle patterns to avoid bank conflicts
#
# make_smem_layout_a/b are helper functions that:
# - Select appropriate swizzle patterns based on major mode and element type
# - Append the stage dimension for pipelining
# - Return a ComposedLayout (layout + swizzle function)
#
# The swizzle pattern interleaves memory addresses across the 32 SMEM banks,
# ensuring that when a warp accesses consecutive elements, they hit different
# banks (avoiding serialization from bank conflicts).
#
# LAYOUT SHAPE: (MMA_ATOM, MMA_TILE, MMA_K, PIPELINE_STAGES)
# For operand A: this encodes how to store M×K tiles with proper bank conflict avoidance
a_smem_layout_staged = sm100_utils.make_smem_layout_a(
tiled_mma,
mnk_tiler,
self.ab_dtype,
mainloop_stage, # Number of pipeline stages
)
b_smem_layout_staged = sm100_utils.make_smem_layout_b(
tiled_mma,
mnk_tiler,
self.ab_dtype,
mainloop_stage,
)
# ========================================================================================
# STEP 8: COMPUTE EPILOGUE TILE SHAPE
# ========================================================================================
# The epilogue processes output tiles in smaller sub-tiles (epi_tile).
# This is necessary because:
# 1. TMEM→RMEM copies have granularity constraints
# 2. TMA stores work on specific tile sizes
#
# cta_tile_shape_mnk: The effective tile shape per CTA after accounting for
# thread-level tiling. This is computed as:
# mnk_tiler / (num_threads_in_mma, 1, 1)
#
# cute.shape_div performs element-wise division of shapes.
# cute.size(tiled_mma.thr_id.shape) gives the number of threads participating in MMA.
cta_tile_shape_mnk = cute.shape_div(
mnk_tiler, (cute.size(tiled_mma.thr_id.shape), 1, 1)
)
# compute_epilogue_tile_shape determines the sub-tile size for epilogue operations.
# It considers:
# - CTA tile shape
# - Whether using 1-CTA or 2-CTA instructions
# - Output layout (M-major or N-major)
# - Output data type
epi_tile = sm100_utils.compute_epilogue_tile_shape(
cta_tile_shape_mnk,
self.use_2cta_instrs,
d_layout,
d_dtype,
)
# Create epilogue SMEM layout for TMA stores.
# This layout is used for the bufferC staging buffer before TMA store to GMEM.
sc_smem_layout_staged = sm100_utils.make_smem_layout_epi(
d_dtype,
d_layout,
epi_tile,
self.TMA_STORE_STAGE, # Number of TMA store pipeline stages
)
# ========================================================================================
# STEP 9: CREATE TMEM LAYOUT FOR ACCUMULATORS
# ========================================================================================
# TMEM (Tensor Memory) is Blackwell's specialized memory for MMA accumulators.
# It provides high-bandwidth access for accumulator updates during MMA operations.
#
# TMEM CHARACTERISTICS:
# - Accessible only by the MMA unit within a warpgroup
# - Has a capacity limit of 512 columns
# - Requires specific layout patterns matching MMA instructions
#
# make_tmem_layout_acc: Derives the TMEM accumulator buffer layout from the
# tiled MMA and MNK tiler, with the given number of pipeline stages.
tmem_layout = cute_ext.make_tmem_layout_acc(tiled_mma, mnk_tiler, acc_stage)
# ========================================================================================
# STEP 10: ALLOCATE SMEM BUFFERS
# ========================================================================================
# cute_ext.allocate creates a tensor in the specified address space.
#
# Arguments:
# - type: Element type (e.g., Float16, Float32)
# - address_space: One of smem, tmem, rmem, gmem
# - layout: The layout including staging dimensions
# - alignment: Byte alignment (1024 for SMEM, 16 for TMEM, 32 for RMEM)
#
# ALIGNMENT RATIONALE:
# - SMEM (1024 bytes): Optimal for TMA transfers and swizzle patterns
# - TMEM (16 bytes): Standard tensor memory alignment
# - RMEM (32 bytes): Vectorized register loads/stores
# Allocate SMEM buffers for A and B operands.
# These buffers hold multiple pipeline stages of tiles loaded from GMEM.
bufferA = cute_ext.allocate(
self.ab_dtype,
cute.AddressSpace.smem,
a_smem_layout_staged,
alignment=1024,
)
bufferB = cute_ext.allocate(
self.ab_dtype,
cute.AddressSpace.smem,
b_smem_layout_staged,
alignment=1024,
)
# Allocate TMEM buffer for MMA accumulators.
# This stores the running sum: C += A × B across K iterations.
bufferAcc = cute_ext.allocate(
self.acc_dtype,
cute.AddressSpace.tmem,
tmem_layout,
alignment=16,
)
# Allocate SMEM buffer for output (C) - used during epilogue before TMA store.
bufferC = cute_ext.allocate(
d_dtype,
cute.AddressSpace.smem,
sc_smem_layout_staged,
alignment=1024,
)
# ========================================================================================
# STEP 11: CREATE TMEM->RMEM COPY CONFIGURATION
# ========================================================================================
# The epilogue copies data from TMEM (accumulators) → RMEM (registers) → SMEM → GMEM.
# This section sets up the copy atoms and tiled copies for this path.
#
# get_tmem_load_op: Returns the appropriate tcgen05 load operation for TMEM→RMEM.
# It selects the right instruction based on:
# - CTA tile shape
# - Output layout orientation
# - Data types
# - Epilogue tile size
# - 1-CTA vs 2-CTA mode
copy_atom_t2r = sm100_utils.get_tmem_load_op(
cta_tile_shape_mnk,
d_layout,
self.tmem_output_dtype,
self.acc_dtype,
epi_tile,
self.use_2cta_instrs,
)
# ========================================================================================
# STEP 12: PREPARE ACCUMULATOR FOR EPILOGUE ITERATION
# ========================================================================================
# The accumulator buffer is divided into epilogue-sized sub-tiles for iteration.
#
# CUTE ALGEBRA EXPLANATION - zipped_divide on accumulators:
# ----------------------------------------------------------
# We divide bufferAcc by (epi_tile, 1) to create sub-tiles for epilogue processing.
# The "1" preserves the stage dimension.
#
# accumulators = cute.zipped_divide(bufferAcc, ((epi_tile), 1))
# This creates: ((epi_tile_shape), (rest_subtiles, stages))
#
# acc_epi_div = accumulators[((None, None), 0), 0]
# - (None, None): Keep the epilogue tile shape
# - 0: Select the first rest-mode position
# - 0: Select the first stage (for tiled_copy_t2r creation)
#
# This gives us one epilogue tile's worth of data for configuring the copy.
accumulators = cute.zipped_divide(bufferAcc, ((epi_tile), 1))
acc_epi_div = accumulators[((None, None), 0), 0]
# Create the tiled copy operation for TMEM→RMEM.
# make_tmem_copy creates a TiledCopy object that defines:
# - How threads partition the source (TMEM)
# - How threads partition the destination (RMEM)
# - The mapping between source and destination layouts
tiled_copy_t2r = cute.nvgpu.tcgen05.make_tmem_copy(copy_atom_t2r, acc_epi_div)
# ========================================================================================
# STEP 13: DERIVE RMEM LAYOUT FROM COPY PARTITION
# ========================================================================================
# RMEM layouts must match the thread-value ownership pattern of the copy.
# We derive the RMEM layout by partitioning the destination and extracting
# the per-thread layout.
#
# CUTE ALGEBRA EXPLANATION - flat_divide:
# ---------------------------------------
# flat_divide(tensor, tiler) flattens all dimensions:
# Result shape: (Tile_M, Tile_N, Rest_M, Rest_N, ...)
#
# Unlike zipped_divide which groups tile and rest separately,
# flat_divide keeps everything flat, which is useful for iteration.
#
# make_t2r_rmem_layout: Derives the per-thread RMEM buffer layout
# produced by a TMEM->RMEM copy for a single epilogue iteration.
gC_mnl_epi = cute.flat_divide(gD_tile, epi_tile)
acc_d_rmem_layout = cute_ext.make_t2r_rmem_layout(
tiled_copy_t2r, gC_mnl_epi, tid_x
)
# ========================================================================================
# STEP 14: ALLOCATE RMEM BUFFERS FOR EPILOGUE
# ========================================================================================
# RMEM (Register Memory) is per-thread storage. Each thread has its own
# private copy of these buffers.
#
# bufferRAcc: Holds accumulator values copied from TMEM (FP32)
# bufferRD: Holds output values after epilogue conversion (output dtype)
bufferRAcc = cute_ext.allocate(
self.acc_dtype, # FP32 for accumulators
cute.AddressSpace.rmem,
acc_d_rmem_layout,
alignment=32,
)
bufferRD = cute_ext.allocate(
d_dtype, # Output dtype (e.g., FP16)
cute.AddressSpace.rmem,
acc_d_rmem_layout,
alignment=32,
)
# ========================================================================================
# STEP 15: CREATE PIPELINES
# ========================================================================================
# Pipelines provide producer/consumer synchronization using hardware barriers.
# They enable overlapping of memory operations with compute.
#
# PIPELINE 1: TMAToUMMAPipeline (mainloop_pipe)
# ---------------------------------------------
# Synchronizes TMA loads (producer) with UMMA/MMA operations (consumer).
# - num_stages: Number of pipeline stages (matches mainloop_stage)
# - mma_operation_type: The type of MMA operation being consumed
# SM100_MMA_1SM_SS = Single SM, Single-Stage MMA (1-CTA mode)
mainloop_pipe = cute_ext.TMAToUMMAPipeline.create(
num_stages=mainloop_stage,
mma_operation_type=cute_ext.OperationTypeEnum.SM100_MMA_1SM_SS,
)
# PIPELINE 2: UMMAtoAsyncPipeline (acc_pipe)
# ------------------------------------------
# Synchronizes UMMA/MMA operations (producer) with TMEM→RMEM copies (consumer).
# - num_stages: Accumulator stages (acc_stage)
# - mma_operation_type: The MMA operation producing data
# - consumer: The operation consuming data (SM100_COPY_T2R = TMEM→RMEM copy)
# - consumer_arv_count: Number of threads participating as consumers (128 = 4 warps)
acc_pipe = cute_ext.UMMAtoAsyncPipeline.create(
num_stages=acc_stage,
mma_operation_type=cute_ext.OperationTypeEnum.SM100_MMA_1SM_SS,
consumer=cute_ext.OperationTypeEnum.SM100_COPY_T2R,
consumer_arv_count=128, # 4 epilogue warps × 32 threads
)
# ========================================================================================
# STEP 16: WARP ASSIGNMENT AND SPECIALIZATION
# ========================================================================================
# This kernel uses 6 warps (192 threads) with specialized roles:
#
# Warp 0: TMA store (also participates in epilogue)
# Warps 0-3: Epilogue processing (TMEM→RMEM→SMEM)
# Warp 4: MMA compute
# Warp 5: TMA load
#
# cute.arch.warp_idx(): Returns this thread's warp index (0-5)
# make_warp_uniform: Ensures all threads in a warp see the same value
# (important for conditional branching to avoid divergence)
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
# Assign warp roles
tma_store_warp_id = 0
mma_warp_id = 4
tma_load_warp_id = 5
# Boolean flags for role-based execution
is_tma_thr = warp_idx == tma_load_warp_id # Only warp 5
is_mma_thr = warp_idx == mma_warp_id # Only warp 4
is_epi_thr = warp_idx < 4 # Warps 0, 1, 2, 3
# PIPELINE 3: TMAStorePipeline (tma_store_pipe)
# ---------------------------------------------
# Synchronizes RMEM→SMEM writes with TMA stores.
# Uses named barriers (not mbarriers) for synchronization.
#
# - stages: Number of TMA store pipeline stages
# - arv_count: Number of threads participating in barriers (128 = 4 warps)
# - barrier_id: Named barrier ID (must be unique per pipeline)
# - tma_warp_id: Which warp issues TMA stores (warp 0)
tma_store_pipe = cute_ext.TMAStorePipeline(
stages=self.TMA_STORE_STAGE,
arv_count=128,
barrier_id=1,
tma_warp_id=tma_store_warp_id,
)
# ========================================================================================
# STEP 17: COMPUTE K-TILE ITERATION COUNT
# ========================================================================================
# cute.size(gA, mode=[1, 1]) extracts the size of the K-tile dimension.
# gA shape after zipped_divide: ((M_tile, K_tile), (num_M_tiles, num_K_tiles, batch))
# mode=[1, 1] accesses the second element of the second mode = num_K_tiles
k_tile_size = cute.size(gA, mode=[1, 1])
# ========================================================================================
# STEP 18: TMA LOAD WARP - PRODUCER PHASE
# ========================================================================================
# The TMA load warp (warp 5) loads A and B tiles from GMEM to SMEM.
# This is the PRODUCER in the mainloop pipeline.
#
# The producer loop iterates over K-tiles, loading data ahead of consumption.
# Pipeline stages allow loads to overlap with MMA operations.
if is_tma_thr:
# cutlass.range: A loop construct that supports unrolling.
# unroll=1 means don't unroll (iterate normally).
# This iterates over K-tiles: k = 0, 1, 2, ... k_tile_size-1
for k in cutlass.range(0, k_tile_size, 1, unroll=1):
# Select the K-tile from the CTA's tile view.
# gA_tile has shape (M_tile, K_tile, num_K_tiles)
# gA_tile[None, None, k] selects the k-th K-tile: shape (M_tile, K_tile)
gA_k = gA_tile[None, None, k]
gB_k = gB_tile[None, None, k]
# ============================================================================
# PIPELINE PRODUCER PROTOCOL
# ============================================================================
# 1. Acquire a pipeline stage (wait for it to be empty)
# 2. Get the mbarrier for TMA synchronization
# 3. Issue TMA loads
# 4. Commit and advance to the next stage
#
# producer_acquire_and_get_stage():
# - Waits for the next pipeline stage to be empty (consumer released it)
# - Returns (stage_token, idx) where:
# - stage_token: Handle for getting the mbarrier
# - idx: Integer index (0 to num_stages-1) for buffer slicing
(
producer_stage_token,
idx,
) = mainloop_pipe.producer_acquire_and_get_stage()
# get_mbarrier: Retrieves the hardware mbarrier pointer for this stage.
# The mbarrier is signaled by TMA hardware when the load completes.
mbar = cute_ext.get_mbarrier(producer_stage_token)
## producer_body begin ##
# Slice SMEM buffers to the current pipeline stage.
# bufferA has shape (atoms, M, K, stages)
# bufferA[None, None, None, idx] selects stage idx: shape (atoms, M, K)
bufferA_sliced = bufferA[None, None, None, idx]
bufferB_sliced = bufferB[None, None, None, idx]
# ============================================================================
# CTA-TO-VALUE MAPS FOR TMA
# ============================================================================
# cta_v_map (CTA-to-Value map) tells TMA which portion of the global tensor
# this CTA should load. It encodes the mapping from CTA coordinates to
# tensor indices.
#
# get_cta_v_map_ab: Computes the CTA-to-value map for operands A or B.
# Arguments:
# - mA/mB: The global tensor
# - mnk_tiler: The MNK tiler dimensions
# - tiled_mma: The MMA configuration
# - "A"/"B": Which operand this is for
a_cta_v_map = cute_ext.get_cta_v_map_ab(mA, mnk_tiler, tiled_mma, "A")
b_cta_v_map = cute_ext.get_cta_v_map_ab(mB, mnk_tiler, tiled_mma, "B")
# ============================================================================
# TMA LOAD OPERATIONS
# ============================================================================
# tma_load: Asynchronous TMA load from GMEM to SMEM.
#
# Arguments:
# - src: Source tensor in GMEM (the K-tile slice)
# - dst: Destination buffer in SMEM (the stage-sliced buffer)
# - mbar: Mbarrier for completion signaling
# - cta_v_map: CTA-to-value mapping layout
#
# The TMA hardware:
# 1. Reads from GMEM at the location specified by cta_v_map
# 2. Writes to SMEM at dst
# 3. Signals mbar when complete
#
# IMPORTANT: src and dst must have matching shapes!
# This is a common source of "source/destination size mismatch" errors.
cute_ext.tma_load(
gA_k, # Source: K-tile from global A
bufferA_sliced, # Destination: SMEM buffer stage
mbar, # Mbarrier for synchronization
cta_v_map=a_cta_v_map,
)
cute_ext.tma_load(
gB_k,
bufferB_sliced,
mbar,
cta_v_map=b_cta_v_map,
)
## producer_body end ##
# producer_commit_and_advance:
# - Signals that producer work is complete (mbarrier will be triggered by TMA)
# - Advances internal pipeline state to the next stage
mainloop_pipe.producer_commit_and_advance()
# ========================================================================================
# STEP 19: MMA WARP - COMPUTE PHASE
# ========================================================================================
# The MMA warp (warp 4) performs matrix multiply-accumulate operations.
# It consumes data from SMEM (loaded by TMA warp) and produces results in TMEM.
#
# The MMA warp is both:
# - CONSUMER of mainloop_pipe (waits for TMA loads to complete)
# - PRODUCER of acc_pipe (signals when accumulation is complete)
if is_mma_thr:
# Acquire accumulator pipeline stage before starting MMA operations.
# This reserves a TMEM accumulator buffer for this K-reduction.
producer_stage_token, idx = acc_pipe.producer_acquire_and_get_stage()
## acc_producer_body begin ##
# Select the TMEM accumulator for this stage.
# bufferAcc has shape (MMA_shape, stages)
accumulators_sliced = bufferAcc[None, None, None, idx]
# ============================================================================
# MMA ATOM CONFIGURATION
# ============================================================================
# cute.make_mma_atom: Creates an MMA atom from the tiled_mma operation.
# The MMA atom represents the hardware MMA instruction configuration.
#
# ACCUMULATE field controls whether to:
# - False: Overwrite accumulator (C = A × B) - used for first iteration
# - True: Accumulate into existing value (C += A × B) - used after first
mma_atom = cute.make_mma_atom(tiled_mma.op)
mma_atom.set(
cute.nvgpu.tcgen05.Field.ACCUMULATE, False
) # First iteration: overwrite
# Iterate over K-tiles (same loop as TMA load warp)
for k in cutlass.range(0, k_tile_size, 1, unroll=1):
# ============================================================================
# PIPELINE CONSUMER PROTOCOL
# ============================================================================
# Wait for TMA load to complete before reading from SMEM.
# consumer_wait_and_get_stage():
# - Waits for the producer (TMA) to signal the mbarrier
# - Returns (stage_token, mainloop_idx) where mainloop_idx is the stage to read
(
_, # Stage token not needed for consumer
mainloop_idx,
) = mainloop_pipe.consumer_wait_and_get_stage()
## tma_consumer_body begin ##
# cute.core.slice_: An alternative slicing function that creates a view.
# This slices the SMEM buffers to the current pipeline stage.
# Equivalent to bufferA[None, None, None, mainloop_idx]
bufferA_sliced_stage = cute.core.slice_(
bufferA, (None, None, None, mainloop_idx)
)
bufferB_sliced_stage = cute.core.slice_(
bufferB, (None, None, None, mainloop_idx)
)
# ============================================================================
# INNER K-TILE LOOP (MMA INSTRUCTION LOOP)
# ============================================================================
# Within each K-tile, we execute multiple MMA instructions.
# mma_inst_tile_k (=4) MMA instructions are executed per K-tile.
#
# unroll_full=True: Fully unroll this loop (generate 4 copies of the body)
# This is important for MMA instruction scheduling.
for k_tile in cutlass.range(mma_inst_tile_k, unroll_full=True):
# Select the k_tile-th sub-slice for this MMA instruction.
# bufferA_sliced_stage has shape (MMA_atom, M_tile, K_tile)
# After slicing [None, None, k_tile]: shape (MMA_atom, M_tile)
bufferA_sliced = bufferA_sliced_stage[None, None, k_tile]
bufferB_sliced = bufferB_sliced_stage[None, None, k_tile]
# ========================================================================
# CUTE.DOT - MATRIX MULTIPLY-ACCUMULATE
# ========================================================================
# cute_ext.dot: Performs MMA operation C = A × B (or C += A × B)
#
# Arguments:
# - mma_atom: The MMA instruction configuration
# - a: Input tensor A (must be rank-3)
# - b: Input tensor B (must be rank-3)
# - c: Accumulator tensor C (in TMEM)
#
# CUTE ALGEBRA EXPLANATION - append_ones:
# ---------------------------------------
# cute.append_ones(tensor, up_to_rank=3):
# The MMA instruction expects rank-3 operands. If bufferA_sliced
# is rank-2 after slicing, append_ones pads it to rank-3 by
# appending singleton dimensions: shape (M, K) → (M, K, 1)
#
# This is necessary because the MMA instruction operates on
# 3D tiles even when the logical operation is 2D.
cute_ext.dot(
mma_atom,
cute.append_ones(bufferA_sliced, up_to_rank=3),
cute.append_ones(bufferB_sliced, up_to_rank=3),
accumulators_sliced,
)
# After the first MMA instruction, enable accumulation mode.
# Subsequent instructions add to the existing accumulator value.
mma_atom.set(cute.nvgpu.tcgen05.Field.ACCUMULATE, True)
## tma_consumer_body end ##
# Release the mainloop pipeline stage for TMA to reuse.
# consumer_release_and_advance():
# - Signals that consumer has finished reading this stage
# - Advances internal state to the next stage
mainloop_pipe.consumer_release_and_advance()
## acc_producer_body end ##
# Signal that MMA computation is complete for this tile.
# The epilogue warps will consume this data.
acc_pipe.producer_commit_and_advance()
# ========================================================================================
# STEP 20: EPILOGUE WARPS - CONSUME AND STORE PHASE
# ========================================================================================
# Warps 0-3 handle the epilogue: copying results from TMEM to GMEM.
# This involves: TMEM → RMEM → apply epilogue op → SMEM → TMA store to GMEM
#
# The epilogue is both:
# - CONSUMER of acc_pipe (waits for MMA to complete)
# - PRODUCER/CONSUMER of tma_store_pipe (coordinates SMEM→GMEM stores)
if is_epi_thr:
# Wait for accumulator data to be ready.
_, idx = acc_pipe.consumer_wait_and_get_stage()
## acc_consume_body begin ##
# Select the accumulator stage and reshape for epilogue iteration.
# accumulators_sliced: shape (M_epi, N_epi) after removing stage dimension
accumulators_sliced = bufferAcc[(None, None), 0, 0, idx]
# Divide the accumulator into epilogue-sized sub-tiles.
# flat_divide creates a flat iteration space over sub-tiles.
# acc_epi_div_tiled: allows iteration with index mn over sub-tiles
acc_epi_div_tiled = cute.flat_divide(accumulators_sliced, epi_tile)
# Get the number of sub-tiles to process.
# mode=[3] accesses the sub-tile count dimension
subtile_cnt = cute.size(acc_epi_div_tiled.shape, mode=[3])
# Iterate over epilogue sub-tiles
for mn in range(subtile_cnt):
# ============================================================================
# TMEM → RMEM COPY
# ============================================================================
# partition_and_copy: High-level function that combines partitioning and copying.
# It handles:
# 1. Partitioning source/destination according to the tiled copy layout
# 2. Selecting the appropriate copy method based on memory spaces
# 3. Executing the copy
#
# For TMEM→RMEM, this uses specialized tcgen05 load instructions.
#
# Arguments:
# - tiled_copy.get_slice(tid_x): Per-thread copy configuration
# - source: TMEM accumulator sub-tile
# - destination: RMEM buffer (per-thread, not partitioned)
cute_ext.partition_and_copy(
tiled_copy_t2r.get_slice(tid_x),
acc_epi_div_tiled[None, None, 0, mn],
bufferRAcc,
)
# ============================================================================
# APPLY EPILOGUE OPERATION IN REGISTERS
# ============================================================================
# bufferRAcc.load(): Reads all values from the RMEM tensor into a register
# .to(d_dtype): Converts from accumulator type (FP32) to output type (FP16)
# self.epilogue_op: Applies user-specified transformation (default: identity)
# bufferRD.store(): Writes the result back to RMEM
#
# Common epilogue operations:
# - Identity: lambda x: x (default)
# - ReLU: cute.where(x > 0, x, cute.full_like(x, 0))
# - GELU: Uses cute.exp for tanh approximation
# - Sigmoid: 1 / (1 + cute.exp(-x))
bufferRD.store(self.epilogue_op(bufferRAcc.load().to(d_dtype)))
# ============================================================================
# TMA STORE PIPELINE PROTOCOL
# ============================================================================
# The TMA store pipeline coordinates multiple warps writing to SMEM
# before a single warp (warp 0) issues the TMA store.
#
# acquire_sync():
# - TMA warp waits for any in-flight TMA ops to complete
# - All warps synchronize via a named barrier
tma_store_pipe.acquire_sync()
# Get the current pipeline stage index for buffer access
idx = tma_store_pipe.get_index()
# ============================================================================
# RMEM → SMEM COPY
# ============================================================================
# Create a tiled copy for RMEM→SMEM using the same layout as TMEM→RMEM.
# make_tiled_copy_D creates a copy with destination-oriented partitioning.
#
# CopyUniversalOp: A generic copy operation that works for any memory pair.
# The partition_and_copy function will select appropriate vectorization.
tiled_copy_r2s = cute.make_tiled_copy_D(
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), d_dtype),
tiled_copy_t2r,
)
# Copy from RMEM to the current SMEM stage buffer
cute_ext.partition_and_copy(
tiled_copy_r2s.get_slice(tid_x),
bufferRD,
bufferC[None, None, idx],
)
# commit_sync():
# - Fences SMEM writes to ensure visibility for TMA
# - All warps synchronize before TMA store
# This is CRITICAL - TMA must see committed SMEM writes!
tma_store_pipe.commit_sync()
# ============================================================================
# TMA STORE (SINGLE WARP)
# ============================================================================
# Only the designated TMA store warp (warp 0) issues the actual TMA store.
# Other warps skip this but still participate in synchronization.
if warp_idx == tma_store_warp_id:
# get_cta_v_map_c: CTA-to-value map for the output tensor.
# Arguments:
# - mD: Global output tensor
# - epi_tile: Epilogue tile shape
c_cta_v_map = cute_ext.get_cta_v_map_c(mD, epi_tile)
# tma_store: Asynchronous TMA store from SMEM to GMEM.
#
# Arguments:
# - src: Source buffer in SMEM (current stage)
# - dst: Destination in GMEM (sub-tile at position mn)
# - cta_v_map: CTA-to-value mapping
#
# The store is added to an async bulk group managed by the pipeline.
cute_ext.tma_store(
bufferC[None, None, idx],
gC_mnl_epi[None, None, 0, mn],
cta_v_map=c_cta_v_map,
)
# release_advance():
# - TMA warp commits TMA ops to bulk group
# - All warps advance to the next pipeline stage
tma_store_pipe.release_advance()
# ============================================================================
# PIPELINE CLEANUP
# ============================================================================
# tail(): Called at the end of the pipeline to ensure all TMA stores complete.
# This waits for all in-flight TMA operations before the kernel exits.
# Without this, the kernel might exit before stores are globally visible!
tma_store_pipe.tail()
# Release the accumulator pipeline stage
acc_pipe.consumer_release_and_advance()
# ====================================================================================================
# HOST-SIDE UTILITY FUNCTIONS
# ====================================================================================================
def create_tensors(l, m, n, k, a_major, b_major, d_major, ab_dtype, d_dtype):
"""
Create input and output tensors for GEMM operation.
This function creates:
1. CPU tensors with proper layouts (for reference computation)
2. GPU tensors wrapped as CuTe tensors (for kernel execution)
Args:
l: Batch size (L dimension)
m: M dimension (rows of A, rows of D)
n: N dimension (columns of B, columns of D)
k: K dimension (columns of A, rows of B - the reduction dimension)
a_major: "m" for M-major (column-major in M), "k" for K-major
b_major: "n" for N-major, "k" for K-major
d_major: "m" for M-major, "n" for N-major
ab_dtype: Data type for A and B matrices
d_dtype: Data type for output matrix
Returns:
Tuple of (a_tensor, b_tensor, d_tensor, a_cpu, b_cpu, d_cpu, d_gpu)
- *_tensor: CuTe tensor wrappers for kernel input
- *_cpu: PyTorch CPU tensors for reference
- d_gpu: PyTorch GPU tensor for result extraction
TENSOR LAYOUT CONVENTIONS:
- cutlass_torch.matrix(l, m, k, m_major, dtype) creates a tensor of shape (m, k, l)
- m_major=True: M is the fast (stride-1) dimension
- m_major=False: K is the fast dimension
CUTE TENSOR CREATION:
- cute_tensor_like wraps a PyTorch tensor as a CuTe tensor
- is_dynamic_layout=True: Allows variable problem sizes
- assumed_align=16: Assumes 16-byte alignment for TMA
"""
torch.manual_seed(1111) # For reproducibility
# Create PyTorch CPU tensors with specified layouts.
# cutlass_torch.matrix(l, m, k, m_major, dtype) creates (m, k, l) tensor
a_torch_cpu = cutlass_torch.matrix(l, m, k, a_major == "m", ab_dtype)
b_torch_cpu = cutlass_torch.matrix(l, n, k, b_major == "n", ab_dtype)
d_torch_cpu = cutlass_torch.matrix(l, m, n, d_major == "m", d_dtype)
# Wrap as CuTe tensors for kernel input.
# cute_tensor_like returns (cute_tensor, pytorch_gpu_tensor)
a_tensor, _ = cutlass_torch.cute_tensor_like(
a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
)
b_tensor, _ = cutlass_torch.cute_tensor_like(
b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
)
d_tensor, d_torch_gpu = cutlass_torch.cute_tensor_like(
d_torch_cpu, d_dtype, is_dynamic_layout=True, assumed_align=16
)
return (
a_tensor,
b_tensor,
d_tensor,
a_torch_cpu,
b_torch_cpu,
d_torch_cpu,
d_torch_gpu,
)
def compare(a_torch_cpu, b_torch_cpu, d_torch_gpu, d_dtype, tolerance):
"""
Compare kernel output against PyTorch reference.
The reference computation uses torch.einsum with the pattern "mkl,nkl->mnl":
- A has shape (m, k, l): indices m, k, l
- B has shape (n, k, l): indices n, k, l
- Output has shape (m, n, l): indices m, n, l
- The 'k' index is summed (contraction)
This computes: D[m,n,l] = sum_k A[m,k,l] * B[n,k,l]
Args:
a_torch_cpu: Input A tensor on CPU
b_torch_cpu: Input B tensor on CPU
d_torch_gpu: Kernel output tensor on GPU
d_dtype: Output data type (for reference tensor creation)
tolerance: Absolute tolerance for comparison
Raises:
AssertionError: If kernel output doesn't match reference within tolerance
"""
# Compute reference using einsum
ref = torch.einsum("mkl,nkl->mnl", a_torch_cpu, b_torch_cpu)
# Wrap reference as CuTe tensor (for consistent comparison)
_, ref_torch_gpu = cutlass_torch.cute_tensor_like(
ref, d_dtype, is_dynamic_layout=True, assumed_align=16
)
ref_result = ref_torch_gpu.cpu()
# Compare with tolerance
torch.testing.assert_close(
d_torch_gpu.cpu(), ref_result, atol=tolerance, rtol=1e-05
)
def run(
mnkl: Tuple[int, int, int, int],
mma_tiler_mn: Tuple[int, int],
cluster_shape_mn: Tuple[int, int],
ab_dtype: Type[Numeric],
c_dtype: Type[Numeric],
acc_dtype: Type[Numeric],
a_major: str,
b_major: str,
c_major: str,
warmup_iterations: int = 0,
iterations: int = 1,
use_cold_l2: bool = False,
tolerance: float = 1e-02,
skip_ref_check: bool = False,
**kwargs,
):
"""Execute a batched dense GEMM operation on Blackwell architecture with performance benchmarking.
This function:
1. Creates input tensors
2. Instantiates and compiles the kernel
3. Executes the kernel
4. Validates correctness against PyTorch reference
5. Benchmarks performance
COMPILATION PATTERN:
-------------------
CRITICAL: Always use explicit compilation to avoid JIT overhead!
WRONG (recompiles every call, ~1000x slower):
kernel = DenseGemmKernel(...)
kernel(a, b, d) # JIT compilation happens here every time!
CORRECT (compile once, run many times):
kernel = DenseGemmKernel(...)
compiled = cute_ext.compile(kernel, a, b, d) # Compile once
compiled(a, b, d) # Fast execution
Args:
mnkl: Problem size tuple (M, N, K, L)
mma_tiler_mn: MMA tile shape (M_tile, N_tile)
cluster_shape_mn: Cluster shape (currently unused in 1-CTA mode)
ab_dtype: Input data type
d_dtype: Output data type
acc_dtype: Accumulator data type
a_major, b_major, d_major: Layout specifications ("m"/"k"/"n")
warmup_iterations: Warmup iterations before timing
iterations: Timed iterations
use_cold_l2: Whether to use cold L2 cache (requires fresh tensors)
tolerance: Tolerance for numerical comparison
skip_ref_check: Skip reference validation
Returns:
exec_time: Execution time in microseconds per iteration
"""
print("Running Blackwell Dense GEMM test with:")
print(f"mnkl: {mnkl}")
print(f"AB dtype: {ab_dtype}, D dtype: {c_dtype}, Acc dtype: {acc_dtype}")
print(f"Matrix majors - A: {a_major}, B: {b_major}, D: {c_major}")
print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}")
print(f"Tolerance: {tolerance}")
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
print(f"Skip reference checking: {skip_ref_check}")
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
m, n, k, l = mnkl
ab_dtype = ab_dtype
d_dtype = c_dtype
d_major = c_major
# Create tensors
a_tensor, b_tensor, d_tensor, a_torch_cpu, b_torch_cpu, d_torch_cpu, d_torch_gpu = (
create_tensors(l, m, n, k, a_major, b_major, d_major, ab_dtype, d_dtype)
)
# Instantiate kernel with configuration
dense_gemm = DenseGemmKernel(
mn_tiler=mma_tiler_mn,
mma_dtype=(ab_dtype, acc_dtype),
tmem_output_dtype=d_dtype,
)
# compile() pre-compiles the kernel for the given tensor shapes/types
compiled_dense_gemm = cute_ext.compile(dense_gemm, a_tensor, b_tensor, d_tensor)
# Execute the kernel (now fast - no recompilation)
compiled_dense_gemm(a_tensor, b_tensor, d_tensor)
# Validate correctness
if not skip_ref_check:
compare(a_torch_cpu, b_torch_cpu, d_torch_gpu, d_dtype, tolerance)
print("check reference: PASS")
# Tensor generator for benchmarking
def generate_tensors():
a_tensor, _ = cutlass_torch.cute_tensor_like(
a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
)
b_tensor, _ = cutlass_torch.cute_tensor_like(
b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16
)
d_tensor, _ = cutlass_torch.cute_tensor_like(
d_torch_cpu, d_dtype, is_dynamic_layout=True, assumed_align=16
)
return testing.JitArguments(a_tensor, b_tensor, d_tensor)
# For cold L2 benchmarking, we need enough tensor copies to flush the cache
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
a_torch_cpu.numel() * a_torch_cpu.element_size()
+ b_torch_cpu.numel() * b_torch_cpu.element_size()
+ d_torch_cpu.numel() * d_torch_cpu.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
# Run benchmark
exec_time = testing.benchmark(
compiled_dense_gemm,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
return exec_time
# ====================================================================================================
# COMMAND-LINE INTERFACE
# ====================================================================================================
if __name__ == "__main__":
def parse_comma_separated_ints(s: str) -> Tuple[int, ...]:
try:
return tuple(int(x.strip()) for x in s.split(","))
except ValueError:
raise argparse.ArgumentTypeError(
"Invalid format. Expected comma-separated integers."
)
parser = argparse.ArgumentParser(description="Example of Dense GEMM on Blackwell.")
parser.add_argument(
"--mnkl",
type=parse_comma_separated_ints,
default=(256, 256, 512, 1),
help="mnkl dimensions (comma-separated)",
)
parser.add_argument(
"--mma_tiler_mn",
type=parse_comma_separated_ints,
default=(128, 128),
help="Mma tile shape (comma-separated)",
)
parser.add_argument(
"--cluster_shape_mn",
type=parse_comma_separated_ints,
default=(1, 1),
help="Cluster shape (comma-separated)",
)
parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.Float32)
parser.add_argument("--d_dtype", type=cutlass.dtype, default=cutlass.Float32)
parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32)
parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
parser.add_argument("--d_major", choices=["n", "m"], type=str, default="n")
parser.add_argument(
"--warmup_iterations", type=int, default=0, help="Warmup iterations"
)
parser.add_argument(
"--iterations", type=int, default=1, help="Number of iterations"
)
parser.add_argument("--use_cold_l2", action="store_true", help="Use cold L2")
parser.add_argument(
"--tolerance", type=float, default=1e-02, help="Tolerance for validation"
)
parser.add_argument(
"--skip_ref_check", action="store_true", help="Skip reference checking"
)
args = parser.parse_args()
if len(args.mnkl) != 4:
parser.error("--mnkl must contain exactly 4 values")
if len(args.mma_tiler_mn) != 2:
parser.error("--mma_tiler_mn must contain exactly 2 values")
exec_time = run(
args.mnkl,
args.mma_tiler_mn,
args.cluster_shape_mn,
args.ab_dtype,
args.d_dtype,
args.acc_dtype,
args.a_major,
args.b_major,
args.d_major,
args.warmup_iterations,
args.iterations,
args.use_cold_l2,
args.tolerance,
args.skip_ref_check,
)
print(f"Execution time: {exec_time} microseconds per iteration")