# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # 1. Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import argparse from typing import Optional, Tuple, Type, Union from functools import lru_cache import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute import cutlass.cute.testing as testing import cutlass.utils as utils from cutlass.utils import is_fp8_dtype, create_cute_tensor_for_fp8 import cutlass.pipeline as pipeline from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait from cutlass.cute.nvgpu import cpasync, tcgen05 """ A high-performance persistent batched dense GEMM example for the NVIDIA Blackwell SM100 architecture using CUTE DSL. - Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") - Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") - Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") This GEMM kernel supports the following features: - Utilizes Tensor Memory Access (TMA) for efficient memory operations - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations (including 2cta mma instructions) - Implements TMA multicast with cluster to reduce L2 memory traffic - Support persistent tile scheduling to better overlap memory load/store with mma between tiles - Support warp specialization to avoid explicit pipelining between mainloop load and mma This GEMM works as follows: 1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. 2. MMA warp: Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. 3. EPILOGUE warp: - Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. - Type convert C matrix to output type. - Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations, or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations. - Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor: e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0)) SM100 tcgen05.mma instructions operate as follows: - Read matrix A from SMEM - Read matrix B from SMEM - Write accumulator to TMEM The accumulator in TMEM must then be loaded to registers before writing back to GMEM. Input arguments to this example is same as dense_gemm.py. .. code-block:: bash python examples/blackwell/dense_gemm_persistent.py \ --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ --mnkl 8192,8192,8192,1 \ --use_tma_store --use_2cta_instrs To collect performance with NCU profiler: .. code-block:: bash ncu python examples/blackwell/dense_gemm_persistent.py \ --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ --mnkl 8192,8192,8192,1 \ --use_tma_store --use_2cta_instrs \ --warmup_iterations 1 --iterations 10 --skip_ref_check Constraints are same as dense_gemm.py: * Supported input data types: fp16, bf16, tf32, int8, uint8, fp8 (e4m3fn, e5m2), see detailed valid dtype combinations in below PersistentDenseGemmKernel class documentation * A/B tensor must have the same data type * Mma tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) * Mma tiler N must be 32-256, step 32 * Cluster shape M/N must be positive and power of 2, total cluster size <= 16 * Cluster shape M must be multiple of 2 if use_2cta_instrs=True * The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, i.e, number of elements is a multiple of 4, 8, and 16 for TFloat32, Float16/BFloat16, and Int8/Uint8/Float8, respectively. * OOB tiles are not allowed when TMA store is disabled """ def _compute_stages( tiled_mma: cute.TiledMma, mma_tiler_mnk: Tuple[int, int, int], a_dtype: Type[cutlass.Numeric], b_dtype: Type[cutlass.Numeric], c_dtype: Type[cutlass.Numeric], smem_capacity: int, occupancy: int, use_tma_store: bool, c_smem_layout: Union[cute.Layout, None], ) -> Tuple[int, int, int]: """Computes the number of stages for A/B/C operands based on heuristics. :param tiled_mma: The tiled MMA object defining the core computation. :type tiled_mma: cute.TiledMma :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. :type mma_tiler_mnk: tuple[int, int, int] :param a_dtype: Data type of operand A. :type a_dtype: type[cutlass.Numeric] :param b_dtype: Data type of operand B. :type b_dtype: type[cutlass.Numeric] :param c_dtype: Data type of operand C (output). :type c_dtype: type[cutlass.Numeric] :param smem_capacity: Total available shared memory capacity in bytes. :type smem_capacity: int :param occupancy: Target number of CTAs per SM (occupancy). :type occupancy: int :param use_tma_store: Whether TMA store is enabled. :type use_tma_store: bool :param c_smem_layout: Layout of C operand in shared memory, or None if not using TMA store. :type c_smem_layout: Union[cute.Layout, None] :return: A tuple containing the computed number of stages for: (ACC stages, A/B operand stages, C stages) :rtype: tuple[int, int, int] """ # Default ACC stages num_acc_stage = 2 # Default C stages num_c_stage = 2 if use_tma_store else 0 # Calculate smem layout and size for one stage of A, B, and C with 1-stage a_smem_layout_stage_one = utils.sm100.make_smem_layout_a( tiled_mma, mma_tiler_mnk, a_dtype, 1 ) b_smem_layout_staged_one = utils.sm100.make_smem_layout_b( tiled_mma, mma_tiler_mnk, b_dtype, 1 ) ab_bytes_per_stage = cute.size_in_bytes( a_dtype, a_smem_layout_stage_one ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) mbar_helpers_bytes = 1024 c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout) c_bytes = c_bytes_per_stage * num_c_stage # Calculate A/B stages: # Start with total smem per CTA (capacity / occupancy) # Subtract reserved bytes and initial C stages bytes # Divide remaining by bytes needed per A/B stage num_ab_stage = ( smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) ) // ab_bytes_per_stage # Refine epilogue stages: # Calculate remaining smem after allocating for A/B stages and reserved bytes # Add remaining unused smem to epilogue if use_tma_store: num_c_stage += ( smem_capacity - occupancy * ab_bytes_per_stage * num_ab_stage - occupancy * (mbar_helpers_bytes + c_bytes) ) // (occupancy * c_bytes_per_stage) return num_acc_stage, num_ab_stage, num_c_stage class PersistentDenseGemmKernel: """This class implements batched matrix multiplication (C = A x B) with support for various data types and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. :param acc_dtype: Data type for accumulation during computation :type acc_dtype: type[cutlass.Numeric] :param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation :type use_2cta_instrs: bool :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) :type mma_tiler_mn: Tuple[int, int] :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing :type cluster_shape_mn: Tuple[int, int] :param use_tma_store: Whether to use Tensor Memory Access (TMA) for storing results :type use_tma_store: bool :note: In current version, A and B tensor must have the same data type - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported :note: Supported A/B data types: - TFloat32 - Float16/BFloat16 - Int8/Uint8 - Float8E4M3FN/Float8E5M2 :note: Supported accumulator data types: - Float32 (for all floating point A/B data types) - Float16 (only for fp16 and fp8 A/B data types) - Int32 (only for uint8/int8 A/B data types) :note: Supported C data types: - Float32 (for float32 and int32 accumulator data types) - Int32 (for float32 and int32 accumulator data types) - Float16/BFloat16 (for fp16 and fp8 accumulator data types) - Int8/Uint8 (for uint8/int8 accumulator data types) - Float8E4M3FN/Float8E5M2 (for float32 accumulator data types) :note: Constraints: - MMA tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) - MMA tiler N must be 32-256, step 32 - Cluster shape M must be multiple of 2 if use_2cta_instrs=True - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 **Example:** gemm = PersistentDenseGemmKernel( acc_dtype=cutlass.Float32, use_2cta_instrs=True, mma_tiler_mn=(128, 128), cluster_shape_mn=(2, 2) ) gemm(a, b, c, max_active_clusters, stream) """ def __init__( self, acc_dtype: Type[cutlass.Numeric], use_2cta_instrs: bool, mma_tiler_mn: Tuple[int, int], cluster_shape_mn: Tuple[int, int], use_tma_store: bool, ): """Initializes the configuration for a Blackwell dense GEMM kernel. This configuration includes several key aspects: 1. MMA Instruction Settings (tcgen05): - acc_dtype: Data types for MMA accumulator. - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. - use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant with cta_group=2 should be used. 2. Cluster Shape: - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. 3. Output C tensor store mode: - use_tma_store: Boolean indicating whether to use Tensor Memory Access (TMA) for storing results. :param acc_dtype: Data type of the accumulator. :type acc_dtype: type[cutlass.Numeric] :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. :type mma_tiler_mn: Tuple[int, int] :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. :type use_2cta_instrs: bool :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. :type cluster_shape_mn: Tuple[int, int] :param use_tma_store: Use Tensor Memory Access (TMA) or normal store for output C tensor. :type use_tma_store: bool """ self.acc_dtype: Type[cutlass.Numeric] = acc_dtype self.use_2cta_instrs = use_2cta_instrs self.cluster_shape_mn = cluster_shape_mn # K dimension is deferred in _setup_attributes self.mma_tiler_mn = mma_tiler_mn self.mma_tiler = (*mma_tiler_mn, 1) self.use_tma_store = use_tma_store self.arch = "sm_100" self.cta_group = ( tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE ) self.occupancy = 1 # Set specialized warp ids self.epilogue_warp_id = (0, 1, 2, 3) self.mma_warp_id = 4 self.tma_warp_id = 5 self.threads_per_cta = 32 * len( (self.mma_warp_id, self.tma_warp_id, *self.epilogue_warp_id) ) # Set barrier id for cta sync, epilogue sync and tmem ptr sync self.epilog_sync_bar_id = 1 self.tmem_alloc_sync_bar_id = 2 self.tmem_dealloc_sync_bar_id = 3 def _create_tiled_mma(self): return utils.sm100.make_trivial_tiled_mma( self.a_dtype, self.a_major_mode, self.b_major_mode, self.acc_dtype, self.cta_group, self.mma_tiler[:2], ) def _setup_attributes(self): """Set up configurations that are dependent on GEMM inputs This method configures various attributes based on the input tensor properties (data types, leading dimensions) and kernel settings: - Configuring tiled MMA - Computing MMA/cluster/tile shapes - Computing cluster layout - Computing multicast CTAs for A/B - Computing epilogue subtile - Setting up A/B/C stage counts in shared memory - Computing A/B/C shared memory layout - Computing tensor memory allocation columns """ # Configure tiled mma tiled_mma = self._create_tiled_mma() # Compute mma/cluster/tile shapes mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) mma_inst_tile_k = 4 self.mma_tiler = ( self.mma_tiler[0], self.mma_tiler[1], mma_inst_shape_k * mma_inst_tile_k, ) self.cta_tile_shape_mnk = ( self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), self.mma_tiler[1], self.mma_tiler[2], ) # Compute cluster layout self.cluster_layout_vmnk = cute.tiled_divide( cute.make_layout((*self.cluster_shape_mn, 1)), (tiled_mma.thr_id.shape,), ) # Compute number of multicast CTAs for A/B self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) self.is_a_mcast = self.num_mcast_ctas_a > 1 self.is_b_mcast = self.num_mcast_ctas_b > 1 # Compute epilogue subtile if cutlass.const_expr(self.use_tma_store): self.epi_tile = utils.sm100.compute_epilogue_tile_shape( self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.c_dtype, ) else: self.epi_tile = self.cta_tile_shape_mnk[:2] c_smem_layout = None if cutlass.const_expr(self.use_tma_store): c_smem_layout = utils.sm100.make_smem_layout_epi( self.c_dtype, self.c_layout, self.epi_tile, 1 ) self.smem_capacity = utils.get_smem_capacity_in_bytes() # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory self.num_acc_stage, self.num_ab_stage, self.num_c_stage = _compute_stages( tiled_mma, self.mma_tiler, self.a_dtype, self.b_dtype, self.c_dtype, self.smem_capacity, self.occupancy, self.use_tma_store, c_smem_layout, ) # Compute A/B/C shared memory layout self.a_smem_layout_staged = utils.sm100.make_smem_layout_a( tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage ) self.b_smem_layout_staged = utils.sm100.make_smem_layout_b( tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage ) self.c_smem_layout_staged = None if self.use_tma_store: self.c_smem_layout_staged = utils.sm100.make_smem_layout_epi( self.c_dtype, self.c_layout, self.epi_tile, self.num_c_stage ) # Compute the number of tensor memory allocation columns self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols( tiled_mma, self.mma_tiler, self.num_acc_stage, self.arch ) @cute.jit def __call__( self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, max_active_clusters: cutlass.Constexpr, stream: cuda.CUstream, epilogue_op: cutlass.Constexpr = lambda x: x, ): """Execute the GEMM operation in steps: - Setup static attributes before smem/grid/tma computation - Setup TMA load/store atoms and tensors - Compute grid size with regard to hardware constraints - Define shared storage for kernel - Launch the kernel synchronously :param a: Input tensor A :type a: cute.Tensor :param b: Input tensor B :type b: cute.Tensor :param c: Output tensor C :type c: cute.Tensor :param max_active_clusters: Maximum number of active clusters :type max_active_clusters: cutlass.Constexpr :param stream: CUDA stream for asynchronous execution :type stream: cuda.CUstream :param epilogue_op: Optional elementwise lambda function to apply to the output tensor :type epilogue_op: cutlass.Constexpr :raises TypeError: If input data types are incompatible with the MMA instruction. :raises AssertionError: If OOB (Out-Of-Bounds) tiles are present when TMA store is disabled. """ # Setup static attributes before smem/grid/tma computation self.a_dtype: Type[cutlass.Numeric] = a.element_type self.b_dtype: Type[cutlass.Numeric] = b.element_type self.c_dtype: Type[cutlass.Numeric] = c.element_type self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() self.c_layout = utils.LayoutEnum.from_tensor(c) # Check if input data types are compatible with MMA instruction if cutlass.const_expr(self.a_dtype != self.b_dtype): raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") tiled_mma = self._create_tiled_mma() # Setup attributes that dependent on gemm inputs self._setup_attributes() atom_thr_size = cute.size(tiled_mma.thr_id.shape) # Setup TMA load for A a_op = utils.sm100.cluster_shape_to_tma_atom_A( self.cluster_shape_mn, tiled_mma.thr_id ) a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( a_op, a, a_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape, internal_type=( cutlass.TFloat32 if a.element_type is cutlass.Float32 else None ), ) # Setup TMA load for B b_op = utils.sm100.cluster_shape_to_tma_atom_B( self.cluster_shape_mn, tiled_mma.thr_id ) b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( b_op, b, b_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape, internal_type=( cutlass.TFloat32 if b.element_type is cutlass.Float32 else None ), ) a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size # Setup TMA store for C tma_atom_c = None tma_tensor_c = None if cutlass.const_expr(self.use_tma_store): epi_smem_layout = cute.select(self.c_smem_layout_staged, mode=[0, 1]) tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem_layout, self.epi_tile ) # Compute grid size self.tile_sched_params, grid = self._compute_grid( c, self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters ) # Launch the kernel synchronously self.kernel( tiled_mma, tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b, tma_atom_c, tma_tensor_c if self.use_tma_store else c, self.cluster_layout_vmnk, self.a_smem_layout_staged, self.b_smem_layout_staged, self.c_smem_layout_staged, self.epi_tile, self.tile_sched_params, epilogue_op, ).launch( grid=grid, block=[self.threads_per_cta, 1, 1], cluster=(*self.cluster_shape_mn, 1), stream=stream, ) return # GPU device kernel @cute.kernel def kernel( self, tiled_mma: cute.TiledMma, tma_atom_a: cute.CopyAtom, mA_mkl: cute.Tensor, tma_atom_b: cute.CopyAtom, mB_nkl: cute.Tensor, tma_atom_c: Optional[cute.CopyAtom], mC_mnl: cute.Tensor, cluster_layout_vmnk: cute.Layout, a_smem_layout_staged: cute.ComposedLayout, b_smem_layout_staged: cute.ComposedLayout, c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], epi_tile: cute.Tile, tile_sched_params: utils.PersistentTileSchedulerParams, epilogue_op: cutlass.Constexpr, ): """ GPU device kernel performing the Persistent batched GEMM computation. """ warp_idx = cute.arch.warp_idx() warp_idx = cute.arch.make_warp_uniform(warp_idx) # # Prefetch tma desc # if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_atom_a) cpasync.prefetch_descriptor(tma_atom_b) if cutlass.const_expr(self.use_tma_store): cpasync.prefetch_descriptor(tma_atom_c) use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 # # Setup cta/thread coordinates # # Coords inside cluster bidx, bidy, bidz = cute.arch.block_idx() mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) is_leader_cta = mma_tile_coord_v == 0 cta_rank_in_cluster = cute.arch.make_warp_uniform( cute.arch.block_idx_in_cluster() ) block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( cta_rank_in_cluster ) # Coord inside cta tidx, _, _ = cute.arch.thread_idx() # # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier # # Define shared storage for kernel @cute.struct class SharedStorage: ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] acc_full_mbar_ptr: cute.struct.MemRange[ cutlass.Int64, self.num_acc_stage * 2 ] tmem_dealloc_mbar_ptr: cutlass.Int64 tmem_holding_buf: cutlass.Int32 smem = utils.SmemAllocator() storage = smem.allocate(SharedStorage) # Initialize mainloop ab_pipeline (barrier) and states ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 ab_pipeline_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, num_tma_producer ) ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), num_stages=self.num_ab_stage, producer_group=ab_pipeline_producer_group, consumer_group=ab_pipeline_consumer_group, tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cluster_layout_vmnk, defer_sync=True, ).make_participants() # Initialize acc_pipeline (barrier) and states acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) num_acc_consumer_threads = len(self.epilogue_warp_id) * ( 2 if use_2cta_instrs else 1 ) acc_pipeline_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, num_acc_consumer_threads ) acc_pipeline = pipeline.PipelineUmmaAsync.create( barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), num_stages=self.num_acc_stage, producer_group=acc_pipeline_producer_group, consumer_group=acc_pipeline_consumer_group, cta_layout_vmnk=cluster_layout_vmnk, defer_sync=True, ) tmem_alloc_barrier = pipeline.NamedBarrier( barrier_id=self.tmem_alloc_sync_bar_id, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)), ) tmem_dealloc_barrier = None if cutlass.const_expr(not self.use_tma_store): tmem_dealloc_barrier = pipeline.NamedBarrier( barrier_id=self.tmem_dealloc_sync_bar_id, num_threads=32 * len(self.epilogue_warp_id), ) # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( storage.tmem_holding_buf, barrier_for_retrieve=tmem_alloc_barrier, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta_instrs, two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, ) # Cluster arrive after barrier init pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=True) # # Setup smem tensor A/B/C # # (MMA, MMA_M, MMA_K, STAGE) sA = smem.allocate_tensor( element_type=self.a_dtype, layout=a_smem_layout_staged.outer, byte_alignment=128, swizzle=a_smem_layout_staged.inner, ) # (MMA, MMA_N, MMA_K, STAGE) sB = smem.allocate_tensor( element_type=self.b_dtype, layout=b_smem_layout_staged.outer, byte_alignment=128, swizzle=b_smem_layout_staged.inner, ) # # Compute multicast mask for A/B buffer full # a_full_mcast_mask = None b_full_mcast_mask = None if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): a_full_mcast_mask = cpasync.create_tma_multicast_mask( cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 ) b_full_mcast_mask = cpasync.create_tma_multicast_mask( cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 ) # # Local_tile partition global tensors # # (bM, bK, RestM, RestK, RestL) gA_mkl = cute.local_tile( mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) ) # (bN, bK, RestN, RestK, RestL) gB_nkl = cute.local_tile( mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) ) # (bM, bN, RestM, RestN, RestL) gC_mnl = cute.local_tile( mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) ) k_tile_cnt = cute.size(gA_mkl, mode=[3]) # # Partition global tensor for TiledMMA_A/B/C # thr_mma = tiled_mma.get_slice(mma_tile_coord_v) # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) tCgA = thr_mma.partition_A(gA_mkl) # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) tCgB = thr_mma.partition_B(gB_nkl) # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) tCgC = thr_mma.partition_C(gC_mnl) # # Partition global/shared tensor for TMA load A/B # # TMA load A partition_S/D a_cta_layout = cute.make_layout( cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape ) # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestM, RestK, RestL) tAsA, tAgA = cpasync.tma_partition( tma_atom_a, block_in_cluster_coord_vmnk[2], a_cta_layout, cute.group_modes(sA, 0, 3), cute.group_modes(tCgA, 0, 3), ) # TMA load B partition_S/D b_cta_layout = cute.make_layout( cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape ) # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestM, RestK, RestL) tBsB, tBgB = cpasync.tma_partition( tma_atom_b, block_in_cluster_coord_vmnk[1], b_cta_layout, cute.group_modes(sB, 0, 3), cute.group_modes(tCgB, 0, 3), ) # # Partition shared/tensor memory tensor for TiledMMA_A/B/C # # (MMA, MMA_M, MMA_K, STAGE) tCrA = tiled_mma.make_fragment_A(sA) # (MMA, MMA_N, MMA_K, STAGE) tCrB = tiled_mma.make_fragment_B(sB) # (MMA, MMA_M, MMA_N) acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) # (MMA, MMA_M, MMA_N, STAGE) tCtAcc_fake = tiled_mma.make_fragment_C( cute.append(acc_shape, self.num_acc_stage) ) # # Cluster wait before tensor memory alloc # pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk) # # Construct the scheduler # tile_sched = utils.StaticPersistentTileScheduler.create( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim(), ) work_tile = tile_sched.initial_work_tile_info() # # Specialized TMA load warp # if warp_idx == self.tma_warp_id: # # Persistent tile scheduling loop # while work_tile.is_valid_tile: # Get tile coord from tile scheduler cur_tile_coord = work_tile.tile_idx mma_tile_coord_mnl = ( cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), cur_tile_coord[1], cur_tile_coord[2], ) # # Slice to per mma tile index # # ((atom_v, rest_v), RestK) tAgA_slice = tAgA[ (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) ] # ((atom_v, rest_v), RestK) tBgB_slice = tBgB[ (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) ] # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt ab_producer.reset() peek_ab_empty_status = ab_producer.try_acquire() # # Tma load loop # for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): # Conditionally wait for AB buffer empty handle = ab_producer.acquire_and_advance(peek_ab_empty_status) # TMA load A/B cute.copy( tma_atom_a, tAgA_slice[(None, handle.count)], tAsA[(None, handle.index)], tma_bar_ptr=handle.barrier, mcast_mask=a_full_mcast_mask, ) cute.copy( tma_atom_b, tBgB_slice[(None, handle.count)], tBsB[(None, handle.index)], tma_bar_ptr=handle.barrier, mcast_mask=b_full_mcast_mask, ) # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 peek_ab_empty_status = cutlass.Boolean(1) if handle.count + 1 < k_tile_cnt: peek_ab_empty_status = ab_producer.try_acquire() # # Advance to next tile # tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() # # Wait A/B buffer empty # ab_producer.tail() # # Specialized MMA warp # if warp_idx == self.mma_warp_id: # # Retrieving tensor memory ptr and make accumulator tensor # tmem.wait_for_alloc() tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) # (MMA, MMA_M, MMA_N, STAGE) tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) # # Persistent tile scheduling loop # acc_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.num_acc_stage ) while work_tile.is_valid_tile: # Get tile coord from tile scheduler cur_tile_coord = work_tile.tile_idx mma_tile_coord_mnl = ( cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), cur_tile_coord[1], cur_tile_coord[2], ) # Set tensor memory buffer for current tile # (MMA, MMA_M, MMA_N) tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] # Peek (try_wait) AB buffer full for k_tile = 0 ab_consumer.reset() peek_ab_full_status = cutlass.Boolean(1) if is_leader_cta: peek_ab_full_status = ab_consumer.try_wait() # # Wait for accumulator buffer empty # if is_leader_cta: acc_pipeline.producer_acquire(acc_producer_state) # # Reset the ACCUMULATE field for each tile # tiled_mma.set(tcgen05.Field.ACCUMULATE, False) # # Mma mainloop # for k_tile in range(k_tile_cnt): if is_leader_cta: # Conditionally wait for AB buffer full handle = ab_consumer.wait_and_advance(peek_ab_full_status) # tCtAcc += tCrA * tCrB num_kblocks = cute.size(tCrA, mode=[2]) for kblk_idx in cutlass.range(num_kblocks, unroll_full=True): kblk_crd = (None, None, kblk_idx, handle.index) cute.gemm( tiled_mma, tCtAcc, tCrA[kblk_crd], tCrB[kblk_crd], tCtAcc, ) # Enable accumulate on tCtAcc after first kblock tiled_mma.set(tcgen05.Field.ACCUMULATE, True) # Async arrive AB buffer empty handle.release() # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 peek_ab_full_status = cutlass.Boolean(1) if handle.count + 1 < k_tile_cnt: peek_ab_full_status = ab_consumer.try_wait() # # Async arrive accumulator buffer full # if is_leader_cta: acc_pipeline.producer_commit(acc_producer_state) acc_producer_state.advance() # # Advance to next tile # tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() # # Wait for accumulator buffer empty # acc_pipeline.producer_tail(acc_producer_state) sC = None if cutlass.const_expr(self.use_tma_store): # (EPI_TILE_M, EPI_TILE_N, STAGE) sC = smem.allocate_tensor( element_type=self.c_dtype, layout=c_smem_layout_staged.outer, byte_alignment=128, swizzle=c_smem_layout_staged.inner, ) # # Specialized epilogue warps # if warp_idx < self.mma_warp_id: # # Alloc tensor memory buffer # tmem.allocate(self.num_tmem_alloc_cols) # # Retrieving tensor memory ptr and make accumulator tensor # tmem.wait_for_alloc() tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) # (MMA, MMA_M, MMA_N, STAGE) tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) # # Persistent tile scheduling loop for epilogue # acc_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.num_acc_stage ) if cutlass.const_expr(self.use_tma_store): assert tma_atom_c is not None and sC is not None c_producer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id), ) c_pipeline = pipeline.PipelineTmaStore.create( num_stages=self.num_c_stage, producer_group=c_producer_group ) while work_tile.is_valid_tile: # Get tile coord from tile scheduler cur_tile_coord = work_tile.tile_idx mma_tile_coord_mnl = ( cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), cur_tile_coord[1], cur_tile_coord[2], ) # # Pre-advance to next tile # tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() num_tiles_executed = tile_sched.num_tiles_executed if cutlass.const_expr(self.use_tma_store): acc_consumer_state = utils.gemm.sm100.epilogue_tma_store( self, tidx, warp_idx, tma_atom_c, tCtAcc_base, sC, tCgC, epi_tile, num_tiles_executed, epilogue_op, mma_tile_coord_mnl, acc_consumer_state, acc_pipeline, c_pipeline, ) else: acc_consumer_state = utils.gemm.sm100.epilogue( self, tidx, tCtAcc_base, tCgC, epi_tile, epilogue_op, mma_tile_coord_mnl, acc_consumer_state, acc_pipeline, ) if cutlass.const_expr(self.use_tma_store): # Wait for C store complete c_pipeline.producer_tail() else: # Synchronize before TMEM dealloc (done by the caller) tmem_dealloc_barrier.arrive_and_wait() # # Dealloc the tensor memory buffer # tmem.relinquish_alloc_permit() tmem.free(tmem_ptr) @staticmethod def _compute_grid( c: cute.Tensor, cta_tile_shape_mnk: Tuple[int, int, int], cluster_shape_mn: Tuple[int, int], max_active_clusters: cutlass.Constexpr, ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: """Use persistent tile scheduler to compute the grid size for the output tensor C. :param c: The output tensor C :type c: cute.Tensor :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. :type cta_tile_shape_mnk: tuple[int, int, int] :param cluster_shape_mn: Shape of each cluster in M, N dimensions. :type cluster_shape_mn: tuple[int, int] :param max_active_clusters: Maximum number of active clusters. :type max_active_clusters: cutlass.Constexpr :return: A tuple containing: - tile_sched_params: Parameters for the persistent tile scheduler. - grid: Grid shape for kernel launch. :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]] """ c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) gc = cute.zipped_divide(c, tiler=c_shape) num_ctas_mnl = gc[(0, (None, None, None))].shape cluster_shape_mnl = (*cluster_shape_mn, 1) tile_sched_params = utils.PersistentTileSchedulerParams( num_ctas_mnl, cluster_shape_mnl ) grid = utils.StaticPersistentTileScheduler.get_grid_shape( tile_sched_params, max_active_clusters ) return tile_sched_params, grid @staticmethod def _compute_num_tmem_alloc_cols( tiled_mma: cute.TiledMma, mma_tiler: Tuple[int, int, int], num_acc_stage: int, arch: str, ) -> int: """ Compute the number of tensor memory allocation columns. :param tiled_mma: The tiled MMA object defining the core computation. :type tiled_mma: cute.TiledMma :param mma_tiler: The shape (M, N, K) of the MMA tile. :type mma_tiler: tuple[int, int, int] :param num_acc_stage: The stage of the accumulator tensor. :type num_acc_stage: int :return: The number of tensor memory allocation columns. :rtype: int """ acc_shape = tiled_mma.partition_shape_C(mma_tiler[:2]) tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, num_acc_stage)) num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake, arch=arch) return num_tmem_alloc_cols def check_supported_dtypes( self, a_dtype: Type[cutlass.Numeric], b_dtype: Type[cutlass.Numeric], c_dtype: Type[cutlass.Numeric], ): """ Check if the dtypes are valid :param a_dtype: The data type of the A operands :type a_dtype: Type[cutlass.Numeric] :param b_dtype: The data type of the B operands :type b_dtype: Type[cutlass.Numeric] :param acc_dtype: The data type of the accumulator :type acc_dtype: Type[cutlass.Numeric] :param c_dtype: The data type of the output tensor :type c_dtype: Type[cutlass.Numeric] :raises testing.CantImplementError: If the dtypes are invalid """ valid_ab_dtypes = { cutlass.Float16, cutlass.BFloat16, cutlass.TFloat32, cutlass.Uint8, cutlass.Int8, cutlass.Float8E4M3FN, cutlass.Float8E5M2, } if a_dtype not in valid_ab_dtypes or b_dtype not in valid_ab_dtypes: raise testing.CantImplementError( f"Unsupported AB dtype: {a_dtype} and {b_dtype}" ) if self.acc_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.Int32}: raise testing.CantImplementError( f"Unsupported accumulator dtype: {self.acc_dtype}" ) # Define compatibility mapping between accumulator type and AB type acc_ab_compatibility = { cutlass.Float32: { cutlass.Float16, cutlass.BFloat16, cutlass.TFloat32, cutlass.Float8E4M3FN, cutlass.Float8E5M2, }, # Float32 accumulator supports floating point AB types only cutlass.Float16: { cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2, }, cutlass.Int32: {cutlass.Uint8, cutlass.Int8}, } # Check compatibility between accumulator type and AB type if ( a_dtype not in acc_ab_compatibility[self.acc_dtype] or b_dtype not in acc_ab_compatibility[self.acc_dtype] ): raise testing.CantImplementError( f"Unsupported AB dtype: {a_dtype} and {b_dtype} for accumulator dtype: {self.acc_dtype}" ) # Define compatibility mapping between accumulator type and C type acc_c_compatibility = { cutlass.Float32: { cutlass.Float32, cutlass.Float16, cutlass.BFloat16, cutlass.Float8E4M3FN, cutlass.Float8E5M2, cutlass.Int32, cutlass.Int8, cutlass.Uint8, }, cutlass.Float16: { cutlass.BFloat16, cutlass.Float16, }, cutlass.Int32: { cutlass.BFloat16, cutlass.Float16, cutlass.Float32, cutlass.Int32, cutlass.Int8, cutlass.Uint8, }, } # Check compatibility between accumulator type and C type if c_dtype not in acc_c_compatibility[self.acc_dtype]: raise testing.CantImplementError( f"Unsupported C dtype: {c_dtype} for accumulator dtype: {self.acc_dtype}" ) def check_mma_tiler_and_cluster_shape(self): """Check if the mma tiler and cluster shape are valid. :raises testing.CantImplementError: If the mma tiler and cluster shape are invalid """ # Skip invalid mma tile shape if not ( (not self.use_2cta_instrs and self.mma_tiler_mn[0] in [64, 128]) or (self.use_2cta_instrs and self.mma_tiler_mn[0] in [128, 256]) ): raise testing.CantImplementError( f"Invalid mma tiler & use_2cta_instrs: {self.mma_tiler_mn}, {self.use_2cta_instrs}" ) if self.mma_tiler_mn[1] not in range(32, 257, 32): raise testing.CantImplementError( f"Invalid mma tiler N: {self.mma_tiler_mn[1]}" ) # Skip illegal cluster shape if self.cluster_shape_mn[0] % (2 if self.use_2cta_instrs else 1) != 0: raise testing.CantImplementError( f"Invalid cluster shape M: {self.cluster_shape_mn[0]}" ) # Skip invalid cluster shape is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 if ( self.cluster_shape_mn[0] * self.cluster_shape_mn[1] > 16 or self.cluster_shape_mn[0] <= 0 or self.cluster_shape_mn[1] <= 0 or not is_power_of_2(self.cluster_shape_mn[0]) or not is_power_of_2(self.cluster_shape_mn[1]) ): raise testing.CantImplementError( f"Invalid cluster shape: {self.cluster_shape_mn}" ) def check_tensor_alignment( self, m: int, n: int, k: int, l: int, a_dtype: Type[cutlass.Numeric], b_dtype: Type[cutlass.Numeric], c_dtype: Type[cutlass.Numeric], a_major: str, b_major: str, c_major: str, ): """ Check if the tensor alignment is valid :param m: The number of rows in the A tensor :type m: int :param n: The number of columns in the B tensor :type n: int :param k: The number of columns in the A tensor :type k: int :param l: The number of columns in the C tensor :type l: int :param a_dtype: The data type of the A operands :type a_dtype: Type[cutlass.Numeric] :param b_dtype: The data type of the B operands :type b_dtype: Type[cutlass.Numeric] :param c_dtype: The data type of the output tensor :type c_dtype: Type[cutlass.Numeric] :param a_major: The major axis of the A tensor :type a_major: str :param b_major: The major axis of the B tensor :type b_major: str :param c_major: The major axis of the C tensor :type c_major: str :raises testing.CantImplementError: If the tensor alignment is invalid """ # TODO: move to utils def check_contiguous_16B_alignment(dtype, is_mode0_major, tensor_shape): major_mode_idx = 0 if is_mode0_major else 1 num_major_elements = tensor_shape[major_mode_idx] num_contiguous_elements = 16 * 8 // dtype.width return num_major_elements % num_contiguous_elements == 0 if ( not check_contiguous_16B_alignment(a_dtype, a_major == "m", (m, k, l)) or not check_contiguous_16B_alignment(b_dtype, b_major == "n", (n, k, l)) or not check_contiguous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) ): raise testing.CantImplementError( f"Invalid tensor alignment: {m}, {n}, {k}, {l}, {a_dtype}, {b_dtype}, {c_dtype}, {a_major}, {b_major}, {c_major}" ) def check_epilog_store_option(self, m: int, n: int): """ Check if the epilogue store option is valid :param m: The number of rows in the A tensor :type m: int :param n: The number of columns in the B tensor :type n: int :raises testing.CantImplementError: If the epilogue store option is invalid """ # None TMA store version does not have predication, can not support OOB tiles cta_tile_shape_mn = ( self.mma_tiler_mn[0] // (2 if self.use_2cta_instrs else 1), self.mma_tiler_mn[1], ) if not self.use_tma_store: if not (m % cta_tile_shape_mn[0] == 0 and n % cta_tile_shape_mn[1] == 0): raise testing.CantImplementError( f"Invalid epilog store option: {m}, {n}" ) def can_implement( self, mnkl: Tuple[int, int, int, int], a_dtype: Type[cutlass.Numeric], b_dtype: Type[cutlass.Numeric], c_dtype: Type[cutlass.Numeric], a_major: str, b_major: str, c_major: str, ) -> bool: """ Determine if the given tensor configuration can be implemented by this kernel. :param mnkl: Problem size as a tuple (M, N, K, L). :type mnkl: Tuple[int, int, int, int] :param a_dtype: Data type for input tensors A. :type a_dtype: Type[cutlass.Numeric] :param b_dtype: Data type for input tensors B. :type b_dtype: Type[cutlass.Numeric] :param c_dtype: Data type for output tensor C. :type c_dtype: Type[cutlass.Numeric] :param a_major: Major dimension of the A tensor layout ("m" or "k"). :type a_major: str :param b_major: Major dimension of the B tensor layout ("n" or "k"). :type b_major: str :param c_major: Major dimension of the C tensor layout ("m" or "n"). :type c_major: str :return: True if the kernel supports the given configuration, False otherwise. :rtype: bool """ try: # Skip unsupported types self.check_supported_dtypes(a_dtype, b_dtype, c_dtype) # Skip invalid mma tile shape and cluster shape self.check_mma_tiler_and_cluster_shape() m, n, k, l = mnkl self.check_tensor_alignment( m, n, k, l, a_dtype, b_dtype, c_dtype, a_major, b_major, c_major ) self.check_epilog_store_option(m, n) except testing.CantImplementError: return False return True @cute.jit def bmm( gemm_op: cutlass.Constexpr, a: cute.Tensor, # (l, m, k) b: cute.Tensor, # (l, k, n) c: cute.Tensor, # (l, m, n) max_active_clusters: cutlass.Constexpr, stream: cuda.CUstream, epilogue_op: cutlass.Constexpr = lambda x: x, ): """ Wrapper API for persistent GEMM kernel to follow the convention of PyTorch's batch matrix-multiply (bmm). Internally, the tensors are permuted to match CuTe's convention: - a: (m, k, l) - b: (n, k, l) - c: (m, n, l) :param gemm_op: Kernel operation, expects (a, b, c, max_active_clusters, stream, epilogue_op) :type gemm_op: cutlass.Constexpr :param a: Input tensor of shape (l, m, k) :type a: cute.Tensor :param b: Input tensor of shape (l, k, n) :type b: cute.Tensor :param c: Output tensor of shape (l, m, n) :type c: cute.Tensor :param max_active_clusters: Maximum number of hardware clusters to launch :type max_active_clusters: cutlass.Constexpr :param epilogue_op: Optional elementwise lambda function to apply per output element, defaults to identity :type epilogue_op: cutlass.Constexpr, optional """ # (l,m,k) -> (m,k,l) a = cute.make_tensor(a.iterator, cute.select(a.layout, mode=[1, 2, 0])) # (l,k,n) -> (n,k,l) b = cute.make_tensor(b.iterator, cute.select(b.layout, mode=[2, 1, 0])) # (l,m,n) -> (m,n,l) c = cute.make_tensor(c.iterator, cute.select(c.layout, mode=[1, 2, 0])) gemm_op(a, b, c, max_active_clusters, stream, epilogue_op) @lru_cache(maxsize=1) def prepare_tensors( mnkl: Tuple[int, int, int, int], a_dtype: Type[cutlass.Numeric], b_dtype: Type[cutlass.Numeric], c_dtype: Type[cutlass.Numeric], a_major: str, b_major: str, c_major: str, init_random: bool = True, normal_mean: float = 0.0, normal_std: float = 1.0, ): """Prepare tensors for GEMM. Returns: Tuple of (a_f32, b_f32, c_f32, a_storage, b_storage, c_storage): - *_f32: Float32 tensors with the logical data (for reference and fp8 conversion) - *_storage: Storage tensors for DLPack (uint8 for fp8, otherwise the target dtype) """ import torch from cutlass.torch import dtype as torch_dtype m, n, k, l = mnkl if a_major == "k": a_f32 = torch.empty((l, m, k), dtype=torch.float32, device="cuda") elif a_major == "m": a_f32 = torch.empty((l, k, m), dtype=torch.float32, device="cuda").permute( 0, 2, 1 ) if b_major == "n": b_f32 = torch.empty((l, k, n), dtype=torch.float32, device="cuda") elif b_major == "k": b_f32 = torch.empty((l, n, k), dtype=torch.float32, device="cuda").permute( 0, 2, 1 ) if c_major == "n": c_f32 = torch.empty((l, m, n), dtype=torch.float32, device="cuda") elif c_major == "m": c_f32 = torch.empty((l, n, m), dtype=torch.float32, device="cuda").permute( 0, 2, 1 ) if init_random: # Uniform random initialization in range [-2, 3) a_f32.random_(-2, 3) b_f32.random_(-2, 3) c_f32.random_(-2, 3) else: # Normal (Gaussian) initialization with user-specified mean and std a_f32.normal_(mean=normal_mean, std=normal_std) b_f32.normal_(mean=normal_mean, std=normal_std) c_f32.normal_(mean=normal_mean, std=normal_std) # For float8 types, use uint8 as storage type to avoid dlpack limitation # (dlpack doesn't support float8 types) # For other types, convert to the target dtype a_storage_dtype = torch.uint8 if is_fp8_dtype(a_dtype) else torch_dtype(a_dtype) b_storage_dtype = torch.uint8 if is_fp8_dtype(b_dtype) else torch_dtype(b_dtype) c_storage_dtype = torch.uint8 if is_fp8_dtype(c_dtype) else torch_dtype(c_dtype) a_storage = a_f32.to(dtype=a_storage_dtype) b_storage = b_f32.to(dtype=b_storage_dtype) c_storage = c_f32.to(dtype=c_storage_dtype) return (a_f32, b_f32, c_f32, a_storage, b_storage, c_storage) @lru_cache(maxsize=1) def compile_bmm( mnkl: Tuple[int, int, int, int], a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, acc_dtype: Type[cutlass.Numeric], a_major: str, b_major: str, c_major: str, mma_tiler_mn: Tuple[int, int] = (256, 256), cluster_shape_mn: Tuple[int, int] = (2, 1), max_active_clusters: cutlass.Constexpr = None, use_2cta_instrs: bool = True, use_tma_store: bool = True, epilogue_op: cutlass.Constexpr = lambda x: x, ): from cutlass.cute.runtime import make_fake_stream gemm = PersistentDenseGemmKernel( acc_dtype, use_2cta_instrs, mma_tiler_mn, cluster_shape_mn, use_tma_store, ) # Check if configuration can be implemented can_implement = gemm.can_implement( mnkl, a.element_type, b.element_type, c.element_type, a_major, b_major, c_major ) if not can_implement: raise testing.CantImplementError( f"The current config which is invalid/unsupported: use_2cta_instrs = {use_2cta_instrs}, " f"mma_tiler_mn = {mma_tiler_mn}, cluster_shape_mn = {cluster_shape_mn}, " f"use_tma_store = {use_tma_store}" ) stream = make_fake_stream() return cute.compile(bmm, gemm, a, b, c, max_active_clusters, stream, epilogue_op) def run( mnkl: Tuple[int, int, int, int], ab_dtype: Type[cutlass.Numeric], c_dtype: Type[cutlass.Numeric], acc_dtype: Type[cutlass.Numeric], a_major: str, b_major: str, c_major: str, mma_tiler_mn: Tuple[int, int] = (256, 256), cluster_shape_mn: Tuple[int, int] = (2, 1), use_2cta_instrs: bool = True, use_tma_store: bool = True, tolerance: float = 1e-01, warmup_iterations: int = 0, iterations: int = 1, skip_ref_check: bool = False, use_cold_l2: bool = False, benchmark: bool = False, **kwargs, ): """ Execute a persistent batched dense GEMM operation on Blackwell architecture with performance benchmarking. Prepares input tensors, configures and launches the persistent GEMM kernel, optionally performs reference validation, and benchmarks execution. :param mnkl: Problem size as a tuple (M, N, K, L). :type mnkl: Tuple[int, int, int, int] :param ab_dtype: Data type for input tensors A and B. :type ab_dtype: Type[cutlass.Numeric] :param c_dtype: Data type for output tensor C. :type c_dtype: Type[cutlass.Numeric] :param acc_dtype: Accumulator data type for the matrix multiplication. :type acc_dtype: Type[cutlass.Numeric] :param a_major: Memory layout of tensor A. :type a_major: str :param b_major: Memory layout of tensor B. :type b_major: str :param c_major: Memory layout of tensor C. :type c_major: str :param mma_tiler_mn: MMA tiling size (M, N), defaults to (256, 256). :type mma_tiler_mn: Tuple[int, int], optional :param cluster_shape_mn: Cluster shape (M, N), defaults to (2, 1). :type cluster_shape_mn: Tuple[int, int], optional :param use_2cta_instrs: Whether to use 2CTA MMA instructions, defaults to True. :type use_2cta_instrs: bool, optional :param use_tma_store: Whether to use TMA store, defaults to True. :type use_tma_store: bool, optional :param tolerance: Tolerance for reference validation, defaults to 1e-01. :type tolerance: float, optional :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0. :type warmup_iterations: int, optional :param iterations: Number of benchmark iterations to run, defaults to 1. :type iterations: int, optional :param skip_ref_check: Whether to skip reference result validation, defaults to False. :type skip_ref_check: bool, optional :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False. :type use_cold_l2: bool, optional :param benchmark: Whether to only benchmark the kernel, defaults to False. :type benchmark: bool, optional :raises RuntimeError: If CUDA GPU is not available. :raises ValueError: If the configuration is invalid or unsupported by the kernel. :return: Execution time of the GEMM kernel. :rtype: float """ import torch from cutlass.torch import dtype as torch_dtype if not torch.cuda.is_available(): raise RuntimeError("GPU is required to run this example!") # Get current CUDA stream from PyTorch torch_stream = torch.cuda.current_stream() # Get the raw stream pointer as a CUstream current_stream = cuda.CUstream(torch_stream.cuda_stream) # Check if configuration can be implemented max_active_clusters = utils.HardwareInfo().get_max_active_clusters( cluster_shape_mn[0] * cluster_shape_mn[1] ) # Run and verify BMM with torch a_f32, b_f32, c_f32, a_storage, b_storage, c_storage = prepare_tensors( mnkl, ab_dtype, ab_dtype, c_dtype, a_major, b_major, c_major ) leading_dim_a = 2 if a_major == "k" else 1 leading_dim_b = 1 if b_major == "k" else 2 leading_dim_c = 2 if c_major == "n" else 1 # Create CuTe tensors, passing float32 source for fp8 conversion a_ = create_cute_tensor_for_fp8( a_storage, ab_dtype, leading_dim_a, source_f32_tensor=a_f32 ) b_ = create_cute_tensor_for_fp8( b_storage, ab_dtype, leading_dim_b, source_f32_tensor=b_f32 ) c_ = create_cute_tensor_for_fp8( c_storage, c_dtype, leading_dim_c, source_f32_tensor=c_f32 ) compiled_fn = compile_bmm( mnkl, a_, b_, c_, acc_dtype, a_major, b_major, c_major, mma_tiler_mn, cluster_shape_mn, max_active_clusters, use_2cta_instrs, use_tma_store, epilogue_op=lambda x: x, ) print("Running Blackwell Persistent Dense GEMM test with:") print(f"mnkl: {mnkl}") print(f"Tolerance: {tolerance}") print(f"Warmup iterations: {warmup_iterations}") print(f"Iterations: {iterations}") print(f"Skip reference checking: {skip_ref_check}") print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") if not skip_ref_check: # Use small random number for deterministic result for reference check compiled_fn(a_, b_, c_, current_stream) # Manually quantize to be comparable # Use float32 source data for reference calculation ref = ( torch.bmm(a_f32, b_f32) .to(dtype=torch_dtype(c_dtype)) .to(dtype=torch.float32) ) # Read back the result from CuTe tensor (c_storage was updated in-place) torch.testing.assert_close( c_storage.to(dtype=torch.float32), ref, atol=tolerance, rtol=1e-03 ) if not benchmark: return 0 def generate_tensors(): a_f32, b_f32, c_f32, a_st, b_st, c_st = prepare_tensors( mnkl, ab_dtype, ab_dtype, c_dtype, a_major, b_major, c_major, ) a_ = create_cute_tensor_for_fp8( a_st, ab_dtype, leading_dim_a, source_f32_tensor=a_f32 ) b_ = create_cute_tensor_for_fp8( b_st, ab_dtype, leading_dim_b, source_f32_tensor=b_f32 ) c_ = create_cute_tensor_for_fp8( c_st, c_dtype, leading_dim_c, source_f32_tensor=c_f32 ) return testing.JitArguments(a_, b_, c_, current_stream) workspace_count = 1 if use_cold_l2: one_workspace_bytes = ( a_storage.numel() * a_storage.element_size() + b_storage.numel() * b_storage.element_size() + c_storage.numel() * c_storage.element_size() ) workspace_count = testing.get_workspace_count( one_workspace_bytes, warmup_iterations, iterations ) # Return execution time in microseconds return testing.benchmark( compiled_fn, workspace_generator=generate_tensors, workspace_count=workspace_count, stream=current_stream, warmup_iterations=warmup_iterations, iterations=iterations, ) def compute_tflops(time_ns, m, n, k): return 2.0 * m * n * k / time_ns / 1000.0 def _parse_comma_separated_ints(s: str) -> Tuple[int, ...]: try: return tuple(int(x.strip()) for x in s.split(",")) except ValueError: raise argparse.ArgumentTypeError( "Invalid format. Expected comma-separated integers." ) def prepare_parser(): parser = argparse.ArgumentParser( description="Example of Dense Persistent GEMM on Blackwell." ) parser.add_argument( "--mnkl", type=_parse_comma_separated_ints, default=(256, 256, 512, 1), help="mnkl dimensions (comma-separated)", ) parser.add_argument( "--cluster_shape_mn", type=_parse_comma_separated_ints, default=(1, 1), help="Cluster shape (comma-separated)", ) parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.TFloat32) parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float32) parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32) parser.add_argument( "--use_2cta_instrs", action="store_true", help="Enable 2CTA MMA instructions feature", ) parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") parser.add_argument( "--use_tma_store", action="store_true", help="Use tma store or not" ) parser.add_argument( "--tolerance", type=float, default=1e-01, help="Tolerance for validation" ) parser.add_argument( "--benchmark", type=str, default="default", choices=[ "default", "none", ], help="Benchmark the kernel with nsight or default (cute.testing.benchmark) or none", ) parser.add_argument( "--warmup_iterations", type=int, default=0, help="Warmup iterations" ) parser.add_argument( "--iterations", type=int, default=1, help="Number of iterations to run the kernel", ) parser.add_argument( "--skip_ref_check", action="store_true", help="Skip reference checking" ) parser.add_argument( "--use_cold_l2", action="store_true", default=False, help="Use circular buffer tensor sets to ensure L2 cold cache", ) return parser if __name__ == "__main__": parser = prepare_parser() parser.add_argument( "--mma_tiler_mn", type=_parse_comma_separated_ints, default=(128, 128), help="Mma tile shape (comma-separated)", ) args = parser.parse_args() if len(args.mnkl) != 4: parser.error("--mnkl must contain exactly 4 values") if len(args.mma_tiler_mn) != 2: parser.error("--mma_tiler_mn must contain exactly 2 values") if len(args.cluster_shape_mn) != 2: parser.error("--cluster_shape_mn must contain exactly 2 values") print(f"[DSL INFO] Compiling Blackwell Persistent Dense GEMM with:") print( f"[DSL INFO] A dtype: {args.ab_dtype}, B dtype: {args.c_dtype}, C dtype: {args.acc_dtype}, Acc dtype: {args.acc_dtype}" ) print( f"[DSL INFO] Matrix majors - A: {args.a_major}, B: {args.b_major}, C: {args.c_major}" ) print(f"[DSL INFO] Mma Tiler (M, N): {args.mma_tiler_mn}") print(f"[DSL INFO] Cluster Shape (M, N): {args.cluster_shape_mn}") print( f"[DSL INFO] 2CTA MMA instructions: {'True' if args.use_2cta_instrs else 'False'}" ) print(f"[DSL INFO] Use TMA Store: {'True' if args.use_tma_store else 'False'}") run( args.mnkl, args.ab_dtype, args.c_dtype, args.acc_dtype, args.a_major, args.b_major, args.c_major, args.mma_tiler_mn, args.cluster_shape_mn, args.use_2cta_instrs, args.use_tma_store, args.tolerance, args.warmup_iterations, args.iterations, args.skip_ref_check, args.use_cold_l2, args.benchmark == "default", ) print("PASS")