Files
cutlass/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py
2026-02-13 23:27:58 -05:00

1843 lines
68 KiB
Python

# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
from typing import Optional, Tuple, Type, Union
from functools import lru_cache
import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
import cutlass.utils as utils
from cutlass.utils import is_fp8_dtype, create_cute_tensor_for_fp8
import cutlass.pipeline as pipeline
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
from cutlass.cute.nvgpu import cpasync, tcgen05
"""
A high-performance persistent batched dense GEMM example for the NVIDIA Blackwell SM100 architecture
using CUTE DSL.
- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M")
- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K")
- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M")
This GEMM kernel supports the following features:
- Utilizes Tensor Memory Access (TMA) for efficient memory operations
- Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations (including 2cta mma instructions)
- Implements TMA multicast with cluster to reduce L2 memory traffic
- Support persistent tile scheduling to better overlap memory load/store with mma between tiles
- Support warp specialization to avoid explicit pipelining between mainloop load and mma
This GEMM works as follows:
1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations.
2. MMA warp: Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction.
3. EPILOGUE warp:
- Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
- Type convert C matrix to output type.
- Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations,
or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations.
- Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor:
e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0))
SM100 tcgen05.mma instructions operate as follows:
- Read matrix A from SMEM
- Read matrix B from SMEM
- Write accumulator to TMEM
The accumulator in TMEM must then be loaded to registers before writing back to GMEM.
Input arguments to this example is same as dense_gemm.py.
.. code-block:: bash
python examples/blackwell/dense_gemm_persistent.py \
--ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \
--mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \
--mnkl 8192,8192,8192,1 \
--use_tma_store --use_2cta_instrs
To collect performance with NCU profiler:
.. code-block:: bash
ncu python examples/blackwell/dense_gemm_persistent.py \
--ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \
--mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \
--mnkl 8192,8192,8192,1 \
--use_tma_store --use_2cta_instrs \
--warmup_iterations 1 --iterations 10 --skip_ref_check
Constraints are same as dense_gemm.py:
* Supported input data types: fp16, bf16, tf32, int8, uint8, fp8 (e4m3fn, e5m2),
see detailed valid dtype combinations in below PersistentDenseGemmKernel class documentation
* A/B tensor must have the same data type
* Mma tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True)
* Mma tiler N must be 32-256, step 32
* Cluster shape M/N must be positive and power of 2, total cluster size <= 16
* Cluster shape M must be multiple of 2 if use_2cta_instrs=True
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
i.e, number of elements is a multiple of 4, 8, and 16 for TFloat32,
Float16/BFloat16, and Int8/Uint8/Float8, respectively.
* OOB tiles are not allowed when TMA store is disabled
"""
def _compute_stages(
tiled_mma: cute.TiledMma,
mma_tiler_mnk: Tuple[int, int, int],
a_dtype: Type[cutlass.Numeric],
b_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
smem_capacity: int,
occupancy: int,
use_tma_store: bool,
c_smem_layout: Union[cute.Layout, None],
) -> Tuple[int, int, int]:
"""Computes the number of stages for A/B/C operands based on heuristics.
:param tiled_mma: The tiled MMA object defining the core computation.
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler.
:type mma_tiler_mnk: tuple[int, int, int]
:param a_dtype: Data type of operand A.
:type a_dtype: type[cutlass.Numeric]
:param b_dtype: Data type of operand B.
:type b_dtype: type[cutlass.Numeric]
:param c_dtype: Data type of operand C (output).
:type c_dtype: type[cutlass.Numeric]
:param smem_capacity: Total available shared memory capacity in bytes.
:type smem_capacity: int
:param occupancy: Target number of CTAs per SM (occupancy).
:type occupancy: int
:param use_tma_store: Whether TMA store is enabled.
:type use_tma_store: bool
:param c_smem_layout: Layout of C operand in shared memory, or None if not using TMA store.
:type c_smem_layout: Union[cute.Layout, None]
:return: A tuple containing the computed number of stages for:
(ACC stages, A/B operand stages, C stages)
:rtype: tuple[int, int, int]
"""
# Default ACC stages
num_acc_stage = 2
# Default C stages
num_c_stage = 2 if use_tma_store else 0
# Calculate smem layout and size for one stage of A, B, and C with 1-stage
a_smem_layout_stage_one = utils.sm100.make_smem_layout_a(
tiled_mma, mma_tiler_mnk, a_dtype, 1
)
b_smem_layout_staged_one = utils.sm100.make_smem_layout_b(
tiled_mma, mma_tiler_mnk, b_dtype, 1
)
ab_bytes_per_stage = cute.size_in_bytes(
a_dtype, a_smem_layout_stage_one
) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one)
mbar_helpers_bytes = 1024
c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout)
c_bytes = c_bytes_per_stage * num_c_stage
# Calculate A/B stages:
# Start with total smem per CTA (capacity / occupancy)
# Subtract reserved bytes and initial C stages bytes
# Divide remaining by bytes needed per A/B stage
num_ab_stage = (
smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes)
) // ab_bytes_per_stage
# Refine epilogue stages:
# Calculate remaining smem after allocating for A/B stages and reserved bytes
# Add remaining unused smem to epilogue
if use_tma_store:
num_c_stage += (
smem_capacity
- occupancy * ab_bytes_per_stage * num_ab_stage
- occupancy * (mbar_helpers_bytes + c_bytes)
) // (occupancy * c_bytes_per_stage)
return num_acc_stage, num_ab_stage, num_c_stage
class PersistentDenseGemmKernel:
"""This class implements batched matrix multiplication (C = A x B) with support for various data types
and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
:param acc_dtype: Data type for accumulation during computation
:type acc_dtype: type[cutlass.Numeric]
:param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation
:type use_2cta_instrs: bool
:param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N)
:type mma_tiler_mn: Tuple[int, int]
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
:type cluster_shape_mn: Tuple[int, int]
:param use_tma_store: Whether to use Tensor Memory Access (TMA) for storing results
:type use_tma_store: bool
:note: In current version, A and B tensor must have the same data type
- i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported
:note: Supported A/B data types:
- TFloat32
- Float16/BFloat16
- Int8/Uint8
- Float8E4M3FN/Float8E5M2
:note: Supported accumulator data types:
- Float32 (for all floating point A/B data types)
- Float16 (only for fp16 and fp8 A/B data types)
- Int32 (only for uint8/int8 A/B data types)
:note: Supported C data types:
- Float32 (for float32 and int32 accumulator data types)
- Int32 (for float32 and int32 accumulator data types)
- Float16/BFloat16 (for fp16 and fp8 accumulator data types)
- Int8/Uint8 (for uint8/int8 accumulator data types)
- Float8E4M3FN/Float8E5M2 (for float32 accumulator data types)
:note: Constraints:
- MMA tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True)
- MMA tiler N must be 32-256, step 32
- Cluster shape M must be multiple of 2 if use_2cta_instrs=True
- Cluster shape M/N must be positive and power of 2, total cluster size <= 16
**Example:**
gemm = PersistentDenseGemmKernel(
acc_dtype=cutlass.Float32,
use_2cta_instrs=True,
mma_tiler_mn=(128, 128),
cluster_shape_mn=(2, 2)
)
gemm(a, b, c, max_active_clusters, stream)
"""
def __init__(
self,
acc_dtype: Type[cutlass.Numeric],
use_2cta_instrs: bool,
mma_tiler_mn: Tuple[int, int],
cluster_shape_mn: Tuple[int, int],
use_tma_store: bool,
):
"""Initializes the configuration for a Blackwell dense GEMM kernel.
This configuration includes several key aspects:
1. MMA Instruction Settings (tcgen05):
- acc_dtype: Data types for MMA accumulator.
- mma_tiler_mn: The (M, N) shape of the MMA instruction tiler.
- use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant
with cta_group=2 should be used.
2. Cluster Shape:
- cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster.
3. Output C tensor store mode:
- use_tma_store: Boolean indicating whether to use Tensor Memory Access (TMA) for storing results.
:param acc_dtype: Data type of the accumulator.
:type acc_dtype: type[cutlass.Numeric]
:param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction.
:type mma_tiler_mn: Tuple[int, int]
:param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant.
:type use_2cta_instrs: bool
:param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster.
:type cluster_shape_mn: Tuple[int, int]
:param use_tma_store: Use Tensor Memory Access (TMA) or normal store for output C tensor.
:type use_tma_store: bool
"""
self.acc_dtype: Type[cutlass.Numeric] = acc_dtype
self.use_2cta_instrs = use_2cta_instrs
self.cluster_shape_mn = cluster_shape_mn
# K dimension is deferred in _setup_attributes
self.mma_tiler_mn = mma_tiler_mn
self.mma_tiler = (*mma_tiler_mn, 1)
self.use_tma_store = use_tma_store
self.arch = "sm_100"
self.cta_group = (
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
)
self.occupancy = 1
# Set specialized warp ids
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.threads_per_cta = 32 * len(
(self.mma_warp_id, self.tma_warp_id, *self.epilogue_warp_id)
)
# Set barrier id for cta sync, epilogue sync and tmem ptr sync
self.epilog_sync_bar_id = 1
self.tmem_alloc_sync_bar_id = 2
self.tmem_dealloc_sync_bar_id = 3
def _create_tiled_mma(self):
return utils.sm100.make_trivial_tiled_mma(
self.a_dtype,
self.a_major_mode,
self.b_major_mode,
self.acc_dtype,
self.cta_group,
self.mma_tiler[:2],
)
def _setup_attributes(self):
"""Set up configurations that are dependent on GEMM inputs
This method configures various attributes based on the input tensor properties
(data types, leading dimensions) and kernel settings:
- Configuring tiled MMA
- Computing MMA/cluster/tile shapes
- Computing cluster layout
- Computing multicast CTAs for A/B
- Computing epilogue subtile
- Setting up A/B/C stage counts in shared memory
- Computing A/B/C shared memory layout
- Computing tensor memory allocation columns
"""
# Configure tiled mma
tiled_mma = self._create_tiled_mma()
# Compute mma/cluster/tile shapes
mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
mma_inst_tile_k = 4
self.mma_tiler = (
self.mma_tiler[0],
self.mma_tiler[1],
mma_inst_shape_k * mma_inst_tile_k,
)
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler[1],
self.mma_tiler[2],
)
# Compute cluster layout
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((*self.cluster_shape_mn, 1)),
(tiled_mma.thr_id.shape,),
)
# Compute number of multicast CTAs for A/B
self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
self.is_a_mcast = self.num_mcast_ctas_a > 1
self.is_b_mcast = self.num_mcast_ctas_b > 1
# Compute epilogue subtile
if cutlass.const_expr(self.use_tma_store):
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk,
self.use_2cta_instrs,
self.c_layout,
self.c_dtype,
)
else:
self.epi_tile = self.cta_tile_shape_mnk[:2]
c_smem_layout = None
if cutlass.const_expr(self.use_tma_store):
c_smem_layout = utils.sm100.make_smem_layout_epi(
self.c_dtype, self.c_layout, self.epi_tile, 1
)
self.smem_capacity = utils.get_smem_capacity_in_bytes()
# Setup A/B/C stage count in shared memory and ACC stage count in tensor memory
self.num_acc_stage, self.num_ab_stage, self.num_c_stage = _compute_stages(
tiled_mma,
self.mma_tiler,
self.a_dtype,
self.b_dtype,
self.c_dtype,
self.smem_capacity,
self.occupancy,
self.use_tma_store,
c_smem_layout,
)
# Compute A/B/C shared memory layout
self.a_smem_layout_staged = utils.sm100.make_smem_layout_a(
tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage
)
self.b_smem_layout_staged = utils.sm100.make_smem_layout_b(
tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage
)
self.c_smem_layout_staged = None
if self.use_tma_store:
self.c_smem_layout_staged = utils.sm100.make_smem_layout_epi(
self.c_dtype, self.c_layout, self.epi_tile, self.num_c_stage
)
# Compute the number of tensor memory allocation columns
self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols(
tiled_mma, self.mma_tiler, self.num_acc_stage, self.arch
)
@cute.jit
def __call__(
self,
a: cute.Tensor,
b: cute.Tensor,
c: cute.Tensor,
max_active_clusters: cutlass.Constexpr,
stream: cuda.CUstream,
epilogue_op: cutlass.Constexpr = lambda x: x,
):
"""Execute the GEMM operation in steps:
- Setup static attributes before smem/grid/tma computation
- Setup TMA load/store atoms and tensors
- Compute grid size with regard to hardware constraints
- Define shared storage for kernel
- Launch the kernel synchronously
:param a: Input tensor A
:type a: cute.Tensor
:param b: Input tensor B
:type b: cute.Tensor
:param c: Output tensor C
:type c: cute.Tensor
:param max_active_clusters: Maximum number of active clusters
:type max_active_clusters: cutlass.Constexpr
:param stream: CUDA stream for asynchronous execution
:type stream: cuda.CUstream
:param epilogue_op: Optional elementwise lambda function to apply to the output tensor
:type epilogue_op: cutlass.Constexpr
:raises TypeError: If input data types are incompatible with the MMA instruction.
:raises AssertionError: If OOB (Out-Of-Bounds) tiles are present when TMA store is disabled.
"""
# Setup static attributes before smem/grid/tma computation
self.a_dtype: Type[cutlass.Numeric] = a.element_type
self.b_dtype: Type[cutlass.Numeric] = b.element_type
self.c_dtype: Type[cutlass.Numeric] = c.element_type
self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = utils.LayoutEnum.from_tensor(c)
# Check if input data types are compatible with MMA instruction
if cutlass.const_expr(self.a_dtype != self.b_dtype):
raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}")
tiled_mma = self._create_tiled_mma()
# Setup attributes that dependent on gemm inputs
self._setup_attributes()
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
# Setup TMA load for A
a_op = utils.sm100.cluster_shape_to_tma_atom_A(
self.cluster_shape_mn, tiled_mma.thr_id
)
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
a_op,
a,
a_smem_layout,
self.mma_tiler,
tiled_mma,
self.cluster_layout_vmnk.shape,
internal_type=(
cutlass.TFloat32 if a.element_type is cutlass.Float32 else None
),
)
# Setup TMA load for B
b_op = utils.sm100.cluster_shape_to_tma_atom_B(
self.cluster_shape_mn, tiled_mma.thr_id
)
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
b_op,
b,
b_smem_layout,
self.mma_tiler,
tiled_mma,
self.cluster_layout_vmnk.shape,
internal_type=(
cutlass.TFloat32 if b.element_type is cutlass.Float32 else None
),
)
a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout)
b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout)
self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size
# Setup TMA store for C
tma_atom_c = None
tma_tensor_c = None
if cutlass.const_expr(self.use_tma_store):
epi_smem_layout = cute.select(self.c_smem_layout_staged, mode=[0, 1])
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem_layout, self.epi_tile
)
# Compute grid size
self.tile_sched_params, grid = self._compute_grid(
c, self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters
)
# Launch the kernel synchronously
self.kernel(
tiled_mma,
tma_atom_a,
tma_tensor_a,
tma_atom_b,
tma_tensor_b,
tma_atom_c,
tma_tensor_c if self.use_tma_store else c,
self.cluster_layout_vmnk,
self.a_smem_layout_staged,
self.b_smem_layout_staged,
self.c_smem_layout_staged,
self.epi_tile,
self.tile_sched_params,
epilogue_op,
).launch(
grid=grid,
block=[self.threads_per_cta, 1, 1],
cluster=(*self.cluster_shape_mn, 1),
stream=stream,
)
return
# GPU device kernel
@cute.kernel
def kernel(
self,
tiled_mma: cute.TiledMma,
tma_atom_a: cute.CopyAtom,
mA_mkl: cute.Tensor,
tma_atom_b: cute.CopyAtom,
mB_nkl: cute.Tensor,
tma_atom_c: Optional[cute.CopyAtom],
mC_mnl: cute.Tensor,
cluster_layout_vmnk: cute.Layout,
a_smem_layout_staged: cute.ComposedLayout,
b_smem_layout_staged: cute.ComposedLayout,
c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
epi_tile: cute.Tile,
tile_sched_params: utils.PersistentTileSchedulerParams,
epilogue_op: cutlass.Constexpr,
):
"""
GPU device kernel performing the Persistent batched GEMM computation.
"""
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
#
# Prefetch tma desc
#
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
if cutlass.const_expr(self.use_tma_store):
cpasync.prefetch_descriptor(tma_atom_c)
use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
#
# Setup cta/thread coordinates
#
# Coords inside cluster
bidx, bidy, bidz = cute.arch.block_idx()
mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
is_leader_cta = mma_tile_coord_v == 0
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster()
)
block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(
cta_rank_in_cluster
)
# Coord inside cta
tidx, _, _ = cute.arch.thread_idx()
#
# Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier
#
# Define shared storage for kernel
@cute.struct
class SharedStorage:
ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
acc_full_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_acc_stage * 2
]
tmem_dealloc_mbar_ptr: cutlass.Int64
tmem_holding_buf: cutlass.Int32
smem = utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
# Initialize mainloop ab_pipeline (barrier) and states
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, num_tma_producer
)
ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
num_stages=self.num_ab_stage,
producer_group=ab_pipeline_producer_group,
consumer_group=ab_pipeline_consumer_group,
tx_count=self.num_tma_load_bytes,
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
).make_participants()
# Initialize acc_pipeline (barrier) and states
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
num_acc_consumer_threads = len(self.epilogue_warp_id) * (
2 if use_2cta_instrs else 1
)
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, num_acc_consumer_threads
)
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
num_stages=self.num_acc_stage,
producer_group=acc_pipeline_producer_group,
consumer_group=acc_pipeline_consumer_group,
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
)
tmem_alloc_barrier = pipeline.NamedBarrier(
barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)),
)
tmem_dealloc_barrier = None
if cutlass.const_expr(not self.use_tma_store):
tmem_dealloc_barrier = pipeline.NamedBarrier(
barrier_id=self.tmem_dealloc_sync_bar_id,
num_threads=32 * len(self.epilogue_warp_id),
)
# Tensor memory dealloc barrier init
tmem = utils.TmemAllocator(
storage.tmem_holding_buf,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=self.epilogue_warp_id[0],
is_two_cta=use_2cta_instrs,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
)
# Cluster arrive after barrier init
pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=True)
#
# Setup smem tensor A/B/C
#
# (MMA, MMA_M, MMA_K, STAGE)
sA = smem.allocate_tensor(
element_type=self.a_dtype,
layout=a_smem_layout_staged.outer,
byte_alignment=128,
swizzle=a_smem_layout_staged.inner,
)
# (MMA, MMA_N, MMA_K, STAGE)
sB = smem.allocate_tensor(
element_type=self.b_dtype,
layout=b_smem_layout_staged.outer,
byte_alignment=128,
swizzle=b_smem_layout_staged.inner,
)
#
# Compute multicast mask for A/B buffer full
#
a_full_mcast_mask = None
b_full_mcast_mask = None
if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
a_full_mcast_mask = cpasync.create_tma_multicast_mask(
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
)
b_full_mcast_mask = cpasync.create_tma_multicast_mask(
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1
)
#
# Local_tile partition global tensors
#
# (bM, bK, RestM, RestK, RestL)
gA_mkl = cute.local_tile(
mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)
)
# (bN, bK, RestN, RestK, RestL)
gB_nkl = cute.local_tile(
mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)
)
# (bM, bN, RestM, RestN, RestL)
gC_mnl = cute.local_tile(
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
)
k_tile_cnt = cute.size(gA_mkl, mode=[3])
#
# Partition global tensor for TiledMMA_A/B/C
#
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
# (MMA, MMA_M, MMA_K, RestM, RestK, RestL)
tCgA = thr_mma.partition_A(gA_mkl)
# (MMA, MMA_N, MMA_K, RestN, RestK, RestL)
tCgB = thr_mma.partition_B(gB_nkl)
# (MMA, MMA_M, MMA_N, RestM, RestN, RestL)
tCgC = thr_mma.partition_C(gC_mnl)
#
# Partition global/shared tensor for TMA load A/B
#
# TMA load A partition_S/D
a_cta_layout = cute.make_layout(
cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
)
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), RestM, RestK, RestL)
tAsA, tAgA = cpasync.tma_partition(
tma_atom_a,
block_in_cluster_coord_vmnk[2],
a_cta_layout,
cute.group_modes(sA, 0, 3),
cute.group_modes(tCgA, 0, 3),
)
# TMA load B partition_S/D
b_cta_layout = cute.make_layout(
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
)
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), RestM, RestK, RestL)
tBsB, tBgB = cpasync.tma_partition(
tma_atom_b,
block_in_cluster_coord_vmnk[1],
b_cta_layout,
cute.group_modes(sB, 0, 3),
cute.group_modes(tCgB, 0, 3),
)
#
# Partition shared/tensor memory tensor for TiledMMA_A/B/C
#
# (MMA, MMA_M, MMA_K, STAGE)
tCrA = tiled_mma.make_fragment_A(sA)
# (MMA, MMA_N, MMA_K, STAGE)
tCrB = tiled_mma.make_fragment_B(sB)
# (MMA, MMA_M, MMA_N)
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
# (MMA, MMA_M, MMA_N, STAGE)
tCtAcc_fake = tiled_mma.make_fragment_C(
cute.append(acc_shape, self.num_acc_stage)
)
#
# Cluster wait before tensor memory alloc
#
pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk)
#
# Construct the scheduler
#
tile_sched = utils.StaticPersistentTileScheduler.create(
tile_sched_params,
cute.arch.block_idx(),
cute.arch.grid_dim(),
)
work_tile = tile_sched.initial_work_tile_info()
#
# Specialized TMA load warp
#
if warp_idx == self.tma_warp_id:
#
# Persistent tile scheduling loop
#
while work_tile.is_valid_tile:
# Get tile coord from tile scheduler
cur_tile_coord = work_tile.tile_idx
mma_tile_coord_mnl = (
cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
cur_tile_coord[1],
cur_tile_coord[2],
)
#
# Slice to per mma tile index
#
# ((atom_v, rest_v), RestK)
tAgA_slice = tAgA[
(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])
]
# ((atom_v, rest_v), RestK)
tBgB_slice = tBgB[
(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])
]
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt
ab_producer.reset()
peek_ab_empty_status = ab_producer.try_acquire()
#
# Tma load loop
#
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
# Conditionally wait for AB buffer empty
handle = ab_producer.acquire_and_advance(peek_ab_empty_status)
# TMA load A/B
cute.copy(
tma_atom_a,
tAgA_slice[(None, handle.count)],
tAsA[(None, handle.index)],
tma_bar_ptr=handle.barrier,
mcast_mask=a_full_mcast_mask,
)
cute.copy(
tma_atom_b,
tBgB_slice[(None, handle.count)],
tBsB[(None, handle.index)],
tma_bar_ptr=handle.barrier,
mcast_mask=b_full_mcast_mask,
)
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1
peek_ab_empty_status = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt:
peek_ab_empty_status = ab_producer.try_acquire()
#
# Advance to next tile
#
tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
#
# Wait A/B buffer empty
#
ab_producer.tail()
#
# Specialized MMA warp
#
if warp_idx == self.mma_warp_id:
#
# Retrieving tensor memory ptr and make accumulator tensor
#
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
# (MMA, MMA_M, MMA_N, STAGE)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
#
# Persistent tile scheduling loop
#
acc_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_acc_stage
)
while work_tile.is_valid_tile:
# Get tile coord from tile scheduler
cur_tile_coord = work_tile.tile_idx
mma_tile_coord_mnl = (
cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
cur_tile_coord[1],
cur_tile_coord[2],
)
# Set tensor memory buffer for current tile
# (MMA, MMA_M, MMA_N)
tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)]
# Peek (try_wait) AB buffer full for k_tile = 0
ab_consumer.reset()
peek_ab_full_status = cutlass.Boolean(1)
if is_leader_cta:
peek_ab_full_status = ab_consumer.try_wait()
#
# Wait for accumulator buffer empty
#
if is_leader_cta:
acc_pipeline.producer_acquire(acc_producer_state)
#
# Reset the ACCUMULATE field for each tile
#
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
#
# Mma mainloop
#
for k_tile in range(k_tile_cnt):
if is_leader_cta:
# Conditionally wait for AB buffer full
handle = ab_consumer.wait_and_advance(peek_ab_full_status)
# tCtAcc += tCrA * tCrB
num_kblocks = cute.size(tCrA, mode=[2])
for kblk_idx in cutlass.range(num_kblocks, unroll_full=True):
kblk_crd = (None, None, kblk_idx, handle.index)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[kblk_crd],
tCrB[kblk_crd],
tCtAcc,
)
# Enable accumulate on tCtAcc after first kblock
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
# Async arrive AB buffer empty
handle.release()
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
peek_ab_full_status = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt:
peek_ab_full_status = ab_consumer.try_wait()
#
# Async arrive accumulator buffer full
#
if is_leader_cta:
acc_pipeline.producer_commit(acc_producer_state)
acc_producer_state.advance()
#
# Advance to next tile
#
tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
#
# Wait for accumulator buffer empty
#
acc_pipeline.producer_tail(acc_producer_state)
sC = None
if cutlass.const_expr(self.use_tma_store):
# (EPI_TILE_M, EPI_TILE_N, STAGE)
sC = smem.allocate_tensor(
element_type=self.c_dtype,
layout=c_smem_layout_staged.outer,
byte_alignment=128,
swizzle=c_smem_layout_staged.inner,
)
#
# Specialized epilogue warps
#
if warp_idx < self.mma_warp_id:
#
# Alloc tensor memory buffer
#
tmem.allocate(self.num_tmem_alloc_cols)
#
# Retrieving tensor memory ptr and make accumulator tensor
#
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
# (MMA, MMA_M, MMA_N, STAGE)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
#
# Persistent tile scheduling loop for epilogue
#
acc_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage
)
if cutlass.const_expr(self.use_tma_store):
assert tma_atom_c is not None and sC is not None
c_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
32 * len(self.epilogue_warp_id),
)
c_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.num_c_stage, producer_group=c_producer_group
)
while work_tile.is_valid_tile:
# Get tile coord from tile scheduler
cur_tile_coord = work_tile.tile_idx
mma_tile_coord_mnl = (
cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
cur_tile_coord[1],
cur_tile_coord[2],
)
#
# Pre-advance to next tile
#
tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
num_tiles_executed = tile_sched.num_tiles_executed
if cutlass.const_expr(self.use_tma_store):
acc_consumer_state = utils.gemm.sm100.epilogue_tma_store(
self,
tidx,
warp_idx,
tma_atom_c,
tCtAcc_base,
sC,
tCgC,
epi_tile,
num_tiles_executed,
epilogue_op,
mma_tile_coord_mnl,
acc_consumer_state,
acc_pipeline,
c_pipeline,
)
else:
acc_consumer_state = utils.gemm.sm100.epilogue(
self,
tidx,
tCtAcc_base,
tCgC,
epi_tile,
epilogue_op,
mma_tile_coord_mnl,
acc_consumer_state,
acc_pipeline,
)
if cutlass.const_expr(self.use_tma_store):
# Wait for C store complete
c_pipeline.producer_tail()
else:
# Synchronize before TMEM dealloc (done by the caller)
tmem_dealloc_barrier.arrive_and_wait()
#
# Dealloc the tensor memory buffer
#
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
@staticmethod
def _compute_grid(
c: cute.Tensor,
cta_tile_shape_mnk: Tuple[int, int, int],
cluster_shape_mn: Tuple[int, int],
max_active_clusters: cutlass.Constexpr,
) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]:
"""Use persistent tile scheduler to compute the grid size for the output tensor C.
:param c: The output tensor C
:type c: cute.Tensor
:param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile.
:type cta_tile_shape_mnk: tuple[int, int, int]
:param cluster_shape_mn: Shape of each cluster in M, N dimensions.
:type cluster_shape_mn: tuple[int, int]
:param max_active_clusters: Maximum number of active clusters.
:type max_active_clusters: cutlass.Constexpr
:return: A tuple containing:
- tile_sched_params: Parameters for the persistent tile scheduler.
- grid: Grid shape for kernel launch.
:rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]]
"""
c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0))
gc = cute.zipped_divide(c, tiler=c_shape)
num_ctas_mnl = gc[(0, (None, None, None))].shape
cluster_shape_mnl = (*cluster_shape_mn, 1)
tile_sched_params = utils.PersistentTileSchedulerParams(
num_ctas_mnl, cluster_shape_mnl
)
grid = utils.StaticPersistentTileScheduler.get_grid_shape(
tile_sched_params, max_active_clusters
)
return tile_sched_params, grid
@staticmethod
def _compute_num_tmem_alloc_cols(
tiled_mma: cute.TiledMma,
mma_tiler: Tuple[int, int, int],
num_acc_stage: int,
arch: str,
) -> int:
"""
Compute the number of tensor memory allocation columns.
:param tiled_mma: The tiled MMA object defining the core computation.
:type tiled_mma: cute.TiledMma
:param mma_tiler: The shape (M, N, K) of the MMA tile.
:type mma_tiler: tuple[int, int, int]
:param num_acc_stage: The stage of the accumulator tensor.
:type num_acc_stage: int
:return: The number of tensor memory allocation columns.
:rtype: int
"""
acc_shape = tiled_mma.partition_shape_C(mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, num_acc_stage))
num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake, arch=arch)
return num_tmem_alloc_cols
def check_supported_dtypes(
self,
a_dtype: Type[cutlass.Numeric],
b_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
):
"""
Check if the dtypes are valid
:param a_dtype: The data type of the A operands
:type a_dtype: Type[cutlass.Numeric]
:param b_dtype: The data type of the B operands
:type b_dtype: Type[cutlass.Numeric]
:param acc_dtype: The data type of the accumulator
:type acc_dtype: Type[cutlass.Numeric]
:param c_dtype: The data type of the output tensor
:type c_dtype: Type[cutlass.Numeric]
:raises testing.CantImplementError: If the dtypes are invalid
"""
valid_ab_dtypes = {
cutlass.Float16,
cutlass.BFloat16,
cutlass.TFloat32,
cutlass.Uint8,
cutlass.Int8,
cutlass.Float8E4M3FN,
cutlass.Float8E5M2,
}
if a_dtype not in valid_ab_dtypes or b_dtype not in valid_ab_dtypes:
raise testing.CantImplementError(
f"Unsupported AB dtype: {a_dtype} and {b_dtype}"
)
if self.acc_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.Int32}:
raise testing.CantImplementError(
f"Unsupported accumulator dtype: {self.acc_dtype}"
)
# Define compatibility mapping between accumulator type and AB type
acc_ab_compatibility = {
cutlass.Float32: {
cutlass.Float16,
cutlass.BFloat16,
cutlass.TFloat32,
cutlass.Float8E4M3FN,
cutlass.Float8E5M2,
}, # Float32 accumulator supports floating point AB types only
cutlass.Float16: {
cutlass.Float16,
cutlass.Float8E4M3FN,
cutlass.Float8E5M2,
},
cutlass.Int32: {cutlass.Uint8, cutlass.Int8},
}
# Check compatibility between accumulator type and AB type
if (
a_dtype not in acc_ab_compatibility[self.acc_dtype]
or b_dtype not in acc_ab_compatibility[self.acc_dtype]
):
raise testing.CantImplementError(
f"Unsupported AB dtype: {a_dtype} and {b_dtype} for accumulator dtype: {self.acc_dtype}"
)
# Define compatibility mapping between accumulator type and C type
acc_c_compatibility = {
cutlass.Float32: {
cutlass.Float32,
cutlass.Float16,
cutlass.BFloat16,
cutlass.Float8E4M3FN,
cutlass.Float8E5M2,
cutlass.Int32,
cutlass.Int8,
cutlass.Uint8,
},
cutlass.Float16: {
cutlass.BFloat16,
cutlass.Float16,
},
cutlass.Int32: {
cutlass.BFloat16,
cutlass.Float16,
cutlass.Float32,
cutlass.Int32,
cutlass.Int8,
cutlass.Uint8,
},
}
# Check compatibility between accumulator type and C type
if c_dtype not in acc_c_compatibility[self.acc_dtype]:
raise testing.CantImplementError(
f"Unsupported C dtype: {c_dtype} for accumulator dtype: {self.acc_dtype}"
)
def check_mma_tiler_and_cluster_shape(self):
"""Check if the mma tiler and cluster shape are valid.
:raises testing.CantImplementError: If the mma tiler and cluster shape are invalid
"""
# Skip invalid mma tile shape
if not (
(not self.use_2cta_instrs and self.mma_tiler_mn[0] in [64, 128])
or (self.use_2cta_instrs and self.mma_tiler_mn[0] in [128, 256])
):
raise testing.CantImplementError(
f"Invalid mma tiler & use_2cta_instrs: {self.mma_tiler_mn}, {self.use_2cta_instrs}"
)
if self.mma_tiler_mn[1] not in range(32, 257, 32):
raise testing.CantImplementError(
f"Invalid mma tiler N: {self.mma_tiler_mn[1]}"
)
# Skip illegal cluster shape
if self.cluster_shape_mn[0] % (2 if self.use_2cta_instrs else 1) != 0:
raise testing.CantImplementError(
f"Invalid cluster shape M: {self.cluster_shape_mn[0]}"
)
# Skip invalid cluster shape
is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0
if (
self.cluster_shape_mn[0] * self.cluster_shape_mn[1] > 16
or self.cluster_shape_mn[0] <= 0
or self.cluster_shape_mn[1] <= 0
or not is_power_of_2(self.cluster_shape_mn[0])
or not is_power_of_2(self.cluster_shape_mn[1])
):
raise testing.CantImplementError(
f"Invalid cluster shape: {self.cluster_shape_mn}"
)
def check_tensor_alignment(
self,
m: int,
n: int,
k: int,
l: int,
a_dtype: Type[cutlass.Numeric],
b_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
a_major: str,
b_major: str,
c_major: str,
):
"""
Check if the tensor alignment is valid
:param m: The number of rows in the A tensor
:type m: int
:param n: The number of columns in the B tensor
:type n: int
:param k: The number of columns in the A tensor
:type k: int
:param l: The number of columns in the C tensor
:type l: int
:param a_dtype: The data type of the A operands
:type a_dtype: Type[cutlass.Numeric]
:param b_dtype: The data type of the B operands
:type b_dtype: Type[cutlass.Numeric]
:param c_dtype: The data type of the output tensor
:type c_dtype: Type[cutlass.Numeric]
:param a_major: The major axis of the A tensor
:type a_major: str
:param b_major: The major axis of the B tensor
:type b_major: str
:param c_major: The major axis of the C tensor
:type c_major: str
:raises testing.CantImplementError: If the tensor alignment is invalid
"""
# TODO: move to utils
def check_contiguous_16B_alignment(dtype, is_mode0_major, tensor_shape):
major_mode_idx = 0 if is_mode0_major else 1
num_major_elements = tensor_shape[major_mode_idx]
num_contiguous_elements = 16 * 8 // dtype.width
return num_major_elements % num_contiguous_elements == 0
if (
not check_contiguous_16B_alignment(a_dtype, a_major == "m", (m, k, l))
or not check_contiguous_16B_alignment(b_dtype, b_major == "n", (n, k, l))
or not check_contiguous_16B_alignment(c_dtype, c_major == "m", (m, n, l))
):
raise testing.CantImplementError(
f"Invalid tensor alignment: {m}, {n}, {k}, {l}, {a_dtype}, {b_dtype}, {c_dtype}, {a_major}, {b_major}, {c_major}"
)
def check_epilog_store_option(self, m: int, n: int):
"""
Check if the epilogue store option is valid
:param m: The number of rows in the A tensor
:type m: int
:param n: The number of columns in the B tensor
:type n: int
:raises testing.CantImplementError: If the epilogue store option is invalid
"""
# None TMA store version does not have predication, can not support OOB tiles
cta_tile_shape_mn = (
self.mma_tiler_mn[0] // (2 if self.use_2cta_instrs else 1),
self.mma_tiler_mn[1],
)
if not self.use_tma_store:
if not (m % cta_tile_shape_mn[0] == 0 and n % cta_tile_shape_mn[1] == 0):
raise testing.CantImplementError(
f"Invalid epilog store option: {m}, {n}"
)
def can_implement(
self,
mnkl: Tuple[int, int, int, int],
a_dtype: Type[cutlass.Numeric],
b_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
a_major: str,
b_major: str,
c_major: str,
) -> bool:
"""
Determine if the given tensor configuration can be implemented by this kernel.
:param mnkl: Problem size as a tuple (M, N, K, L).
:type mnkl: Tuple[int, int, int, int]
:param a_dtype: Data type for input tensors A.
:type a_dtype: Type[cutlass.Numeric]
:param b_dtype: Data type for input tensors B.
:type b_dtype: Type[cutlass.Numeric]
:param c_dtype: Data type for output tensor C.
:type c_dtype: Type[cutlass.Numeric]
:param a_major: Major dimension of the A tensor layout ("m" or "k").
:type a_major: str
:param b_major: Major dimension of the B tensor layout ("n" or "k").
:type b_major: str
:param c_major: Major dimension of the C tensor layout ("m" or "n").
:type c_major: str
:return: True if the kernel supports the given configuration, False otherwise.
:rtype: bool
"""
try:
# Skip unsupported types
self.check_supported_dtypes(a_dtype, b_dtype, c_dtype)
# Skip invalid mma tile shape and cluster shape
self.check_mma_tiler_and_cluster_shape()
m, n, k, l = mnkl
self.check_tensor_alignment(
m, n, k, l, a_dtype, b_dtype, c_dtype, a_major, b_major, c_major
)
self.check_epilog_store_option(m, n)
except testing.CantImplementError:
return False
return True
@cute.jit
def bmm(
gemm_op: cutlass.Constexpr,
a: cute.Tensor, # (l, m, k)
b: cute.Tensor, # (l, k, n)
c: cute.Tensor, # (l, m, n)
max_active_clusters: cutlass.Constexpr,
stream: cuda.CUstream,
epilogue_op: cutlass.Constexpr = lambda x: x,
):
"""
Wrapper API for persistent GEMM kernel to follow the convention of PyTorch's batch matrix-multiply (bmm).
Internally, the tensors are permuted to match CuTe's convention:
- a: (m, k, l)
- b: (n, k, l)
- c: (m, n, l)
:param gemm_op: Kernel operation, expects (a, b, c, max_active_clusters, stream, epilogue_op)
:type gemm_op: cutlass.Constexpr
:param a: Input tensor of shape (l, m, k)
:type a: cute.Tensor
:param b: Input tensor of shape (l, k, n)
:type b: cute.Tensor
:param c: Output tensor of shape (l, m, n)
:type c: cute.Tensor
:param max_active_clusters: Maximum number of hardware clusters to launch
:type max_active_clusters: cutlass.Constexpr
:param epilogue_op: Optional elementwise lambda function to apply per output element, defaults to identity
:type epilogue_op: cutlass.Constexpr, optional
"""
# (l,m,k) -> (m,k,l)
a = cute.make_tensor(a.iterator, cute.select(a.layout, mode=[1, 2, 0]))
# (l,k,n) -> (n,k,l)
b = cute.make_tensor(b.iterator, cute.select(b.layout, mode=[2, 1, 0]))
# (l,m,n) -> (m,n,l)
c = cute.make_tensor(c.iterator, cute.select(c.layout, mode=[1, 2, 0]))
gemm_op(a, b, c, max_active_clusters, stream, epilogue_op)
@lru_cache(maxsize=1)
def prepare_tensors(
mnkl: Tuple[int, int, int, int],
a_dtype: Type[cutlass.Numeric],
b_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
a_major: str,
b_major: str,
c_major: str,
init_random: bool = True,
normal_mean: float = 0.0,
normal_std: float = 1.0,
):
"""Prepare tensors for GEMM.
Returns:
Tuple of (a_f32, b_f32, c_f32, a_storage, b_storage, c_storage):
- *_f32: Float32 tensors with the logical data (for reference and fp8 conversion)
- *_storage: Storage tensors for DLPack (uint8 for fp8, otherwise the target dtype)
"""
import torch
from cutlass.torch import dtype as torch_dtype
m, n, k, l = mnkl
if a_major == "k":
a_f32 = torch.empty((l, m, k), dtype=torch.float32, device="cuda")
elif a_major == "m":
a_f32 = torch.empty((l, k, m), dtype=torch.float32, device="cuda").permute(
0, 2, 1
)
if b_major == "n":
b_f32 = torch.empty((l, k, n), dtype=torch.float32, device="cuda")
elif b_major == "k":
b_f32 = torch.empty((l, n, k), dtype=torch.float32, device="cuda").permute(
0, 2, 1
)
if c_major == "n":
c_f32 = torch.empty((l, m, n), dtype=torch.float32, device="cuda")
elif c_major == "m":
c_f32 = torch.empty((l, n, m), dtype=torch.float32, device="cuda").permute(
0, 2, 1
)
if init_random:
# Uniform random initialization in range [-2, 3)
a_f32.random_(-2, 3)
b_f32.random_(-2, 3)
c_f32.random_(-2, 3)
else:
# Normal (Gaussian) initialization with user-specified mean and std
a_f32.normal_(mean=normal_mean, std=normal_std)
b_f32.normal_(mean=normal_mean, std=normal_std)
c_f32.normal_(mean=normal_mean, std=normal_std)
# For float8 types, use uint8 as storage type to avoid dlpack limitation
# (dlpack doesn't support float8 types)
# For other types, convert to the target dtype
a_storage_dtype = torch.uint8 if is_fp8_dtype(a_dtype) else torch_dtype(a_dtype)
b_storage_dtype = torch.uint8 if is_fp8_dtype(b_dtype) else torch_dtype(b_dtype)
c_storage_dtype = torch.uint8 if is_fp8_dtype(c_dtype) else torch_dtype(c_dtype)
a_storage = a_f32.to(dtype=a_storage_dtype)
b_storage = b_f32.to(dtype=b_storage_dtype)
c_storage = c_f32.to(dtype=c_storage_dtype)
return (a_f32, b_f32, c_f32, a_storage, b_storage, c_storage)
@lru_cache(maxsize=1)
def compile_bmm(
mnkl: Tuple[int, int, int, int],
a: cute.Tensor,
b: cute.Tensor,
c: cute.Tensor,
acc_dtype: Type[cutlass.Numeric],
a_major: str,
b_major: str,
c_major: str,
mma_tiler_mn: Tuple[int, int] = (256, 256),
cluster_shape_mn: Tuple[int, int] = (2, 1),
max_active_clusters: cutlass.Constexpr = None,
use_2cta_instrs: bool = True,
use_tma_store: bool = True,
epilogue_op: cutlass.Constexpr = lambda x: x,
):
from cutlass.cute.runtime import make_fake_stream
gemm = PersistentDenseGemmKernel(
acc_dtype,
use_2cta_instrs,
mma_tiler_mn,
cluster_shape_mn,
use_tma_store,
)
# Check if configuration can be implemented
can_implement = gemm.can_implement(
mnkl, a.element_type, b.element_type, c.element_type, a_major, b_major, c_major
)
if not can_implement:
raise testing.CantImplementError(
f"The current config which is invalid/unsupported: use_2cta_instrs = {use_2cta_instrs}, "
f"mma_tiler_mn = {mma_tiler_mn}, cluster_shape_mn = {cluster_shape_mn}, "
f"use_tma_store = {use_tma_store}"
)
stream = make_fake_stream()
return cute.compile(bmm, gemm, a, b, c, max_active_clusters, stream, epilogue_op)
def run(
mnkl: Tuple[int, int, int, int],
ab_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
acc_dtype: Type[cutlass.Numeric],
a_major: str,
b_major: str,
c_major: str,
mma_tiler_mn: Tuple[int, int] = (256, 256),
cluster_shape_mn: Tuple[int, int] = (2, 1),
use_2cta_instrs: bool = True,
use_tma_store: bool = True,
tolerance: float = 1e-01,
warmup_iterations: int = 0,
iterations: int = 1,
skip_ref_check: bool = False,
use_cold_l2: bool = False,
benchmark: bool = False,
**kwargs,
):
"""
Execute a persistent batched dense GEMM operation on Blackwell architecture with performance benchmarking.
Prepares input tensors, configures and launches the persistent GEMM kernel,
optionally performs reference validation, and benchmarks execution.
:param mnkl: Problem size as a tuple (M, N, K, L).
:type mnkl: Tuple[int, int, int, int]
:param ab_dtype: Data type for input tensors A and B.
:type ab_dtype: Type[cutlass.Numeric]
:param c_dtype: Data type for output tensor C.
:type c_dtype: Type[cutlass.Numeric]
:param acc_dtype: Accumulator data type for the matrix multiplication.
:type acc_dtype: Type[cutlass.Numeric]
:param a_major: Memory layout of tensor A.
:type a_major: str
:param b_major: Memory layout of tensor B.
:type b_major: str
:param c_major: Memory layout of tensor C.
:type c_major: str
:param mma_tiler_mn: MMA tiling size (M, N), defaults to (256, 256).
:type mma_tiler_mn: Tuple[int, int], optional
:param cluster_shape_mn: Cluster shape (M, N), defaults to (2, 1).
:type cluster_shape_mn: Tuple[int, int], optional
:param use_2cta_instrs: Whether to use 2CTA MMA instructions, defaults to True.
:type use_2cta_instrs: bool, optional
:param use_tma_store: Whether to use TMA store, defaults to True.
:type use_tma_store: bool, optional
:param tolerance: Tolerance for reference validation, defaults to 1e-01.
:type tolerance: float, optional
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0.
:type warmup_iterations: int, optional
:param iterations: Number of benchmark iterations to run, defaults to 1.
:type iterations: int, optional
:param skip_ref_check: Whether to skip reference result validation, defaults to False.
:type skip_ref_check: bool, optional
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False.
:type use_cold_l2: bool, optional
:param benchmark: Whether to only benchmark the kernel, defaults to False.
:type benchmark: bool, optional
:raises RuntimeError: If CUDA GPU is not available.
:raises ValueError: If the configuration is invalid or unsupported by the kernel.
:return: Execution time of the GEMM kernel.
:rtype: float
"""
import torch
from cutlass.torch import dtype as torch_dtype
if not torch.cuda.is_available():
raise RuntimeError("GPU is required to run this example!")
# Get current CUDA stream from PyTorch
torch_stream = torch.cuda.current_stream()
# Get the raw stream pointer as a CUstream
current_stream = cuda.CUstream(torch_stream.cuda_stream)
# Check if configuration can be implemented
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(
cluster_shape_mn[0] * cluster_shape_mn[1]
)
# Run and verify BMM with torch
a_f32, b_f32, c_f32, a_storage, b_storage, c_storage = prepare_tensors(
mnkl, ab_dtype, ab_dtype, c_dtype, a_major, b_major, c_major
)
leading_dim_a = 2 if a_major == "k" else 1
leading_dim_b = 1 if b_major == "k" else 2
leading_dim_c = 2 if c_major == "n" else 1
# Create CuTe tensors, passing float32 source for fp8 conversion
a_ = create_cute_tensor_for_fp8(
a_storage, ab_dtype, leading_dim_a, source_f32_tensor=a_f32
)
b_ = create_cute_tensor_for_fp8(
b_storage, ab_dtype, leading_dim_b, source_f32_tensor=b_f32
)
c_ = create_cute_tensor_for_fp8(
c_storage, c_dtype, leading_dim_c, source_f32_tensor=c_f32
)
compiled_fn = compile_bmm(
mnkl,
a_,
b_,
c_,
acc_dtype,
a_major,
b_major,
c_major,
mma_tiler_mn,
cluster_shape_mn,
max_active_clusters,
use_2cta_instrs,
use_tma_store,
epilogue_op=lambda x: x,
)
print("Running Blackwell Persistent Dense GEMM test with:")
print(f"mnkl: {mnkl}")
print(f"Tolerance: {tolerance}")
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
print(f"Skip reference checking: {skip_ref_check}")
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
if not skip_ref_check:
# Use small random number for deterministic result for reference check
compiled_fn(a_, b_, c_, current_stream)
# Manually quantize to be comparable
# Use float32 source data for reference calculation
ref = (
torch.bmm(a_f32, b_f32)
.to(dtype=torch_dtype(c_dtype))
.to(dtype=torch.float32)
)
# Read back the result from CuTe tensor (c_storage was updated in-place)
torch.testing.assert_close(
c_storage.to(dtype=torch.float32), ref, atol=tolerance, rtol=1e-03
)
if not benchmark:
return 0
def generate_tensors():
a_f32, b_f32, c_f32, a_st, b_st, c_st = prepare_tensors(
mnkl,
ab_dtype,
ab_dtype,
c_dtype,
a_major,
b_major,
c_major,
)
a_ = create_cute_tensor_for_fp8(
a_st, ab_dtype, leading_dim_a, source_f32_tensor=a_f32
)
b_ = create_cute_tensor_for_fp8(
b_st, ab_dtype, leading_dim_b, source_f32_tensor=b_f32
)
c_ = create_cute_tensor_for_fp8(
c_st, c_dtype, leading_dim_c, source_f32_tensor=c_f32
)
return testing.JitArguments(a_, b_, c_, current_stream)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
a_storage.numel() * a_storage.element_size()
+ b_storage.numel() * b_storage.element_size()
+ c_storage.numel() * c_storage.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
# Return execution time in microseconds
return testing.benchmark(
compiled_fn,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=current_stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
def compute_tflops(time_ns, m, n, k):
return 2.0 * m * n * k / time_ns / 1000.0
def _parse_comma_separated_ints(s: str) -> Tuple[int, ...]:
try:
return tuple(int(x.strip()) for x in s.split(","))
except ValueError:
raise argparse.ArgumentTypeError(
"Invalid format. Expected comma-separated integers."
)
def prepare_parser():
parser = argparse.ArgumentParser(
description="Example of Dense Persistent GEMM on Blackwell."
)
parser.add_argument(
"--mnkl",
type=_parse_comma_separated_ints,
default=(256, 256, 512, 1),
help="mnkl dimensions (comma-separated)",
)
parser.add_argument(
"--cluster_shape_mn",
type=_parse_comma_separated_ints,
default=(1, 1),
help="Cluster shape (comma-separated)",
)
parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.TFloat32)
parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float32)
parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32)
parser.add_argument(
"--use_2cta_instrs",
action="store_true",
help="Enable 2CTA MMA instructions feature",
)
parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n")
parser.add_argument(
"--use_tma_store", action="store_true", help="Use tma store or not"
)
parser.add_argument(
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
)
parser.add_argument(
"--benchmark",
type=str,
default="default",
choices=[
"default",
"none",
],
help="Benchmark the kernel with nsight or default (cute.testing.benchmark) or none",
)
parser.add_argument(
"--warmup_iterations", type=int, default=0, help="Warmup iterations"
)
parser.add_argument(
"--iterations",
type=int,
default=1,
help="Number of iterations to run the kernel",
)
parser.add_argument(
"--skip_ref_check", action="store_true", help="Skip reference checking"
)
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
return parser
if __name__ == "__main__":
parser = prepare_parser()
parser.add_argument(
"--mma_tiler_mn",
type=_parse_comma_separated_ints,
default=(128, 128),
help="Mma tile shape (comma-separated)",
)
args = parser.parse_args()
if len(args.mnkl) != 4:
parser.error("--mnkl must contain exactly 4 values")
if len(args.mma_tiler_mn) != 2:
parser.error("--mma_tiler_mn must contain exactly 2 values")
if len(args.cluster_shape_mn) != 2:
parser.error("--cluster_shape_mn must contain exactly 2 values")
print(f"[DSL INFO] Compiling Blackwell Persistent Dense GEMM with:")
print(
f"[DSL INFO] A dtype: {args.ab_dtype}, B dtype: {args.c_dtype}, C dtype: {args.acc_dtype}, Acc dtype: {args.acc_dtype}"
)
print(
f"[DSL INFO] Matrix majors - A: {args.a_major}, B: {args.b_major}, C: {args.c_major}"
)
print(f"[DSL INFO] Mma Tiler (M, N): {args.mma_tiler_mn}")
print(f"[DSL INFO] Cluster Shape (M, N): {args.cluster_shape_mn}")
print(
f"[DSL INFO] 2CTA MMA instructions: {'True' if args.use_2cta_instrs else 'False'}"
)
print(f"[DSL INFO] Use TMA Store: {'True' if args.use_tma_store else 'False'}")
run(
args.mnkl,
args.ab_dtype,
args.c_dtype,
args.acc_dtype,
args.a_major,
args.b_major,
args.c_major,
args.mma_tiler_mn,
args.cluster_shape_mn,
args.use_2cta_instrs,
args.use_tma_store,
args.tolerance,
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
args.benchmark == "default",
)
print("PASS")