# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # 1. Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import os import sys import argparse import math from typing import Type, Tuple, Optional from types import SimpleNamespace import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute import cutlass.cute.testing as testing from cutlass.cute.nvgpu import tcgen05 from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode import cutlass.cute.nvgpu.cpasync as cpasync import cutlass.utils as utils import cutlass.pipeline as pipeline from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait import cutlass.utils.blackwell_helpers as sm100_utils from cutlass.cute.runtime import from_dlpack from cutlass.cute.arch import Arch from cutlass.cutlass_dsl import BaseDSL if __name__ == "__main__": current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, os.path.join(current_dir, "../..")) from blackwell.mla.mla_helpers import ( ceil_div, MAX_SPLITS, LOG2_E, MLAStaticTileScheduler, MLAStaticTileSchedulerParams, create_mla_static_tile_scheduler, create_mla_static_tile_scheduler_params, ) """ A Multi-Head Latent Attention (MLA) example using fp8 as input/output for the NVIDIA Blackwell SM100 architecture using CUTE DSL This example demonstrates an implementation of inference of multi-head latent attention using a TMA + Blackwell SM100 TensorCore warp-specialized persistent kernel. The implementation integrates the (Qc + Qr)*(Kc + Kr)^T matrix multiplication, softmax normalization, and softmax((Qc + Qr)*(Kc + Kr)^T)*Vc into a single kernel. The kernel provides support for page table storage and variable-length KV cache sequences. It implements KV splitting functionality to minimize latency when processing long KV sequences. The kernel implements key optimizations including: - Warp specialization for different computation phases (load, MMA, softmax, correction, epilogue) - Pipeline stages between different warps for overlapping computation and memory access - Support for different precision data types - Two sub-kernels (split KV kernel and reduction kernel) that enable split KV processing To run this example: .. code-block:: bash python examples/blackwell/mla_fp8.py \ --batch_size 4 --latent_dim 512 --rope_dim 64 \ --num_heads 128 --seq_len_q 1 --seq_len_k 1024 \ --in_dtype Float8E4M3FN --out_dtype Float8E4M3FN \ --acc_dtype Float32 --lse_dtype Float32 \ --is_var_seq --is_var_split_kv \ --is_persistent The above example runs Multi-Head Latent Attention (MLA) with the following configuration: - Batch size: 4 - Sequence length of Q: 1 - Sequence length of K: 1024 - Latent dimension: 512 - RoPE dimension: 64 - Number of heads: 128 - Data types: Float8E4M3FN (input), Float8E4M3FN (output), Float32 (accumulation and LSE) It utilizes page table storage for the KV cache and enables both variable-length KV cache sequences and variable split KV processing with persistent scheduling. To collect performance with NCU profiler: .. code-block:: bash ncu python examples/blackwell/mla_fp8.py \ --batch_size 4 --latent_dim 512 --rope_dim 64 \ --num_heads 128 --seq_len_q 1 --seq_len_k 1024 \ --in_dtype Float8E4M3FN --out_dtype Float8E4M3FN \ --acc_dtype Float32 --lse_dtype Float32 \ --is_var_seq --is_var_split_kv \ --is_persistent --warmup_iterations 3 \ --iterations 10 --skip_ref_check Constraints for this example: * Data type requirements: - Input/output: Float8E4M3FN - Accumulation and LSE: Float32 * Fixed architecture parameters: - Number of attention heads: 128 - Latent dimension: 512 - RoPE dimension: 64 * Input query modes should be (NumHeads, LatentDim/RopeDim, SeqLenQ, BatchSize) * Input kv latent/rope modes should be (SeqLenK, LatentDim/RopeDim, BatchSize) * Query sequence length must be 1-4 * Only supports 2-CTA instructions * Variable sequence length requires page table storage enabled """ class BlackwellMultiHeadLatentAttentionForwardFP8: def __init__( self, acc_dtype: Type[cutlass.Numeric], lse_dtype: Type[cutlass.Numeric], mma_qk_tiler_mn: Tuple[int, int], mma_pv_tiler_mn: Tuple[int, int], max_active_clusters: int, page_size: int, skip_correction_threshold: float, is_persistent: bool, is_var_seq: bool, is_var_split_kv: bool, ): """Initializes the configuration for a Blackwell Multi-Head Latent Attention (MLA) kernel. :param acc_dtype: Data type for accumulation S and O :type acc_dtype: Type[cutlass.Numeric] :param lse_dtype: Data type for output LSE :type lse_dtype: Type[cutlass.Numeric] :param mma_s_tiler: The (H, K) tile shape of the MMA instruction for S :type mma_s_tiler: Tuple[int, int] :param mma_p_tiler: The (H, D) tile shape of the MMA instruction for P :type mma_p_tiler: Tuple[int, int] :param max_active_clusters: Maximum number of active clusters :type max_active_clusters: int :param page_size: The page size :type page_size: int :param skip_correction_threshold: Threshold to skip correction :type skip_correction_threshold: float :param is_persistent: Whether to use persistent kernel mode :type is_persistent: bool :param is_var_seq: Whether to use variable sequence length :type is_var_seq: bool :param is_var_split_kv: Whether to use variable split KV :type is_var_split_kv: bool """ self.latent_dim = 512 self.rope_dim = 64 self.acc_dtype = acc_dtype self.lse_dtype = lse_dtype self.mma_qk_tiler_mn = mma_qk_tiler_mn self.mma_pv_tiler_mn = mma_pv_tiler_mn self.max_active_clusters = max_active_clusters self.skip_correction_threshold = skip_correction_threshold self.is_persistent = is_persistent self.page_size = page_size self.is_var_seq = is_var_seq self.is_var_split_kv = is_var_split_kv self.cluster_shape_mnk = (2, 1, 1) self.use_2cta_instrs = True # When using 2 CTAs with m=128: warps 0-1 handle accumulation for first half [0, n/2), # while warps 2-3 handle accumulation for second half [n/2, n) self.warps_in_n = 2 self.num_compute_warps = 4 self.threads_per_warp = 32 mma_qk_tiler_k = self.rope_dim * 2 self.mma_qk_tiler = ( self.mma_qk_tiler_mn[0], self.mma_qk_tiler_mn[1], mma_qk_tiler_k, ) self.mma_qk_rope_tiler = ( self.mma_qk_tiler_mn[0], self.mma_qk_tiler_mn[1], self.rope_dim, ) self.mma_pv_tiler = ( self.mma_pv_tiler_mn[0], self.mma_pv_tiler_mn[1], self.mma_qk_tiler[1] * self.mma_qk_tiler[2] // self.mma_pv_tiler_mn[1], ) self.iterations_qk_latent = self.latent_dim // self.mma_qk_tiler[2] self.iterations_qk_rope = 1 self.iterations_qk = self.iterations_qk_latent + self.iterations_qk_rope self.iterations_pv_k = self.mma_qk_tiler[1] // self.mma_pv_tiler[2] self.iterations_pv_n = self.latent_dim // self.mma_pv_tiler[1] # Set specialized warp ids self.compute_warp_ids = (0, 1, 2, 3) self.correction_warp_ids = (4, 5, 6, 7) self.mma_warp_id = 8 self.load_tma_k_warp_id = 9 self.load_tma_v_warp_id = 10 self.empty_warp_ids = (11,) self.threads_per_cta = self.threads_per_warp * len( ( self.mma_warp_id, self.load_tma_k_warp_id, self.load_tma_v_warp_id, *self.compute_warp_ids, *self.correction_warp_ids, *self.empty_warp_ids, ) ) # register settings self.softmax_reg_num = 192 self.correction_reg_num = 256 self.other_reg_num = 48 # Named barriers self.tmem_ptr_sync_bar = pipeline.NamedBarrier( barrier_id=1, num_threads=( self.threads_per_warp + self.threads_per_warp * self.num_compute_warps * 2 ), ) self.softmax_exchange_sync_bar = pipeline.NamedBarrier( barrier_id=2, num_threads=(self.threads_per_warp * self.num_compute_warps) ) self.epilogue_exchange_sync_bar = pipeline.NamedBarrier( barrier_id=3, num_threads=(self.threads_per_warp * self.num_compute_warps) ) def _setup_attributes(self): """Set up configurations and parameters for the MLA kernel operation. This method initializes and configures various attributes required for the execution of the multi-head latent attention kernel, mainly about the pipeline stages: - Sets up staging parameters for Q, K, V inputs and accumulator data - Configures pipeline stages for softmax, correction, and epilogue operations """ self.load_q_stage = 1 self.load_k_stage = 3 self.load_v_stage = 2 self.mma_s_stage = 2 self.p_mma_stage = 2 self.p_cor_stage = 2 self.mma_o_stage = 2 self.tmem_o_offset = self.mma_s_stage * self.mma_qk_tiler[1] // self.warps_in_n self.correction_factor_offset = ( self.tmem_o_offset + self.latent_dim // self.warps_in_n ) @cute.jit def __call__( self, q_latent: cute.Tensor, q_rope: cute.Tensor, c_latent: cute.Tensor, c_rope: cute.Tensor, page_table: cute.Tensor, o: cute.Tensor, lse: cute.Tensor, workspace: cute.Tensor, split_kv: cutlass.Int32, cache_seqs: Optional[cute.Tensor], block_split_kvs: Optional[cute.Tensor], softmax_scale: cutlass.Float32, output_scale: cutlass.Float32, stream: cuda.CUstream, ): """Execute the Multi-Head Latent Attention operation on the provided tensors. The method handles: 1. Initialization of workspace for temporary split KV buffers 2. Validation of tensor data types 3. Initialization of hardware-specific parameters and memory layouts 4. Configuration of TMA (Tensor Memory Access) operations 5. Grid and work scheduling computation 6. Kernel launch(split KV kernel and reduction kernel) with appropriate parameters :param q_latent: The query tensor with shape [num_head, latent_dim, seq_len_q, batch_size] :type q_latent: cute.Tensor :param q_rope: The query RoPE tensor with shape [num_head, rope_dim, seq_len_q, batch_size] :type q_rope: cute.Tensor :param c_latent: The key tensor with shape [seq_len_k, latent_dim, batch_size] :type c_latent: cute.Tensor :param c_rope: The key RoPE tensor with shape [seq_len_k, rope_dim, batch_size] :type c_rope: cute.Tensor :param page_table: The page table tensor with shape [page_count, batch_size] :type page_table: cute.Tensor :param o: The output tensor with shape [num_head, latent_dim, seq_len_q, batch_size] :type o: cute.Tensor :param lse: The LSE tensor with shape [num_head, seq_len_q, batch_size] :type lse: cute.Tensor :param workspace: The workspace tensor with 1-d shape prepared for acc_o and acc_lse :type workspace: cute.Tensor :param split_kv: The scalar factor for split KV :type split_kv: cutlass.Int32 :param cache_seqs: The cache sequences tensor with shape [batch_size] :type cache_seqs: cute.Tensor :param block_split_kvs: The block split KV tensor with shape [batch_size] :type block_split_kvs: cute.Tensor :param softmax_scale: The scale factor for softmax :type softmax_scale: cutlass.Float32 :param output_scale: The scale factor for the output :type output_scale: cutlass.Float32 :param stream: The CUDA stream to execute the kernel on :type stream: cuda.CUstream :raises TypeError: If tensor data types don't match or aren't supported """ # setup static attributes before smem/grid/tma computation self.q_dtype = q_latent.element_type self.k_dtype = c_latent.element_type self.v_dtype = c_latent.element_type self.o_dtype = o.element_type # check type consistency if cutlass.const_expr( self.q_dtype != self.k_dtype or self.q_dtype != self.v_dtype ): raise TypeError( f"Type mismatch: {self.q_dtype} != {self.k_dtype} or {self.q_dtype} != {self.v_dtype}" ) # check leading dimensions of input/output if cutlass.const_expr(q_latent.stride[1] != 1 or q_rope.stride[1] != 1): raise ValueError("q_latent and q_rope must have leading dimension 1") if cutlass.const_expr(c_latent.stride[1] != 1 or c_rope.stride[1] != 1): raise ValueError("c_latent and c_rope must have leading dimension 1") if cutlass.const_expr(o.stride[1] != 1): raise ValueError("o must have leading dimension 1") if cutlass.const_expr(lse.stride[0] != 1): raise ValueError("lse must have leading dimension 0") acc_o, acc_lse = self.initialize_workspace( q_latent.shape[0], q_latent.shape[1], q_latent.shape[2], q_latent.shape[3], split_kv, self.acc_dtype, workspace, ) c_latent_tranpose_layout = cute.select(c_latent.layout, mode=[1, 0, 2]) c_latent_transpose = cute.make_tensor( c_latent.iterator, c_latent_tranpose_layout ) self.q_major_mode = OperandMajorMode.K self.k_major_mode = OperandMajorMode.K self.v_major_mode = OperandMajorMode.MN self._setup_attributes() cta_group = tcgen05.CtaGroup.TWO # the intermediate tensor p is from smem & k-major p_major_mode = OperandMajorMode.K qk_tiled_mma = sm100_utils.make_trivial_tiled_mma( self.q_dtype, self.q_major_mode, self.k_major_mode, self.acc_dtype, cta_group, self.mma_qk_tiler[:2], ) pv_tiled_mma = sm100_utils.make_trivial_tiled_mma( self.v_dtype, p_major_mode, self.v_major_mode, self.acc_dtype, cta_group, self.mma_pv_tiler[:2], ) cta_layout_vmnk = cute.tiled_divide( cute.make_layout(self.cluster_shape_mnk), (qk_tiled_mma.thr_id.shape,), ) self.epi_tile = self.mma_pv_tiler[:2] q_latent_smem_layout_staged = sm100_utils.make_smem_layout_a( qk_tiled_mma, self.mma_qk_tiler, self.q_dtype, (self.iterations_qk_latent * self.load_q_stage), ) q_latent_smem_layout_staged = cute.logical_divide( q_latent_smem_layout_staged, (None, None, None, self.iterations_qk_latent) ) q_rope_smem_layout_staged = sm100_utils.make_smem_layout_a( qk_tiled_mma, self.mma_qk_rope_tiler, self.q_dtype, self.load_q_stage, ) kc_latent_smem_layout_staged = sm100_utils.make_smem_layout_b( qk_tiled_mma, self.mma_qk_tiler, self.k_dtype, (self.iterations_qk_latent * self.load_k_stage), ) kc_page_tile_size = min( self.page_size, qk_tiled_mma.op.shape_mnk[0] // qk_tiled_mma.thr_id.shape ) kc_latent_smem_layout_staged = cute.logical_divide( kc_latent_smem_layout_staged, (None, None, None, self.iterations_qk_latent) ) kc_latent_smem_layout_for_tma = sm100_utils.make_smem_layout( OperandMajorMode.K, (self.mma_qk_tiler[0] // qk_tiled_mma.thr_id.shape, self.mma_qk_tiler[2]), self.k_dtype, (self.iterations_qk_latent * self.load_k_stage), ) kc_latent_smem_layout_for_tma = cute.tiled_divide( kc_latent_smem_layout_for_tma, (kc_page_tile_size, self.mma_qk_tiler[2]) ) kc_latent_smem_layout_for_tma = cute.logical_divide( kc_latent_smem_layout_for_tma, (None, None, None, self.iterations_qk_latent) ) kc_rope_smem_layout_staged = sm100_utils.make_smem_layout_b( qk_tiled_mma, self.mma_qk_rope_tiler, self.k_dtype, self.load_k_stage, ) kc_rope_smem_layout_for_tma = sm100_utils.make_smem_layout( OperandMajorMode.K, ( self.mma_qk_rope_tiler[0] // qk_tiled_mma.thr_id.shape, self.mma_qk_rope_tiler[2], ), self.k_dtype, (self.iterations_qk_rope * self.load_k_stage), ) kc_rope_smem_layout_for_tma = cute.tiled_divide( kc_rope_smem_layout_for_tma, (kc_page_tile_size, self.mma_qk_rope_tiler[2]) ) p_smem_layout_staged = sm100_utils.make_smem_layout_a( pv_tiled_mma, self.mma_pv_tiler, self.q_dtype, (self.iterations_pv_k * self.p_mma_stage), ) p_smem_layout_staged = cute.logical_divide( p_smem_layout_staged, (None, None, None, self.iterations_pv_k) ) vc_smem_layout_staged = sm100_utils.make_smem_layout_b( pv_tiled_mma, self.mma_pv_tiler, self.v_dtype, (self.iterations_pv_k * self.iterations_pv_n * self.load_v_stage), ) vc_smem_layout_staged = cute.logical_divide( cute.logical_divide( vc_smem_layout_staged, (None, None, None, self.iterations_pv_k * self.iterations_pv_n), ), (None, None, None, (self.iterations_pv_n, None)), ) vc_page_tile_size = min(self.page_size, self.mma_pv_tiler[2]) vc_smem_layout_for_tma = sm100_utils.make_smem_layout( OperandMajorMode.MN, (self.mma_pv_tiler[1] // pv_tiled_mma.thr_id.shape, self.mma_pv_tiler[2]), self.v_dtype, (self.iterations_pv_k * self.iterations_pv_n * self.load_v_stage), ) vc_smem_layout_for_tma = cute.tiled_divide( vc_smem_layout_for_tma, ( pv_tiled_mma.op.shape_mnk[1] // pv_tiled_mma.thr_id.shape, vc_page_tile_size, ), ) vc_smem_layout_for_tma = cute.logical_divide( cute.logical_divide( vc_smem_layout_for_tma, (None, None, None, self.iterations_pv_k * self.iterations_pv_n), ), (None, None, None, (self.iterations_pv_n, None)), ) # TMA load for Q latent and rope tma_load_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(cta_group) q_smem_layout = cute.select(q_latent_smem_layout_staged, mode=[0, 1, 2]) tma_atom_q_latent, tma_tensor_q_latent = cute.nvgpu.make_tiled_tma_atom_A( tma_load_op, q_latent, q_smem_layout, self.mma_qk_tiler, qk_tiled_mma, cta_layout_vmnk.shape, ) q_rope_smem_layout = cute.select(q_rope_smem_layout_staged, mode=[0, 1, 2]) tma_atom_q_rope, tma_tensor_q_rope = cute.nvgpu.make_tiled_tma_atom_A( tma_load_op, q_rope, q_rope_smem_layout, self.mma_qk_rope_tiler, qk_tiled_mma, cta_layout_vmnk.shape, ) # TMA load for c latent and k rope kc_smem_layout = cute.select(kc_latent_smem_layout_for_tma, mode=[0]) tma_atom_c_latent, tma_tensor_c_latent = self.make_paged_tiled_tma_atom( tma_load_op, c_latent, kc_smem_layout, (self.mma_qk_tiler[1], self.mma_qk_tiler[2]), qk_tiled_mma, is_k_load=True, ) kc_rope_smem_layout = cute.select(kc_rope_smem_layout_for_tma, mode=[0]) tma_atom_c_rope, tma_tensor_c_rope = self.make_paged_tiled_tma_atom( tma_load_op, c_rope, kc_rope_smem_layout, (self.mma_qk_rope_tiler[1], self.mma_qk_rope_tiler[2]), qk_tiled_mma, is_k_load=True, ) # TMA load for c latent transpose vc_smem_layout = cute.select(vc_smem_layout_for_tma, mode=[0]) tma_atom_c_latent_transpose, tma_tensor_c_latent_transpose = ( self.make_paged_tiled_tma_atom( tma_load_op, c_latent_transpose, vc_smem_layout, (self.mma_pv_tiler[1], self.mma_pv_tiler[2]), pv_tiled_mma, is_k_load=False, ) ) q_latent_copy_size = ( cute.size_in_bytes(self.q_dtype, q_smem_layout) * cute.size(qk_tiled_mma.thr_id.shape) * self.iterations_qk_latent ) q_rope_copy_size = ( cute.size_in_bytes(self.q_dtype, q_rope_smem_layout) * cute.size(qk_tiled_mma.thr_id.shape) * self.iterations_qk_rope ) kc_latent_copy_size = ( cute.size_in_bytes( self.k_dtype, cute.select(kc_latent_smem_layout_staged, mode=[0, 1, 2]), ) * cute.size(qk_tiled_mma.thr_id.shape) * self.iterations_qk_latent ) kc_rope_copy_size = ( cute.size_in_bytes( self.k_dtype, cute.select(kc_rope_smem_layout_staged, mode=[0, 1, 2]), ) * cute.size(qk_tiled_mma.thr_id.shape) * self.iterations_qk_rope ) vc_copy_size = ( cute.size_in_bytes( self.v_dtype, cute.select(vc_smem_layout_staged, mode=[0, 1, 2]) ) * cute.size(pv_tiled_mma.thr_id.shape) * self.iterations_pv_n * self.iterations_pv_k ) self.tma_copy_q_bytes = q_latent_copy_size + q_rope_copy_size self.tma_copy_kc_bytes = kc_latent_copy_size + kc_rope_copy_size self.tma_copy_vc_bytes = vc_copy_size tile_sched_params, grid = self._compute_grid( o, split_kv, self.cluster_shape_mnk, self.max_active_clusters, self.is_persistent, ) @cute.struct class SplitKVKernelSharedStorage: # Pipeline barriers load_q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_q_stage * 2] load_k_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_k_stage * 2] load_v_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_v_stage * 2] mma_s_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mma_s_stage * 2] p_mma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.p_mma_stage * 2] p_cor_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.p_cor_stage * 2] mma_o_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mma_o_stage * 2] # Smem tensors smem_p: cute.struct.Align[ cute.struct.MemRange[self.q_dtype, cute.cosize(p_smem_layout_staged)], 1024, ] smem_kc_latent: cute.struct.Align[ cute.struct.MemRange[ self.k_dtype, cute.cosize(kc_latent_smem_layout_staged) ], 1024, ] smem_kc_rope: cute.struct.Align[ cute.struct.MemRange[ self.k_dtype, cute.cosize(kc_rope_smem_layout_staged) ], 1024, ] smem_q_latent: cute.struct.Align[ cute.struct.MemRange[ self.q_dtype, cute.cosize(q_latent_smem_layout_staged) ], 1024, ] smem_q_rope: cute.struct.Align[ cute.struct.MemRange[ self.q_dtype, cute.cosize(q_rope_smem_layout_staged) ], 1024, ] smem_vc: cute.struct.Align[ cute.struct.MemRange[self.v_dtype, cute.cosize(vc_smem_layout_staged)], 1024, ] softmax_smem_exchange: cute.struct.MemRange[ self.acc_dtype, self.num_compute_warps * self.threads_per_warp ] epilogue_smem_exchange: cute.struct.MemRange[ self.acc_dtype, self.num_compute_warps * self.threads_per_warp ] # Tmem dealloc cluster barrier tmem_dealloc_mbar_ptr: cutlass.Int64 # Tmem holding buffer tmem_holding_buf: cutlass.Int32 softmax_scale_log2 = softmax_scale * LOG2_E self.split_kv_kernel( qk_tiled_mma, pv_tiled_mma, tma_atom_q_latent, tma_tensor_q_latent, tma_atom_q_rope, tma_tensor_q_rope, tma_atom_c_latent, tma_tensor_c_latent, tma_atom_c_rope, tma_tensor_c_rope, tma_atom_c_latent_transpose, tma_tensor_c_latent_transpose, page_table, o, lse, acc_o, acc_lse, split_kv, cache_seqs, block_split_kvs, softmax_scale_log2, output_scale, q_latent_smem_layout_staged, q_rope_smem_layout_staged, kc_latent_smem_layout_staged, kc_rope_smem_layout_staged, p_smem_layout_staged, vc_smem_layout_staged, kc_latent_smem_layout_for_tma, kc_rope_smem_layout_for_tma, vc_smem_layout_for_tma, cta_layout_vmnk, tile_sched_params, SplitKVKernelSharedStorage, ).launch( grid=grid, block=[self.threads_per_cta, 1, 1], cluster=self.cluster_shape_mnk, smem=SplitKVKernelSharedStorage.size_in_bytes(), stream=stream, min_blocks_per_mp=1, ) if cutlass.const_expr(acc_o is not None): self.reduction_kernel( o, lse, acc_o, acc_lse, split_kv, cache_seqs, block_split_kvs, ).launch( grid=(q_latent.shape[0], q_latent.shape[2], q_latent.shape[3]), block=[self.threads_per_warp * self.num_compute_warps, 1, 1], smem=MAX_SPLITS * self.acc_dtype.width // 8, stream=stream, min_blocks_per_mp=1, ) @cute.jit def make_paged_tiled_tma_atom( self, tma_load_op: cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp, gmem: cute.Tensor, smem_layout: cute.Layout, mma_tiler, tiled_mma: cute.TiledMma, is_k_load: bool, ): ident = cute.make_identity_layout(gmem.shape) g_tile = cute.composition(ident, mma_tiler) cta_mn = mma_tiler[0] // tiled_mma.thr_id.shape cta_v_map = cute.flat_divide(g_tile, (cta_mn,)) cta_v_map = cute.select(cta_v_map, mode=[0, 2]) page_tile_size = ( min(self.page_size, cta_mn) if is_k_load else min(self.page_size, mma_tiler[1]) ) cta_v_map = cute.zipped_divide( cta_v_map, (page_tile_size, mma_tiler[1]) if is_k_load else (cta_mn, page_tile_size), ) cta_v_map = cute.select(cta_v_map, mode=[0]) from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( gmem.value, smem_layout.value, cta_v_map, tma_load_op._to_ir(), num_multicast=1, ) return ( cute.CopyAtom( tma_load_op, cpasync.CopyBulkTensorTileG2SNonExecTrait(res[0]) ), res[1], ) @cute.kernel def split_kv_kernel( self, tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, tma_atom_q_latent: Optional[cute.CopyAtom], mQL: cute.Tensor, tma_atom_q_rope: Optional[cute.CopyAtom], mQR: cute.Tensor, tma_atom_c_latent: Optional[cute.CopyAtom], mCL: cute.Tensor, tma_atom_c_rope: Optional[cute.CopyAtom], mKR: cute.Tensor, tma_atom_c_latent_transpose: Optional[cute.CopyAtom], mCLT: cute.Tensor, mPT: cute.Tensor, mO: Optional[cute.Tensor], mLSE: Optional[cute.Tensor], mAccO: Optional[cute.Tensor], mAccLSE: Optional[cute.Tensor], split_kv: cutlass.Int32, cache_seqs: cute.Tensor, block_split_kvs: cute.Tensor, softmax_scale_log2: cutlass.Float32, output_scale: cutlass.Float32, q_latent_smem_layout_staged: cute.ComposedLayout, q_rope_smem_layout_staged: cute.ComposedLayout, kc_latent_smem_layout_staged: cute.ComposedLayout, kc_rope_smem_layout_staged: cute.ComposedLayout, p_smem_layout_staged: cute.ComposedLayout, vc_smem_layout_staged: cute.ComposedLayout, kc_latent_smem_layout_for_tma: Optional[cute.ComposedLayout], kc_rope_smem_layout_for_tma: Optional[cute.ComposedLayout], vc_smem_layout_for_tma: Optional[cute.ComposedLayout], cta_layout_vmnk: cute.Layout, tile_sched_params: MLAStaticTileSchedulerParams, SharedStorage: cutlass.Constexpr, ): """The device split_kv kernel implementation of the Multi-Head Latent Attention. This kernel coordinates multiple specialized warps to perform different phases of the MLA computation: 1. Load warp: Loads Q/C latent/rope data from global memory to shared memory using TMA 2. MMA warp: Performs matrix multiplications (Q*K^T and P*V) 3. Compute warps: Compute softmax and do rescaling on accumulators, and store the intermediate/final results to global memory The kernel produces either intermediate or final results of the MLA computation based on the split_kv parameter. When split_kv is 1, the kernel generates the final results directly. Otherwise, it produces intermediate results that will later be combined by a reduction kernel. The kernel implements a complex pipeline with overlapping computation and memory operations, using tensor memory access (TMA) for efficient data loading, warp specialization for different computation phases. :param tiled_mma_qk: Tiled MMA for Q*K^T :type tiled_mma_qk: cute.TiledMma :param tiled_mma_pv: Tiled MMA for P*V :type tiled_mma_pv: cute.TiledMma :param tma_atom_q_latent: TMA copy atom for query latent tensor :type tma_atom_q_latent: cute.CopyAtom :param mQL: query latent tensor :type mQL: cute.Tensor :param tma_atom_q_rope: TMA copy atom for query rope tensor :type tma_atom_q_rope: cute.CopyAtom :param mKR: Compressed rope tensor :type mKR: cute.Tensor :param tma_atom_c_latent: TMA copy atom for c latent tensor :type tma_atom_c_latent: cute.CopyAtom :param mCL: Compressed latent tensor :type mCL: cute.Tensor :param tma_atom_c_rope: TMA copy atom for c rope tensor :type tma_atom_c_rope: cute.CopyAtom :param mCLT: Compressed latent transpose tensor :type mCLT: cute.Tensor :param mPT: Page table tensor :type mPT: cute.Tensor :param mO: Output tensor :type mO: cute.Tensor :param mLSE: Log-sum-exp tensor :type mLSE: cute.Tensor :param mAccO: Intermediate accumulator output tensor :type mAccO: cute.Tensor :param mAccLSE: Intermediate accumulator log-sum-exp tensor :type mAccLSE: cute.Tensor :param split_kv: The split_kv parameter :type split_kv: cutlass.Int32 :param cache_seqs: The variable sequence length tensor :type cache_seqs: cute.Tensor :param block_split_kvs: The per-block split_kv values tensor :type block_split_kvs: cute.Tensor :param softmax_scale_log2: The log2 scale factor for softmax :type softmax_scale_log2: cutlass.Float32 :param output_scale: The scale factor for the output :type output_scale: cutlass.Float32 :param q_latent_smem_layout_staged: Shared memory layout for query tensor :type q_latent_smem_layout_staged: cute.ComposedLayout :param q_rope_smem_layout_staged: Shared memory layout for query rope tensor :type q_rope_smem_layout_staged: cute.ComposedLayout :param kc_latent_smem_layout_staged: Shared memory layout for key tensor :type kc_latent_smem_layout_staged: cute.ComposedLayout :param kc_rope_smem_layout_staged: Shared memory layout for key rope tensor :type kc_rope_smem_layout_staged: cute.ComposedLayout :param p_smem_layout_staged: Shared memory layout for probability matrix :type p_smem_layout_staged: cute.ComposedLayout :param vc_smem_layout_staged: Shared memory layout for value tensor :type vc_smem_layout_staged: cute.ComposedLayout :param cta_layout_vmnk: Layout for compute threads :type cta_layout_vmnk: cute.Layout :param tile_sched_params: Scheduling parameters for work distribution :type tile_sched_params: MLAStaticTileSchedulerParams :param SharedStorage: Shared storage for the kernel :type SharedStorage: cutlass.Constexpr """ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() mma_tile_coord_v = bidx % cute.size(tiled_mma_qk.thr_id.shape) is_leader_cta = mma_tile_coord_v == 0 # Prefetch tma descriptor if warp_idx == self.mma_warp_id: cpasync.prefetch_descriptor(tma_atom_q_latent) cpasync.prefetch_descriptor(tma_atom_q_rope) cpasync.prefetch_descriptor(tma_atom_c_latent) cpasync.prefetch_descriptor(tma_atom_c_rope) cpasync.prefetch_descriptor(tma_atom_c_latent_transpose) # Alloc smem = utils.SmemAllocator() storage = smem.allocate(SharedStorage) # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( storage.tmem_holding_buf, barrier_for_retrieve=self.tmem_ptr_sync_bar, allocator_warp_id=self.mma_warp_id, is_two_cta=self.use_2cta_instrs, two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, ) load_q_pipeline = self.make_and_init_load_qkv_pipeline( storage.load_q_mbar_ptr.data_ptr(), cta_layout_vmnk, self.load_q_stage, self.tma_copy_q_bytes, ) load_k_pipeline = self.make_and_init_load_qkv_pipeline( storage.load_k_mbar_ptr.data_ptr(), cta_layout_vmnk, self.load_k_stage, self.tma_copy_kc_bytes, ) load_v_pipeline = self.make_and_init_load_qkv_pipeline( storage.load_v_mbar_ptr.data_ptr(), cta_layout_vmnk, self.load_v_stage, self.tma_copy_vc_bytes, ) mma_s_pipeline = self.make_and_init_mma_s_pipeline( storage.mma_s_mbar_ptr.data_ptr(), cta_layout_vmnk ) p_mma_pipeline = self.make_and_init_p_mma_pipeline( storage.p_mma_mbar_ptr.data_ptr(), cta_layout_vmnk ) p_cor_pipeline = self.make_and_init_p_cor_pipeline( storage.p_cor_mbar_ptr.data_ptr() ) mma_o_pipeline = self.make_and_init_mma_o_pipeline( storage.mma_o_mbar_ptr.data_ptr(), cta_layout_vmnk ) # Cluster arrive after barrier init pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mnk, is_relaxed=True) # Generate smem tensor Q/KC/VC/exchange # (MMA, MMA_H, MMA_R, PIPE) sQ = storage.smem_q_latent.get_tensor( q_latent_smem_layout_staged.outer, swizzle=q_latent_smem_layout_staged.inner ) sQ_rope = storage.smem_q_rope.get_tensor( q_rope_smem_layout_staged.outer, swizzle=q_rope_smem_layout_staged.inner ) # (MMA, MMA_K, MMA_R, PIPE) sKC = storage.smem_kc_latent.get_tensor( kc_latent_smem_layout_staged.outer, swizzle=kc_latent_smem_layout_staged.inner, ) sKC_rope = storage.smem_kc_rope.get_tensor( kc_rope_smem_layout_staged.outer, swizzle=kc_rope_smem_layout_staged.inner ) sKC_for_tma = storage.smem_kc_latent.get_tensor( kc_latent_smem_layout_for_tma.outer, swizzle=kc_latent_smem_layout_for_tma.inner, ) sKC_rope_for_tma = storage.smem_kc_rope.get_tensor( kc_rope_smem_layout_for_tma.outer, swizzle=kc_rope_smem_layout_for_tma.inner ) # (MMA, MMA_D, MMA_K, PIPE) sVC = storage.smem_vc.get_tensor( vc_smem_layout_staged.outer, swizzle=vc_smem_layout_staged.inner ) sVC_for_tma = storage.smem_vc.get_tensor( vc_smem_layout_for_tma.outer, swizzle=vc_smem_layout_for_tma.inner ) # (MMA, MMA_H, MMA_K) sP = storage.smem_p.get_tensor( p_smem_layout_staged.outer, swizzle=p_smem_layout_staged.inner ) # (compute_threads,) softmax_smem_exchange = storage.softmax_smem_exchange.get_tensor( cute.make_layout(self.num_compute_warps * self.threads_per_warp) ) epilogue_smem_exchange = storage.epilogue_smem_exchange.get_tensor( cute.make_layout(self.num_compute_warps * self.threads_per_warp) ) # # Cluster wait before tensor memory alloc # pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mnk) # /////////////////////////////////////////////////////////////////////////////// # Load warps, including page table and data tensors # /////////////////////////////////////////////////////////////////////////////// if warp_idx >= self.empty_warp_ids[0] and warp_idx <= self.empty_warp_ids[-1]: cute.arch.setmaxregister_decrease(self.other_reg_num) if warp_idx == self.load_tma_k_warp_id: cute.arch.setmaxregister_decrease(self.other_reg_num) load_q_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.load_q_stage ) load_k_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.load_k_stage ) tile_sched = create_mla_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) work_tile = tile_sched.initial_work_tile_info() while work_tile.is_valid_tile: blk_coord = work_tile.tile_idx k_index, k_tile_count, local_split_kv = self.get_k_tile_count( split_kv, cache_seqs, block_split_kvs, blk_coord, ) if k_tile_count > 0: # Construct fixed common/tma_qk/tma_pv params for load_tma tma_common_params = SimpleNamespace( blk_coord=blk_coord, local_split_kv=local_split_kv, load_q_pipeline=load_q_pipeline, load_k_pipeline=load_k_pipeline, load_v_pipeline=load_v_pipeline, mPT=mPT, ) tma_qk_params = SimpleNamespace( tiled_mma_qk=tiled_mma_qk, tma_atom_q_latent=tma_atom_q_latent, tma_atom_q_rope=tma_atom_q_rope, tma_atom_c_latent=tma_atom_c_latent, tma_atom_c_rope=tma_atom_c_rope, mQL=mQL, mQR=mQR, mCL=mCL, mKR=mKR, sQ=sQ, sQ_rope=sQ_rope, sKC=sKC_for_tma, sKC_rope=sKC_rope_for_tma, ) # Load tma load_q_producer_state, load_k_producer_state = self.load_tma_qk( tma_common_params, tma_qk_params, k_index, k_tile_count, load_q_producer_state, load_k_producer_state, ) tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() load_q_pipeline.producer_tail(load_q_producer_state) load_k_pipeline.producer_tail(load_k_producer_state) if warp_idx == self.load_tma_v_warp_id: cute.arch.setmaxregister_decrease(self.other_reg_num) load_v_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.load_v_stage ) tile_sched = create_mla_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) work_tile = tile_sched.initial_work_tile_info() while work_tile.is_valid_tile: blk_coord = work_tile.tile_idx k_index, k_tile_count, local_split_kv = self.get_k_tile_count( split_kv, cache_seqs, block_split_kvs, blk_coord, ) if k_tile_count > 0: # Construct fixed common/tma_qk/tma_pv params for load_tma tma_common_params = SimpleNamespace( blk_coord=blk_coord, local_split_kv=local_split_kv, load_v_pipeline=load_v_pipeline, mPT=mPT, ) tma_pv_params = SimpleNamespace( tiled_mma_pv=tiled_mma_pv, tma_atom_c_latent_transpose=tma_atom_c_latent_transpose, mCLT=mCLT, sVC=sVC_for_tma, ) # Load tma load_v_producer_state = self.load_tma_v( tma_common_params, tma_pv_params, k_index, k_tile_count, load_v_producer_state, ) tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() load_v_pipeline.producer_tail(load_v_producer_state) # /////////////////////////////////////////////////////////////////////////////// # MMA warp # /////////////////////////////////////////////////////////////////////////////// if warp_idx == self.mma_warp_id: cute.arch.setmaxregister_decrease(self.other_reg_num) # Alloc tensor memory buffer tmem.allocate(cute.arch.get_max_tmem_alloc_cols("sm_100")) tmem.wait_for_alloc() tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) load_q_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.load_q_stage ) load_k_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.load_k_stage ) load_v_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.load_v_stage ) mma_s_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.mma_s_stage ) p_mma_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.p_mma_stage ) mma_o_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.mma_o_stage ) tile_sched = create_mla_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) work_tile = tile_sched.initial_work_tile_info() while work_tile.is_valid_tile: blk_coord = work_tile.tile_idx k_index, k_tile_count, local_split_kv = self.get_k_tile_count( split_kv, cache_seqs, block_split_kvs, blk_coord ) if k_tile_count > 0: mma_common_params = SimpleNamespace( blk_coord=blk_coord, local_split_kv=local_split_kv, load_q_pipeline=load_q_pipeline, load_k_pipeline=load_k_pipeline, load_v_pipeline=load_v_pipeline, tmem_ptr=tmem_ptr, is_leader_cta=is_leader_cta, L=mCL.shape[1], ) mma_qk_params = SimpleNamespace( mma_s_pipeline=mma_s_pipeline, sQ=sQ, sQ_rope=sQ_rope, sKC=sKC, sKC_rope=sKC_rope, ) mma_pv_params = SimpleNamespace( p_mma_pipeline=p_mma_pipeline, mma_o_pipeline=mma_o_pipeline, sP=sP, sVC=sVC, ) ( tiled_mma_qk, tiled_mma_pv, load_q_consumer_state, load_k_consumer_state, load_v_consumer_state, mma_s_producer_state, p_mma_consumer_state, mma_o_producer_state, ) = self.mma( mma_common_params, mma_qk_params, mma_pv_params, k_tile_count, tiled_mma_qk, tiled_mma_pv, load_q_consumer_state, load_k_consumer_state, load_v_consumer_state, mma_s_producer_state, p_mma_consumer_state, mma_o_producer_state, ) tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() mma_s_pipeline.producer_tail(mma_s_producer_state) mma_o_pipeline.producer_tail(mma_o_producer_state) tmem.relinquish_alloc_permit() tmem.free(tmem_ptr) # /////////////////////////////////////////////////////////////////////////////// # Compute warp # /////////////////////////////////////////////////////////////////////////////// if ( warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1] ): cute.arch.setmaxregister_increase(self.softmax_reg_num) mma_s_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.mma_s_stage ) p_mma_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.p_mma_stage ) p_cor_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.p_cor_stage ) mma_o_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.mma_o_stage ) tmem.wait_for_alloc() tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) tile_sched = create_mla_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) work_tile = tile_sched.initial_work_tile_info() while work_tile.is_valid_tile: blk_coord = work_tile.tile_idx k_index, k_tile_count, local_split_kv = self.get_k_tile_count( split_kv, cache_seqs, block_split_kvs, blk_coord ) if k_tile_count > 0: compute_common_params = SimpleNamespace( blk_coord=blk_coord, split_kv=split_kv, local_split_kv=local_split_kv, smem_exchange=softmax_smem_exchange, mAccO=mAccO, mO=mO, K=cache_seqs[blk_coord[2]], L=mCL.shape[1], tmem_ptr=tmem_ptr, tidx=tidx, p_cor_pipeline=p_cor_pipeline, ) compute_softmax_params = SimpleNamespace( tiled_mma_qk=tiled_mma_qk, sP=sP, mma_s_pipeline=mma_s_pipeline, p_mma_pipeline=p_mma_pipeline, softmax_scale_log2=softmax_scale_log2, ) mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state = ( self.compute( compute_common_params, compute_softmax_params, k_index=k_index, k_tile_count=k_tile_count, mma_s_consumer_state=mma_s_consumer_state, p_mma_producer_state=p_mma_producer_state, p_cor_producer_state=p_cor_producer_state, ) ) tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() p_cor_pipeline.producer_tail(p_cor_producer_state) # /////////////////////////////////////////////////////////////////////////////// # Correction warp # /////////////////////////////////////////////////////////////////////////////// if ( warp_idx >= self.correction_warp_ids[0] and warp_idx <= self.correction_warp_ids[-1] ): cute.arch.setmaxregister_increase(self.correction_reg_num) p_cor_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.p_cor_stage ) mma_o_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.mma_o_stage ) # sync with mma warp before retrieving tmem ptr tmem.wait_for_alloc() tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) tile_sched = create_mla_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) work_tile = tile_sched.initial_work_tile_info() while work_tile.is_valid_tile: blk_coord = work_tile.tile_idx k_index, k_tile_count, local_split_kv = self.get_k_tile_count( split_kv, cache_seqs, block_split_kvs, blk_coord ) if k_tile_count > 0: compute_common_params = SimpleNamespace( blk_coord=blk_coord, split_kv=split_kv, local_split_kv=local_split_kv, smem_exchange=epilogue_smem_exchange, mAccO=mAccO, mO=mO, K=cache_seqs[blk_coord[2]], L=mCL.shape[1], H=mQL.shape[0], tmem_ptr=tmem_ptr, tidx=tidx, tiled_mma_pv=tiled_mma_pv, p_cor_pipeline=p_cor_pipeline, mma_o_pipeline=mma_o_pipeline, ) compute_epilogue_params = SimpleNamespace( output_scale=output_scale, softmax_scale_log2=softmax_scale_log2, mAccLSE=mAccLSE, mLSE=mLSE, ) p_cor_consumer_state, mma_o_consumer_state = self.correction( compute_common_params, compute_epilogue_params, k_tile_count=k_tile_count, p_cor_consumer_state=p_cor_consumer_state, mma_o_consumer_state=mma_o_consumer_state, ) tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() return @cute.kernel def reduction_kernel( self, mO: cute.Tensor, mLSE: cute.Tensor, mAccO: cute.Tensor, mAccLSE: cute.Tensor, split_kv: cutlass.Int32, cache_seqs: cute.Tensor, block_split_kvs: cute.Tensor, ): """The reduction kernel for Multi-Head Latent Attention (MLA) that combines intermediate results from multiple split_kv blocks into final outputs. :param mO: Output tensor for storing final results :type mO: cute.Tensor :param mLSE: Log-sum-exp tensor for storing final LSE values :type mLSE: cute.Tensor :param mAccO: Accumulated output tensor from split_kv blocks :type mAccO: cute.Tensor :param mAccLSE: Accumulated LSE tensor from split_kv blocks :type mAccLSE: cute.Tensor :param split_kv: Number of split_kv blocks :type split_kv: cutlass.Int32 :param cache_seqs: Cache sequence lengths tensor :type cache_seqs: cute.Tensor :param block_split_kvs: Per-block split_kv values tensor (for variable split_kv) :type block_split_kvs: cute.Tensor """ bidx, bidy, bidz = cute.arch.block_idx() tidx, _, _ = cute.arch.thread_idx() blk_coord = (bidx, bidy, bidz) local_split_kv = ( block_split_kvs[blk_coord[2]] if self.is_var_split_kv else split_kv ) k_tile_total = cute.ceil_div(cache_seqs[blk_coord[2]], self.mma_qk_tiler[1]) k_tile_per_cta = cute.ceil_div(k_tile_total, local_split_kv) local_split_kv = cute.ceil_div(k_tile_total, k_tile_per_cta) # Alloc shared memory smem = utils.SmemAllocator() storage = smem.allocate(MAX_SPLITS * self.acc_dtype.width // 8, 16) lse_scale_ptr = cute.recast_ptr(storage, dtype=self.acc_dtype) smem_lse_scale = cute.make_tensor(lse_scale_ptr, cute.make_layout(MAX_SPLITS)) gLSE = mAccLSE[blk_coord[0], None, blk_coord[1], blk_coord[2]] warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) if warp_idx == 0: # calculate the global lse and exp ^ (local_lse - global_lse) lse_per_thread = cute.ceil_div(MAX_SPLITS, self.threads_per_warp) local_lse = cute.make_rmem_tensor( cute.make_layout(lse_per_thread), self.lse_dtype ) lse_max = -self.lse_dtype.inf # find the max lse for i in cutlass.range_constexpr(lse_per_thread): split_kv_idx = tidx + i * self.threads_per_warp local_lse[i] = ( gLSE[split_kv_idx] if cute.elem_less(split_kv_idx, local_split_kv) else -self.lse_dtype.inf ) # reduce the local lse lse_max = cute.arch.fmax(lse_max, local_lse[i]) lse_max = cute.arch.warp_reduction_max(lse_max) lse_max = lse_max if lse_max != -self.lse_dtype.inf else 0.0 # calculate sum_lse sum_lse = 0.0 for i in cutlass.range_constexpr(lse_per_thread): sum_lse += cute.math.exp2(local_lse[i] - lse_max, fastmath=True) sum_lse = cute.arch.warp_reduction_sum(sum_lse) # calculate the global_lse global_lse = ( lse_max + cute.math.log2(sum_lse, fastmath=True) if not sum_lse == self.lse_dtype(0.0) or sum_lse != sum_lse else self.lse_dtype.inf ) if tidx == 0: mLSE[blk_coord[0], blk_coord[1], blk_coord[2]] = global_lse # store the scale to shared memory for i in cutlass.range_constexpr(lse_per_thread): split_kv_idx = tidx + i * self.threads_per_warp if cute.elem_less(split_kv_idx, local_split_kv): smem_lse_scale[split_kv_idx] = cute.math.exp2( local_lse[i] - global_lse, fastmath=True ) pipeline.sync(barrier_id=4) elements_per_thread = cute.ceil_div( self.latent_dim, self.threads_per_warp * self.num_compute_warps ) gAccO = mAccO[blk_coord[0], None, None, blk_coord[1], blk_coord[2]] rAccO = cute.make_rmem_tensor( cute.make_layout(elements_per_thread), self.acc_dtype ) rO = cute.make_rmem_tensor(cute.make_layout(elements_per_thread), self.o_dtype) rAccO.fill(0.0) for i in range(local_split_kv): for j in cutlass.range_constexpr(elements_per_thread): element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps rAccO[j] += gAccO[i, element_idx] * smem_lse_scale[i] rO.store(rAccO.load().to(self.o_dtype)) for j in cutlass.range_constexpr(elements_per_thread): element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps mO[blk_coord[0], element_idx, blk_coord[1], blk_coord[2]] = rO[j] return @staticmethod def get_split_kv( B: int, S: int, K: int, mma_qk_tiler_mn: tuple, max_active_blocks: int ) -> int: """Get the proper split_kv value for the MLA kernel based on parameters. :param B: Batch size :type B: int :param S: Sequence length :type S: int :param K: Sequence length :type K: int :param mma_qk_tiler_mn: MLA tiling parameters :type mma_qk_tiler_mn: tuple :param max_active_blocks: Maximum number of active blocks :type max_active_blocks: int :return: Split_kv value :rtype: int """ max_splits = ceil_div(K, mma_qk_tiler_mn[1]) blocks_per_batch = max(1, max_active_blocks // B // (S * 2)) split_heur = min(max_splits, blocks_per_batch) k_waves = ceil_div(max_splits, split_heur) split_wave_aware = ceil_div(max_splits, k_waves) max_split_kv = 32 return min(split_wave_aware, max_split_kv) @cute.jit def get_k_tile_count( self, split_kv: cutlass.Int32, cache_seqs: cute.Tensor, block_split_kvs: cute.Tensor, blk_coord: cute.Coord, ) -> tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]: """Get the current k_index, k_tile_count, and local split_kv value for the MLA kernel. :param split_kv: Split_kv value :type split_kv: cutlass.Int32 :param cache_seqs: Cache sequence lengths tensor :type cache_seqs: cute.Tensor :param block_split_kvs: Per-block split_kv values tensor :type block_split_kvs: cute.Tensor :param blk_coord: Block coordinate :type blk_coord: cute.Coord :return: k_index, k_tile_count, split_kv :rtype: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32] """ K = cache_seqs[blk_coord[2]] if cutlass.const_expr(self.is_var_split_kv): split_kv = block_split_kvs[blk_coord[2]] k_tile_total = cute.ceil_div(K, self.mma_qk_tiler[1]) k_tile_per_cta = cute.ceil_div(k_tile_total, split_kv) k_index = blk_coord[3] * k_tile_per_cta k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index) return k_index, k_tile_count, split_kv @cute.jit def load_tma_qk( self, common_params: SimpleNamespace, qk_params: SimpleNamespace, k_index: cutlass.Int32, k_tile_count: cutlass.Int32, load_q_producer_state: pipeline.PipelineState | None = None, load_k_producer_state: pipeline.PipelineState | None = None, ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: """Load wrap to load Q/K tensors. Updates the load qk producer state. :param common_params: The common parameters :type common_params: SimpleNamespace :param qk_params: The qk parameters :type qk_params: SimpleNamespace :param k_index: The k index :type k_index: cutlass.Int32 :param k_tile_count: The k tile count :type k_tile_count: cutlass.Int32 :param load_q_producer_state: The load q producer state :type load_q_producer_state: pipeline.PipelineState :param load_k_producer_state: The load k producer state :type load_k_producer_state: pipeline.PipelineState :return: The load q producer state and load k producer state :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] """ # page table mPT = common_params.mPT[None, common_params.blk_coord[2]] # Flatten divide and partition global tensors for QK TMA load # (bM, bK, rM, rK, rL) mma_qk_tiler_mk = cute.select(self.mma_qk_tiler, mode=[0, 2]) gQL = cute.flat_divide(qk_params.mQL, mma_qk_tiler_mk) mma_qk_tiler_mk_rope = cute.select(self.mma_qk_rope_tiler, mode=[0, 2]) gQR = cute.flat_divide(qk_params.mQR, mma_qk_tiler_mk_rope) thr_mma_qk = qk_params.tiled_mma_qk.get_slice( common_params.blk_coord[0] % cute.size(qk_params.tiled_mma_qk.thr_id) ) tSgQL = thr_mma_qk.partition_A(gQL) tSgQR = thr_mma_qk.partition_A(gQR) cta_m = min( qk_params.tiled_mma_qk.op.shape_mnk[0] // qk_params.tiled_mma_qk.thr_id.shape, self.page_size, ) page_tile_size = min(self.page_size, cta_m) gCL = cute.tiled_divide(qk_params.mCL, (page_tile_size, self.mma_qk_tiler[2])) tSgCL = ( gCL[ None, common_params.blk_coord[0] % qk_params.tiled_mma_qk.thr_id.shape, None, None, ] if cta_m < self.page_size else gCL[None, 0, None, None] ) gKR = cute.tiled_divide( qk_params.mKR, (page_tile_size, self.mma_qk_rope_tiler[2]) ) tSgKR = ( gKR[ None, common_params.blk_coord[0] % qk_params.tiled_mma_qk.thr_id.shape, None, None, ] if cta_m < self.page_size else gKR[None, 0, None, None] ) # tma partition for q, k latent/rope # smem: ((atom_v, rest_v), STAGE) # gmem: ((atom_v, rest_v), RestM, RestK, RestL) tQsQ, tQLgQL_mkl = cpasync.tma_partition( qk_params.tma_atom_q_latent, 0, cute.make_layout(1), cute.group_modes(qk_params.sQ, 0, 3), cute.group_modes(tSgQL, 0, 3), ) tQsQ_rope, tQRgQR_mkl = cpasync.tma_partition( qk_params.tma_atom_q_rope, 0, cute.make_layout(1), cute.group_modes(qk_params.sQ_rope, 0, 3), cute.group_modes(tSgQR, 0, 3), ) tKCsKC, tCLgCL = cpasync.tma_partition( qk_params.tma_atom_c_latent, 0, cute.make_layout(1), qk_params.sKC, tSgCL, ) tKCsKC_rope, tKRgKR = cpasync.tma_partition( qk_params.tma_atom_c_rope, 0, cute.make_layout(1), qk_params.sKC_rope, tSgKR, ) tQLgQL = tQLgQL_mkl[ None, None, None, common_params.blk_coord[1], common_params.blk_coord[2] ] tQRgQR = tQRgQR_mkl[ None, None, None, common_params.blk_coord[1], common_params.blk_coord[2] ] # set extra params common_params.mPT = mPT qk_params.tQLgQL = tQLgQL qk_params.tQRgQR = tQRgQR qk_params.tCLgCL = tCLgCL qk_params.tKRgKR = tKRgKR qk_params.tQsQ = tQsQ qk_params.tQsQ_rope = tQsQ_rope qk_params.tKCsKC = tKCsKC qk_params.tKCsKC_rope = tKCsKC_rope k_tile_count_init = k_tile_count while k_tile_count > 0: load_q_producer_state, load_k_producer_state = self.load_tma_qk_one_k_tile( common_params, qk_params, k_index, k_tile_count, load_q_producer_state, load_k_producer_state, load_q=k_tile_count_init == k_tile_count, ) k_index += 1 k_tile_count -= 1 return load_q_producer_state, load_k_producer_state @cute.jit def load_tma_v( self, common_params: SimpleNamespace, v_params: SimpleNamespace, k_index: cutlass.Int32, k_tile_count: cutlass.Int32, load_v_producer_state: pipeline.PipelineState, ) -> pipeline.PipelineState: """Load wrap to load V tensors. Updates the load v producer state. :param common_params: The common parameters :type common_params: SimpleNamespace :param v_params: The v parameters :type v_params: SimpleNamespace :param k_index: The k index :type k_index: cutlass.Int32 :param k_tile_count: The k tile count :type k_tile_count: cutlass.Int32 :param load_v_producer_state: The load v producer state :type load_v_producer_state: pipeline.PipelineState :return: The load v producer state :rtype: pipeline.PipelineState """ # page table mPT = common_params.mPT[None, common_params.blk_coord[2]] # Flatten divide and partition global tensors for V TMA load page_tile_size = min(self.page_size, self.mma_pv_tiler[2]) gCLT = cute.flat_divide(v_params.mCLT, (self.mma_pv_tiler[1], page_tile_size)) cta_n = self.mma_pv_tiler[1] // v_params.tiled_mma_pv.thr_id.shape gCLT = cute.logical_divide(gCLT, (cta_n,))[ (None, common_params.blk_coord[0]), None, None, None, None ] tOgCLT = cute.tiled_divide(gCLT, (cta_n, page_tile_size)) tOgCLT = tOgCLT[None, 0, 0, None, None, None] # tma partition for vc # smem: ((atom_v, rest_v), STAGE) # gmem: ((atom_v, rest_v), RestM, RestK, RestL) tVCsVC, tCLTgCLT = cpasync.tma_partition( v_params.tma_atom_c_latent_transpose, 0, cute.make_layout(1), v_params.sVC, tOgCLT, ) # set extra params common_params.mPT = mPT v_params.tCLTgCLT = tCLTgCLT v_params.tVCsVC = tVCsVC while k_tile_count > 0: load_v_producer_state = self.load_tma_v_one_k_tile( common_params, v_params, k_index, load_v_producer_state, ) k_index += 1 k_tile_count -= 1 return load_v_producer_state @cute.jit def load_tma_qk_one_k_tile( self, common_params: SimpleNamespace, qk_params: SimpleNamespace, k_index: cutlass.Int32, k_tile_count: cutlass.Int32, load_q_producer_state: pipeline.PipelineState, load_k_producer_state: pipeline.PipelineState, load_q: bool, ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: """Load one k-tile of Q/C latent/rope tensors. Updates the load qkv producer state. :param common_params: The common parameters :type common_params: SimpleNamespace :param qk_params: The qk parameters :type qk_params: SimpleNamespace :param k_index: The k index :type k_index: cutlass.Int32 :param k_tile_count: The k tile count :type k_tile_count: cutlass.Int32 :param load_q_producer_state: The load q producer state :type load_q_producer_state: pipeline.PipelineState :param load_k_producer_state: The load kv producer state :type load_k_producer_state: pipeline.PipelineState :param load_q: Whether to load q :type load_q: bool :return: The load q producer state and load kv producer state :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] """ page_per_tile = ceil_div( self.mma_qk_tiler[1] // self.page_size, qk_params.tiled_mma_qk.thr_id.shape ) k_idx = cute.make_rmem_tensor(cute.make_layout(page_per_tile), cutlass.Int32) for i in cutlass.range_constexpr(page_per_tile): k_idx[i] = ( common_params.mPT[k_index] if self.mma_qk_tiler[1] // self.page_size == 1 else common_params.mPT[ ( k_index * qk_params.tiled_mma_qk.thr_id.shape + common_params.blk_coord[0] ) * page_per_tile + i ] ) # load q once at first iteration load_q_pipeline = common_params.load_q_pipeline if load_q: # get the mbar ptr from pipeline. tma_bar_ptr = load_q_pipeline.producer_get_barrier(load_q_producer_state) # expect the extra bytes for q. load_q_pipeline.producer_acquire(load_q_producer_state) for i in cutlass.range_constexpr(self.iterations_qk_latent): # load q latent cute.copy( qk_params.tma_atom_q_latent, qk_params.tQLgQL[None, 0, i], qk_params.tQsQ[None, (i, 0)], tma_bar_ptr=tma_bar_ptr, ) for i in cutlass.range_constexpr(self.iterations_qk_rope): # load q rope cute.copy( qk_params.tma_atom_q_rope, qk_params.tQRgQR[None, 0, i], qk_params.tQsQ_rope[None, i], tma_bar_ptr=tma_bar_ptr, ) load_q_producer_state.advance() # get the mbar ptr from pipeline. tma_bar_ptr = common_params.load_k_pipeline.producer_get_barrier( load_k_producer_state ) common_params.load_k_pipeline.producer_acquire(load_k_producer_state) for i in range(self.iterations_qk_latent): for k in range(page_per_tile): # load k latent cute.copy( qk_params.tma_atom_c_latent, qk_params.tCLgCL[None, i, k_idx[k]], qk_params.tKCsKC[None, k, 0, (i, load_k_producer_state.index)], tma_bar_ptr=tma_bar_ptr, ) for i in cutlass.range_constexpr(self.iterations_qk_rope): for k in cutlass.range_constexpr(page_per_tile): # load k rope cute.copy( qk_params.tma_atom_c_rope, qk_params.tKRgKR[None, i, k_idx[k]], qk_params.tKCsKC_rope[None, k, 0, load_k_producer_state.index], tma_bar_ptr=tma_bar_ptr, ) load_k_producer_state.advance() return load_q_producer_state, load_k_producer_state @cute.jit def load_tma_v_one_k_tile( self, common_params: SimpleNamespace, v_params: SimpleNamespace, k_index: cutlass.Int32, load_v_producer_state: pipeline.PipelineState, ) -> pipeline.PipelineState: """Load one k-tile of compressed latent transpose tensor(v). Updates the load qkv producer state. :param common_params: The common parameters :type common_params: SimpleNamespace :param v_params: The load tma v parameters :type v_params: SimpleNamespace :param k_index: The k index :type k_index: cutlass.Int32 :param load_v_producer_state: The load v producer state :type load_v_producer_state: pipeline.PipelineState :return: The load qkv producer state :rtype: pipeline.PipelineState """ page_per_tile = self.mma_pv_tiler[2] * self.iterations_pv_k // self.page_size page_per_subtile = ceil_div(page_per_tile, self.iterations_pv_k) k_idx = cute.make_rmem_tensor(cute.make_layout(page_per_tile), cutlass.Int32) for i in cutlass.range_constexpr(page_per_tile): k_idx[i] = ( common_params.mPT[k_index] if page_per_tile == 1 else common_params.mPT[k_index * page_per_tile + i] ) # get the mbar ptr from pipeline. tma_bar_ptr = common_params.load_v_pipeline.producer_get_barrier( load_v_producer_state ) common_params.load_v_pipeline.producer_acquire(load_v_producer_state) for j in cutlass.range_constexpr(self.iterations_pv_n): for i in cutlass.range_constexpr(self.iterations_pv_k): if cutlass.const_expr(page_per_tile > 1): for k in cutlass.range_constexpr(page_per_subtile): k_idx_i = k_idx[k + i * page_per_subtile] cute.copy( v_params.tma_atom_c_latent_transpose, v_params.tCLTgCLT[None, j, 0, k_idx_i], v_params.tVCsVC[ None, 0, k, ((j, i), load_v_producer_state.index) ], tma_bar_ptr=tma_bar_ptr, ) else: cute.copy( v_params.tma_atom_c_latent_transpose, v_params.tCLTgCLT[None, j, i, k_idx[0]], v_params.tVCsVC[ None, 0, 0, ((j, i), load_v_producer_state.index) ], tma_bar_ptr=tma_bar_ptr, ) load_v_producer_state.advance() return load_v_producer_state @cute.jit def mma( self, common_params: SimpleNamespace, qk_params: SimpleNamespace, pv_params: SimpleNamespace, k_tile_count: cutlass.Int32, tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, load_q_consumer_state: pipeline.PipelineState, load_k_consumer_state: pipeline.PipelineState, load_v_consumer_state: pipeline.PipelineState, mma_s_producer_state: pipeline.PipelineState, p_mma_consumer_state: pipeline.PipelineState, mma_o_producer_state: pipeline.PipelineState, ) -> tuple[ cute.TiledMma, cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, ]: """MMA warp to compute the result of Q*K^T and P*V. Updates the tiled mma and pipeline states. :param common_params: The common parameters for mma qk and pv :type common_params: SimpleNamespace :param qk_params: The mma qk parameters :type qk_params: SimpleNamespace :param pv_params: The mma pv parameters :type pv_params: SimpleNamespace :param k_tile_count: The k tile count :type k_tile_count: cutlass.Int32 :param tiled_mma_qk: The tiled mma qk :type tiled_mma_qk: cute.TiledMma :param tiled_mma_pv: The tiled mma pv :type tiled_mma_pv: cute.TiledMma :param load_q_consumer_state: The load q consumer state :type load_q_consumer_state: pipeline.PipelineState :param load_k_consumer_state: The load k consumer state :type load_k_consumer_state: pipeline.PipelineState :param load_v_consumer_state: The load v consumer state :type load_v_consumer_state: pipeline.PipelineState :param mma_s_producer_state: The mma s producer state :type mma_s_producer_state: pipeline.PipelineState :param p_mma_consumer_state: The p mma consumer state :type p_mma_consumer_state: pipeline.PipelineState :param mma_o_producer_state: The mma o producer state :type mma_o_producer_state: pipeline.PipelineState :return: The tiled mma qk, the tiled mma pv, the load q consumer state, the load k consumer state, the load v consumer state, the mma s producer state, the p mma consumer state, and the mma o producer state :rtype: tuple[cute.TiledMma, cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] """ tSrQ = tiled_mma_qk.make_fragment_A(qk_params.sQ) tSrQ_rope = tiled_mma_qk.make_fragment_A(qk_params.sQ_rope) tSrKC = tiled_mma_qk.make_fragment_B(qk_params.sKC) tSrKC_rope = tiled_mma_qk.make_fragment_B(qk_params.sKC_rope) tOrP = tiled_mma_pv.make_fragment_A(pv_params.sP) tOrVC = tiled_mma_pv.make_fragment_B(pv_params.sVC) tStS_shape = tiled_mma_qk.partition_shape_C( cute.select(self.mma_qk_tiler, mode=[0, 1]) ) tStS_staged_fake = tiled_mma_qk.make_fragment_C( cute.append(tStS_shape, self.mma_s_stage) ) # use real tmem ptr for tStS tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout) tOtO_shape = tiled_mma_pv.partition_shape_C( cute.select(self.mma_pv_tiler, mode=[0, 1]) ) # mma O has 1 stage. tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) tOtO_layout = cute.append( tOtO.layout, cute.make_layout( common_params.L // self.mma_pv_tiler[1], stride=self.mma_pv_tiler[1] // self.warps_in_n, ), ) tOtO_staged = cute.make_tensor( tStS_staged.iterator + self.tmem_o_offset, tOtO_layout ) # set more parameters qk_params.tSrQ = tSrQ qk_params.tSrQ_rope = tSrQ_rope qk_params.tSrKC = tSrKC qk_params.tSrKC_rope = tSrKC_rope qk_params.tStS_staged = tStS_staged pv_params.tOrP = tOrP pv_params.tOrVC = tOrVC pv_params.tOtO_staged = tOtO_staged # mma O accumulates on K, so the accumlate flag is set to False once before all K blocks. tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, False) load_q_pipeline = common_params.load_q_pipeline if common_params.is_leader_cta: load_q_release_state = load_q_consumer_state.clone() ( tiled_mma_qk, load_q_consumer_state, load_k_consumer_state, mma_s_producer_state, ) = self.mma_qk( common_params, qk_params, tiled_mma_qk, load_q_consumer_state, load_k_consumer_state, mma_s_producer_state, wait_q=True, ) k_tile_count -= 1 while k_tile_count > 0: ( tiled_mma_qk, load_q_consumer_state, load_k_consumer_state, mma_s_producer_state, ) = self.mma_qk( common_params, qk_params, tiled_mma_qk, load_q_consumer_state, load_k_consumer_state, mma_s_producer_state, wait_q=False, ) ( tiled_mma_pv, load_v_consumer_state, p_mma_consumer_state, mma_o_producer_state, ) = self.mma_pv( common_params, pv_params, tiled_mma_pv, load_v_consumer_state, p_mma_consumer_state, mma_o_producer_state, ) k_tile_count -= 1 # release q consumer states load_q_pipeline.consumer_release(load_q_release_state) load_q_release_state.advance() ( tiled_mma_pv, load_v_consumer_state, p_mma_consumer_state, mma_o_producer_state, ) = self.mma_pv( common_params, pv_params, tiled_mma_pv, load_v_consumer_state, p_mma_consumer_state, mma_o_producer_state, ) return ( tiled_mma_qk, tiled_mma_pv, load_q_consumer_state, load_k_consumer_state, load_v_consumer_state, mma_s_producer_state, p_mma_consumer_state, mma_o_producer_state, ) @cute.jit def mma_qk( self, common_params: SimpleNamespace, qk_params: SimpleNamespace, tiled_mma_qk: cute.TiledMma, load_q_consumer_state: pipeline.PipelineState, load_k_consumer_state: pipeline.PipelineState, mma_s_producer_state: pipeline.PipelineState, wait_q: bool, ) -> tuple[ cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, ]: """Compute one k-tile of mma for Q*K^T. Updates the tiled MMA QK and pipeline states. :param qk_params: The qk parameters :type qk_params: SimpleNamespace :param tiled_mma_qk: The tiled mma qk :type tiled_mma_qk: cute.TiledMma :param load_q_consumer_state: The load q consumer state :type load_q_consumer_state: pipeline.PipelineState :param load_k_consumer_state: The load k consumer state :type load_k_consumer_state: pipeline.PipelineState :param mma_s_producer_state: The mma s producer state :type mma_s_producer_state: pipeline.PipelineState :return: The tiled mma qk, the load q consumer state, the load k consumer state, and the mma s producer state :rtype: tuple[cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] """ tStS = qk_params.tStS_staged[None, None, None, mma_s_producer_state.index] qk_params.mma_s_pipeline.producer_acquire(mma_s_producer_state) tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, False) load_q_pipeline = common_params.load_q_pipeline load_k_pipeline = common_params.load_k_pipeline if cutlass.const_expr(wait_q): load_q_pipeline.consumer_wait(load_q_consumer_state) load_k_pipeline.consumer_wait(load_k_consumer_state) for q_stage in range(self.iterations_qk_latent): kc_stage = load_k_consumer_state.index for k_block in cutlass.range_constexpr(cute.size(qk_params.tSrQ.shape[2])): cute.gemm( tiled_mma_qk, tStS, qk_params.tSrQ[None, None, k_block, (q_stage, 0)], qk_params.tSrKC[None, None, k_block, (q_stage, kc_stage)], tStS, ) tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) for q_stage in range(self.iterations_qk_rope): kc_stage = load_k_consumer_state.index for k_block in cutlass.range_constexpr( self.rope_dim // tiled_mma_qk.shape_mnk[2] ): cute.gemm( tiled_mma_qk, tStS, qk_params.tSrQ_rope[None, None, k_block, q_stage], qk_params.tSrKC_rope[None, None, k_block, kc_stage], tStS, ) tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) load_k_pipeline.consumer_release(load_k_consumer_state) load_k_consumer_state.advance() if cutlass.const_expr(wait_q): load_q_consumer_state.advance() qk_params.mma_s_pipeline.producer_commit(mma_s_producer_state) mma_s_producer_state.advance() return ( tiled_mma_qk, load_q_consumer_state, load_k_consumer_state, mma_s_producer_state, ) @cute.jit def mma_pv( self, common_params: SimpleNamespace, pv_params: SimpleNamespace, tiled_mma_pv: cute.TiledMma, load_v_consumer_state: pipeline.PipelineState, p_mma_consumer_state: pipeline.PipelineState, mma_o_producer_state: pipeline.PipelineState, ) -> tuple[ cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, ]: """Compute one k-tile of mma for P*V. Updates the tiled mma pv and pipeline states. :param common_params: The common parameters :type common_params: SimpleNamespace :param pv_params: The pv parameters :type pv_params: SimpleNamespace :param tiled_mma_pv: The tiled mma pv :type tiled_mma_pv: cute.TiledMma :param load_v_consumer_state: The load v consumer state :type load_v_consumer_state: pipeline.PipelineState :param p_mma_consumer_state: The P MMA consumer state :type p_mma_consumer_state: pipeline.PipelineState :param mma_o_producer_state: The MMA o producer state :type mma_o_producer_state: pipeline.PipelineState :return: The tiled mma pv, the load v consumer state, the P MMA consumer state, and the MMA o producer state :rtype: tuple[cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] """ pv_params.p_mma_pipeline.consumer_wait(p_mma_consumer_state) load_v_pipeline = common_params.load_v_pipeline accumulate_flag = tiled_mma_pv.get(tcgen05.Field.ACCUMULATE) mma_o_pipeline = pv_params.mma_o_pipeline load_v_pipeline.consumer_wait(load_v_consumer_state) vc_stage = load_v_consumer_state.index for acc_stage in range(self.iterations_pv_n): mma_o_pipeline.producer_acquire(mma_o_producer_state) tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, accumulate_flag) for p_stage in range(self.iterations_pv_k): tOtO = pv_params.tOtO_staged[None, None, None, acc_stage] for k_block in cutlass.range_constexpr(pv_params.tOrP.shape[2]): cute.gemm( tiled_mma_pv, tOtO, pv_params.tOrP[ None, None, k_block, (p_stage, p_mma_consumer_state.index), ], pv_params.tOrVC[ None, None, k_block, ((acc_stage, p_stage), vc_stage) ], tOtO, ) tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, True) mma_o_pipeline.producer_commit(mma_o_producer_state) mma_o_producer_state.advance() load_v_pipeline.consumer_release(load_v_consumer_state) load_v_consumer_state.advance() pv_params.p_mma_pipeline.consumer_release(p_mma_consumer_state) p_mma_consumer_state.advance() return ( tiled_mma_pv, load_v_consumer_state, p_mma_consumer_state, mma_o_producer_state, ) @cute.jit def compute( self, common_params: SimpleNamespace, softmax_params: SimpleNamespace, k_index: cutlass.Int32, k_tile_count: cutlass.Int32, mma_s_consumer_state: pipeline.PipelineState, p_mma_producer_state: pipeline.PipelineState, p_cor_producer_state: pipeline.PipelineState, ) -> tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState]: """Compute warp to compute the result of softmax, rescale, and epilogue. Updates the related pipeline states. :param common_params: The common parameters :type common_params: SimpleNamespace :param softmax_params: The softmax parameters :type softmax_params: SimpleNamespace :param k_index: The index of the k-tile :type k_index: cutlass.Int32 :param k_tile_count: The number of k-tiles :type k_tile_count: cutlass.Int32 :param mma_s_consumer_state: The MMA s consumer state :type mma_s_consumer_state: pipeline.PipelineState :param p_mma_producer_state: The P MMA producer state :type p_mma_producer_state: pipeline.PipelineState :param p_cor_producer_state: The P correction producer state :type p_cor_producer_state: pipeline.PipelineState :return: The MMA s consumer state, the P MMA producer state, and the P correction producer state :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] """ k_tile_total = cute.ceil_div(common_params.K, self.mma_qk_tiler[1]) row_max = -self.acc_dtype.inf row_sum = self.acc_dtype(0) correction_factor = self.acc_dtype(1) common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state) # no mask applied while k_tile_count > 1: ( mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state, row_max, row_sum, correction_factor, ) = self.softmax( common_params, softmax_params, k_index, mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state, row_max, row_sum, correction_factor, False, False, ) k_index = k_index + 1 k_tile_count = k_tile_count - 1 # mask applied if cutlass.const_expr(common_params.mAccO is not None): ( mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state, row_max, row_sum, correction_factor, ) = self.softmax( common_params, softmax_params, k_index, mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state, row_max, row_sum, correction_factor, k_index == k_tile_total - 1, True, ) else: ( mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state, row_max, row_sum, correction_factor, ) = self.softmax( common_params, softmax_params, k_index, mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state, row_max, row_sum, correction_factor, True, True, ) return mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state @cute.jit def correction( self, common_params: SimpleNamespace, epilogue_params: SimpleNamespace, k_tile_count: cutlass.Int32, p_cor_consumer_state: pipeline.PipelineState, mma_o_consumer_state: pipeline.PipelineState, ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: """Compute warp to compute the result of softmax, rescale, and epilogue. Updates the related pipeline states. :param common_params: The common parameters :type common_params: SimpleNamespace :param epilogue_params: The epilogue parameters :type epilogue_params: SimpleNamespace :param k_index: The index of the k-tile :type k_index: cutlass.Int32 :param k_tile_count: The number of k-tiles :type k_tile_count: cutlass.Int32 :param p_cor_consumer_state: The P correction consumer state :type p_cor_consumer_state: pipeline.PipelineState :param mma_o_consumer_state: The MMA o consumer state :type mma_o_consumer_state: pipeline.PipelineState :return: The P correction consumer state, and the MMA o consumer state :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] """ k_tile_count_init = k_tile_count while k_tile_count > 0: p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction = ( self.get_correction_factor(common_params, p_cor_consumer_state) ) if k_tile_count_init != k_tile_count: mma_o_consumer_state = self.rescale( common_params, mma_o_consumer_state, correction_factor, no_correction, ) k_tile_count = k_tile_count - 1 if k_tile_count == 0: mma_o_consumer_state = self.epilogue( common_params, epilogue_params, mma_o_consumer_state, row_sum, row_max, ) return p_cor_consumer_state, mma_o_consumer_state @cute.jit def exchange_p_cor_metadata( self, common_params: SimpleNamespace, softmax_params: SimpleNamespace, correction_factor: cutlass.Float32, row_sum: cutlass.Float32, row_max: cutlass.Float32, row_max_new: cutlass.Float32, tAcc: cute.Tensor, tidx: cutlass.Int32, p_cor_producer_state: pipeline.PipelineState, ) -> tuple[pipeline.PipelineState, cutlass.Float32]: """Compute the correction factor for the last k tile.""" no_correction = 0 if ( row_max_new - row_max ) * softmax_params.softmax_scale_log2 <= self.skip_correction_threshold: no_correction = 1 row_max_new = row_max # pad for 4x32b corr_layout = cute.make_layout( (tAcc.shape[0], (4, tAcc.shape[1][1]), self.mma_s_stage), stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4), ) tCor = cute.make_tensor( common_params.tmem_ptr + self.correction_factor_offset, corr_layout, ) cCor = cute.make_identity_tensor(tCor.shape) corr_tmem_store_atom = cute.make_copy_atom( tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype ) corr_tmem_store_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_store_atom, tCor) corr_tmem_store_thr_copy = corr_tmem_store_tiled_copy.get_slice(tidx) cCor_for_copy = corr_tmem_store_thr_copy.partition_S(cCor) tCor_for_copy = corr_tmem_store_thr_copy.partition_D(tCor) rCor = cute.make_fragment_like( cCor_for_copy[None, None, None, 0], self.acc_dtype ) rCor_int = cute.make_tensor( cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout ) rCor[0] = row_sum rCor[1] = row_max_new rCor[2] = correction_factor rCor_int[3] = no_correction cute.copy( corr_tmem_store_tiled_copy, rCor, tCor_for_copy[None, None, None, p_cor_producer_state.index], ) # fence between tmem store and correction warp cute.arch.fence_view_async_tmem_store() common_params.p_cor_pipeline.producer_commit(p_cor_producer_state) p_cor_producer_state.advance() return p_cor_producer_state, row_max_new @cute.jit def softmax( self, common_params: SimpleNamespace, softmax_params: SimpleNamespace, k_index: cutlass.Int32, mma_s_consumer_state: pipeline.PipelineState, p_mma_producer_state: pipeline.PipelineState, p_cor_producer_state: pipeline.PipelineState, row_max: cutlass.Float32, row_sum: cutlass.Float32, correction_factor: cutlass.Float32, is_last_tile: bool, is_local_last_tile: cutlass.Boolean, ) -> tuple[ pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32, ]: """Softmax for one k-tile. Updates the related pipeline states and returns the computed results. :param common_params: The common parameters :type common_params: SimpleNamespace :param softmax_params: The softmax parameters :type softmax_params: SimpleNamespace :param k_index: The index of the k-tile :type k_index: cutlass.Int32 :param mma_s_consumer_state: The MMA s consumer state :type mma_s_consumer_state: pipeline.PipelineState :param p_mma_producer_state: The P MMA producer state :type p_mma_producer_state: pipeline.PipelineState :param p_cor_producer_state: The P correction producer state :type p_cor_producer_state: pipeline.PipelineState :param row_max: The row max :type row_max: cutlass.Float32 :param row_sum: The row sum :type row_sum: cutlass.Float32 :param correction_factor: The correction factor :type correction_factor: cutlass.Float32 :param is_last_tile: Whether the last tile :type is_last_tile: bool :param is_local_last_tile: Whether the last tile is local :type is_local_last_tile: cutlass.Boolean :return: The MMA s consumer state, the P MMA producer state, the P correction producer state, the row max, the row sum, and the correction factor :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32] """ softmax_params.p_mma_pipeline.producer_acquire(p_mma_producer_state) softmax_params.mma_s_pipeline.consumer_wait(mma_s_consumer_state) # load S from tmem tStS_shape = softmax_params.tiled_mma_qk.partition_shape_C( cute.select(self.mma_qk_tiler, mode=[0, 1]) ) tStS_staged_fake = softmax_params.tiled_mma_qk.make_fragment_C( cute.append(tStS_shape, self.mma_s_stage) ) tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout) tStS = tStS_staged[None, None, None, mma_s_consumer_state.index] tAcc = tStS[(None, None), 0, 0] cta_qk_tiler = ( self.mma_qk_tiler[0] // self.cluster_shape_mnk[0], self.mma_qk_tiler[1], self.mma_qk_tiler[2], ) cS = cute.make_identity_tensor(cute.select(cta_qk_tiler, mode=[0, 1])) tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype ) tmem_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) tmem_thr_copy = tmem_tiled_copy.get_slice(tidx) tTR_tAcc = tmem_thr_copy.partition_S(tAcc) tTR_tS = tmem_thr_copy.partition_D(cS) tTR_rAcc = cute.make_fragment_like(tTR_tS, self.acc_dtype) row_max_new = row_max arch = BaseDSL._get_dsl().get_arch_enum() if cutlass.const_expr(arch >= Arch.sm_100 and arch <= Arch.sm_100f): cute.copy(tmem_tiled_copy, tTR_tAcc, tTR_rAcc) for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): if is_last_tile: tTR_rAcc[i] = ( tTR_rAcc[i] if cute.elem_less( tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, common_params.K, ) else -self.acc_dtype.inf ) # reduction for row_max row_max_new = tTR_rAcc.load().reduce(cute.ReductionOp.MAX, row_max_new, 0) elif cutlass.const_expr(arch >= Arch.sm_103 and arch <= Arch.sm_103f): tmem_load_red_atom = cute.make_copy_atom( tcgen05.copy.LdRed32x32bOp( tcgen05.copy.Repetition(64), redOp=tcgen05.TmemLoadRedOp.MAX ), self.acc_dtype, ) tmem_red_tiled_copy = tcgen05.make_tmem_copy(tmem_load_red_atom, tAcc) tmem_red_thr_copy = tmem_red_tiled_copy.get_slice(tidx) tTR_tAcc_red = tmem_red_thr_copy.partition_S(tAcc) tTR_tS_red = tmem_red_thr_copy.partition_D(cS) tTR_rAcc_red = cute.make_fragment_like(tTR_tS_red, self.acc_dtype) tTR_rMax = cute.make_rmem_tensor( cute.make_layout((1, tTR_tS_red.shape[1], tTR_tS_red.shape[2])), self.acc_dtype, ) cute.copy( tmem_red_tiled_copy, tTR_tAcc_red, (tTR_rAcc_red, tTR_rMax), ) tTR_rAcc = cute.make_tensor(tTR_rAcc_red.iterator, tTR_rAcc.layout) if is_last_tile: for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): tTR_rAcc[i] = ( tTR_rAcc[i] if cute.elem_less( tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, common_params.K, ) else -self.acc_dtype.inf ) # reduction for row_max row_max_new = tTR_rAcc.load().reduce( cute.ReductionOp.MAX, row_max_new, 0 ) else: row_max_new = cute.arch.fmax(row_max_new, tTR_rMax[0]) # if warps in N is 2, reduce row_max across warps (0, 1) and (2, 3) if cutlass.const_expr(self.warps_in_n == 2): common_params.smem_exchange[tidx] = row_max_new self.softmax_exchange_sync_bar.wait() row_max_new = cute.arch.fmax( row_max_new, common_params.smem_exchange[ (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) ], ) # find correction factor correction_factor = cute.math.exp2( (row_max - row_max_new) * softmax_params.softmax_scale_log2, fastmath=True ) # split kv case if cutlass.const_expr(not is_local_last_tile): p_cor_producer_state, row_max_new = self.exchange_p_cor_metadata( common_params, softmax_params, correction_factor, row_sum, row_max, row_max_new, tAcc, tidx, p_cor_producer_state, ) # softmax fma_b = softmax_params.softmax_scale_log2 fma_c = (0.0 - row_max_new) * softmax_params.softmax_scale_log2 for i in cutlass.range(cute.size(tTR_rAcc), vectorize=True, unroll_full=True): tTR_rAcc[i] = tTR_rAcc[i] * fma_b + fma_c tTR_rAcc[i] = cute.math.exp2(tTR_rAcc[i], fastmath=True) tTR_rS = cute.make_fragment_like(tTR_tS, self.q_dtype) # quantize tTR_rS.store(tTR_rAcc.load().to(self.q_dtype)) # create sP sP = softmax_params.sP[None, None, None, (None, p_mma_producer_state.index)] sP_mk_view = cute.make_tensor( sP.iterator, cute.make_layout( ( (sP.shape[0][0], sP.shape[1]), (sP.shape[0][1], sP.shape[2], sP.shape[3]), ), stride=( (sP.stride[0][0], sP.stride[1]), (sP.stride[0][1], sP.stride[2], sP.stride[3]), ), ), ) # change to PISL sP_wo_swizzle_iter = cute.recast_ptr(sP.iterator, swizzle_=None) swizzle_bits = ( int(math.log2(self.mma_pv_tiler[2] * self.q_dtype.width // 8 // 32)) + 1 ) swizzle_base = 3 if self.q_dtype.width == 16 else 4 sP_swizzle = cute.make_swizzle(swizzle_bits, swizzle_base, 3) sP_mk_view = cute.make_tensor( sP_wo_swizzle_iter, cute.make_composed_layout(sP_swizzle, 0, sP_mk_view.layout), ) universal_copy_bits = 128 smem_copy_atom = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.q_dtype, num_bits_per_copy=universal_copy_bits, ) smem_tiled_copy = cute.make_tiled_copy_D(smem_copy_atom, tmem_tiled_copy) smem_thr_copy = smem_tiled_copy.get_slice(tidx) rP_copy_view = smem_thr_copy.retile(tTR_rS) sP_copy_view = smem_thr_copy.partition_D(sP_mk_view) cute.copy(smem_tiled_copy, rP_copy_view, sP_copy_view) # fence between smem store and mma o cute.arch.fence_view_async_shared() softmax_params.p_mma_pipeline.producer_commit(p_mma_producer_state) p_mma_producer_state.advance() # row_sum, using `add_packed_f32x2` to reduce the number of instructions row_sum = row_sum * correction_factor row_sum_vec = (0.0, 0.0) for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): row_sum_vec = cute.arch.add_packed_f32x2( row_sum_vec, (tTR_rAcc[i], tTR_rAcc[i + 1]) ) row_sum = row_sum_vec[0] + row_sum_vec[1] + row_sum # split kv case if cutlass.const_expr(is_local_last_tile): p_cor_producer_state, row_max_new = self.exchange_p_cor_metadata( common_params, softmax_params, correction_factor, row_sum, row_max, row_max_new, tAcc, tidx, p_cor_producer_state, ) # store correction factor/row_sum/row_max to tmem for correction warp common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state) # fence between tmem load and mma s cute.arch.fence_view_async_tmem_load() softmax_params.mma_s_pipeline.consumer_release(mma_s_consumer_state) mma_s_consumer_state.advance() return ( mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state, row_max_new, row_sum, correction_factor, ) @cute.jit def _tmem_load_partition( self, common_params: SimpleNamespace, tiled_mma_pv: cute.TiledMma, iter_n: int ) -> tuple[ cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma ]: """Tensor memory load partition for rescale and epilogue. :param common_params: The common parameters :type common_params: SimpleNamespace :param tiled_mma_pv: The tiled mma pv :type tiled_mma_pv: cute.TiledMma :param iter_n: The iteration number :type iter_n: int :return: The tiled mma pv, the tiled mma pv, the tiled mma pv, the tiled mma pv, the tiled mma pv :rtype: tuple[cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma] """ tOtO_shape = tiled_mma_pv.partition_shape_C( cute.select(self.mma_pv_tiler, mode=[0, 1]) ) tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) tOtO_layout = cute.append( tOtO.layout, cute.make_layout( common_params.L // self.mma_pv_tiler[1], stride=self.mma_pv_tiler[1] // self.warps_in_n, ), ) tOtO = cute.make_tensor( common_params.tmem_ptr + self.tmem_o_offset, tOtO_layout ) tOtO = tOtO[None, None, None, iter_n] tAcc = tOtO[(None, None), 0, 0] tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype ) tmem_load_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) tmem_load_thr_copy = tmem_load_tiled_copy.get_slice( common_params.tidx % (self.num_compute_warps * self.threads_per_warp) ) cta_pv_tiler = ( self.mma_pv_tiler[0] // self.cluster_shape_mnk[0], self.mma_pv_tiler[1], self.mma_pv_tiler[2], ) # Flatten divide and partition global tensors for O cta_pv_tiler_mn = cute.select(cta_pv_tiler, mode=[0, 1]) gO = None if cutlass.const_expr(common_params.mAccO is not None): gO = cute.local_tile( common_params.mAccO[None, common_params.blk_coord[3], None, None, None], cta_pv_tiler_mn, ( common_params.blk_coord[0], iter_n, common_params.blk_coord[1], common_params.blk_coord[2], ), ) cO = cute.local_tile( cute.make_identity_tensor( common_params.mAccO[ None, common_params.blk_coord[3], None, None, None ].shape ), cta_pv_tiler_mn, ( common_params.blk_coord[0], iter_n, common_params.blk_coord[1], common_params.blk_coord[2], ), ) else: gO = cute.local_tile( common_params.mO, cta_pv_tiler_mn, ( common_params.blk_coord[0], iter_n, common_params.blk_coord[1], common_params.blk_coord[2], ), ) cO = cute.local_tile( cute.make_identity_tensor(common_params.mO.shape), cta_pv_tiler_mn, ( common_params.blk_coord[0], iter_n, common_params.blk_coord[1], common_params.blk_coord[2], ), ) tTR_tAcc = tmem_load_thr_copy.partition_S(tAcc) tTR_gO = tmem_load_thr_copy.partition_D(gO) tTR_cO = tmem_load_thr_copy.partition_D(cO) tTR_rAcc = cute.make_fragment_like(tTR_gO, self.acc_dtype) return tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc def get_correction_factor( self, common_params: SimpleNamespace, p_cor_consumer_state: pipeline.PipelineState, ) -> tuple[ pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32, cutlass.Int32, ]: """Get the correction factor from the P correction consumer state. :param common_params: The common parameters :type common_params: SimpleNamespace :param p_cor_consumer_state: The P correction consumer state :type p_cor_consumer_state: pipeline.PipelineState :return: The P correction consumer state, the row_sum, the row_max, and the correction factor :rtype: tuple[pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32, cutlass.Int32] """ common_params.p_cor_pipeline.consumer_wait(p_cor_consumer_state) tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) # load correction factor _, tAcc, _, _, _, _ = self._tmem_load_partition( common_params, common_params.tiled_mma_pv, 0 ) corr_layout = cute.make_layout( (tAcc.shape[0], (4, tAcc.shape[1][1]), self.p_cor_stage), stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4), ) tCor = cute.make_tensor( common_params.tmem_ptr + self.correction_factor_offset, corr_layout ) cCor = cute.make_identity_tensor(tCor.shape) corr_tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype ) corr_tmem_load_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_load_atom, tCor) corr_tmem_load_thr_copy = corr_tmem_load_tiled_copy.get_slice(tidx) tCor_for_copy = corr_tmem_load_thr_copy.partition_S(tCor) cCor_for_copy = corr_tmem_load_thr_copy.partition_D(cCor) rCor = cute.make_fragment_like( cCor_for_copy[None, None, None, 0], self.acc_dtype ) rCor_int = cute.make_tensor( cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout ) cute.copy( corr_tmem_load_tiled_copy, tCor_for_copy[None, None, None, p_cor_consumer_state.index], rCor, ) row_sum = rCor[0] row_max = rCor[1] correction_factor = rCor[2] no_correction = rCor_int[3] common_params.p_cor_pipeline.consumer_release(p_cor_consumer_state) p_cor_consumer_state.advance() return p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction @cute.jit def rescale( self, common_params: SimpleNamespace, mma_o_consumer_state: pipeline.PipelineState, correction_factor: cutlass.Float32, no_correction: cutlass.Int32, ) -> pipeline.PipelineState: """Rescale for one k-tile. Updates the related pipeline state. :param common_params: The common parameters :type common_params: SimpleNamespace :param mma_o_consumer_state: The mma o consumer state :type mma_o_consumer_state: pipeline.PipelineState :param correction_factor: The correction factor :type correction_factor: cutlass.Float32 :param no_correction: Whether to apply correction factor :type no_correction: cutlass.Int32 :return: The MMA o consumer state :rtype: pipeline.PipelineState """ skip_correction = cute.arch.vote_all_sync(no_correction == 1) for iter_n in cutlass.range_constexpr(self.iterations_pv_n): common_params.mma_o_pipeline.consumer_wait(mma_o_consumer_state) if not skip_correction: # tmem load tiled copy and partition results. tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = ( self._tmem_load_partition( common_params, common_params.tiled_mma_pv, iter_n ) ) # tmem store tiled copy tmem_store_atom = cute.make_copy_atom( tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype ) tmem_store_tiled_copy = tcgen05.make_tmem_copy(tmem_store_atom, tAcc) # load o cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc) # rescale, using `mul_packed_f32x2` to reduce the number of instructions for i in cutlass.range( cute.size(tTR_rAcc), vectorize=True, unroll_full=True ): tTR_rAcc[i] = tTR_rAcc[i] * correction_factor # store o to tensor memory for next k tile cute.copy(tmem_store_tiled_copy, tTR_rAcc, tTR_tAcc) cute.arch.fence_view_async_tmem_store() common_params.mma_o_pipeline.consumer_release(mma_o_consumer_state) mma_o_consumer_state.advance() return mma_o_consumer_state @cute.jit def epilogue( self, common_params: SimpleNamespace, epilogue_params: SimpleNamespace, mma_o_consumer_state: pipeline.PipelineState, row_sum: cutlass.Float32, row_max: cutlass.Float32, ) -> pipeline.PipelineState: """Epilogue for one k-tile. Updates the related pipeline state. :param common_params: The common parameters :type common_params: SimpleNamespace :param epilogue_params: The epilogue parameters :type epilogue_params: SimpleNamespace :param mma_o_consumer_state: The mma o consumer state :type mma_o_consumer_state: pipeline.PipelineState :param row_sum: The row sum :type row_sum: cutlass.Float32 :param row_max: The row max :type row_max: cutlass.Float32 :return: The MMA o consumer state :rtype: pipeline.PipelineState """ tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) # exchange row_sum between warps (0, 1) and (2, 3) if cutlass.const_expr(self.warps_in_n == 2): common_params.smem_exchange[tidx] = row_sum self.epilogue_exchange_sync_bar.wait() # (64, 2) row_sum = ( row_sum + common_params.smem_exchange[ (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) ] ) # mma_o pipeline consumer wait for iter_n in cutlass.range_constexpr(self.iterations_pv_n): common_params.mma_o_pipeline.consumer_wait(mma_o_consumer_state) # tmem load tiled copy and partition results. tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = ( self._tmem_load_partition( common_params, common_params.tiled_mma_pv, iter_n ) ) # load o cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc) # apply output scale and normalize by row_sum for i in cutlass.range( cute.size(tTR_rAcc), vectorize=True, unroll_full=True ): tTR_rAcc[i] = ( tTR_rAcc[i] * epilogue_params.output_scale * cute.arch.rcp_approx(row_sum) ) # store o to global memory tR2G_rO_src = None tR2G_rO_dst = tTR_gO if cutlass.const_expr(common_params.mAccO is None): tR2G_rO_src = cute.make_fragment_like(tTR_gO, self.o_dtype) # using final output dtype for o tR2G_rO_src.store(tTR_rAcc.load().to(self.o_dtype)) else: # using accumulate dtype for o tR2G_rO_src = tTR_rAcc if cute.elem_less(tTR_cO[0][0], common_params.H): cute.autovec_copy( tR2G_rO_src, tR2G_rO_dst, l1c_evict_priority=cute.nvgpu.CacheEvictionPriority.NO_ALLOCATE, ) # store the lse to global memory cta_pv_tiler = ( self.mma_pv_tiler[0] // self.cluster_shape_mnk[0], self.mma_pv_tiler[1], self.mma_pv_tiler[2], ) gLSE = None cLSE = None if cutlass.const_expr(epilogue_params.mAccLSE is None): gLSE = cute.local_tile( epilogue_params.mLSE, (cta_pv_tiler[0], 1, 1), ( common_params.blk_coord[0], common_params.blk_coord[1], common_params.blk_coord[2], ), (1, 1, 1), ) cLSE = cute.local_tile( cute.make_identity_tensor(epilogue_params.mLSE.shape), (cta_pv_tiler[0], 1, 1), ( common_params.blk_coord[0], common_params.blk_coord[1], common_params.blk_coord[2], ), (1, 1, 1), ) else: gLSE = cute.local_tile( epilogue_params.mAccLSE[ None, common_params.blk_coord[3], None, None ], (cta_pv_tiler[0], 1, 1), ( common_params.blk_coord[0], common_params.blk_coord[1], common_params.blk_coord[2], ), (1, 1, 1), ) cLSE = cute.local_tile( cute.make_identity_tensor( epilogue_params.mAccLSE[ None, common_params.blk_coord[3], None, None ].shape ), (cta_pv_tiler[0], 1, 1), ( common_params.blk_coord[0], common_params.blk_coord[1], common_params.blk_coord[2], ), (1, 1, 1), ) lse = ( cute.math.log2(row_sum, fastmath=True) + epilogue_params.softmax_scale_log2 * row_max ) if cutlass.const_expr(self.warps_in_n == 2): if cute.elem_less(cLSE[tidx][0], common_params.H): gLSE[tidx] = lse cute.arch.fence_view_async_tmem_load() common_params.mma_o_pipeline.consumer_release(mma_o_consumer_state) mma_o_consumer_state.advance() return mma_o_consumer_state def make_and_init_load_qkv_pipeline( self, load_qkv_mbar_ptr, cta_layout_vmnk, load_stages, tx_count ) -> pipeline.PipelineTmaUmma: """Create and initialize the tma load qkv pipeline. :param load_qkv_mbar_ptr: The load qkv mbar pointer :type load_qkv_mbar_ptr: cute.Tensor :param cta_layout_vmnk: The cta layout vmnk :type cta_layout_vmnk: tuple[int, int, int] :param load_stages: The load stages :type load_stages: list[int] :param tx_count: The tx count :type tx_count: int :return: The tma load qkv pipeline :rtype: pipeline.PipelineTmaUmma """ load_qkv_producer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, len([self.load_tma_k_warp_id]) ) load_qkv_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, len([self.mma_warp_id]) ) return pipeline.PipelineTmaUmma.create( barrier_storage=load_qkv_mbar_ptr, num_stages=load_stages, producer_group=load_qkv_producer_group, consumer_group=load_qkv_consumer_group, tx_count=tx_count, cta_layout_vmnk=cta_layout_vmnk, defer_sync=True, ) def make_and_init_mma_s_pipeline( self, mma_s_mbar_ptr, cta_layout_vmnk ) -> pipeline.PipelineUmmaAsync: """Create and initialize the mma s pipeline. :param mma_s_mbar_ptr: The mma s mbar pointer :type mma_s_mbar_ptr: cute.Tensor :param cta_layout_vmnk: The cta layout vmnk :type cta_layout_vmnk: tuple[int, int, int] :return: The mma s pipeline :rtype: pipeline.PipelineUmmaAsync """ mma_s_producer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, len([self.mma_warp_id]) ) consumer_thread_size = ( self.threads_per_warp * len(self.compute_warp_ids) * self.cluster_shape_mnk[0] ) mma_s_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, consumer_thread_size, ) return pipeline.PipelineUmmaAsync.create( barrier_storage=mma_s_mbar_ptr, num_stages=self.mma_s_stage, producer_group=mma_s_producer_group, consumer_group=mma_s_consumer_group, cta_layout_vmnk=cta_layout_vmnk, defer_sync=True, ) def make_and_init_p_mma_pipeline( self, p_mma_mbar_ptr, cta_layout_vmnk ) -> pipeline.PipelineAsyncUmma: """Create and initialize the p mma pipeline. :param p_mma_mbar_ptr: The p mma mbar pointer :type p_mma_mbar_ptr: cute.Tensor :param cta_layout_vmnk: The cta layout vmnk :type cta_layout_vmnk: tuple[int, int, int] :return: The p mma pipeline :rtype: pipeline.PipelineAsyncUmma """ producer_thread_size = ( self.threads_per_warp * len(self.compute_warp_ids) * self.cluster_shape_mnk[0] ) p_mma_producer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, producer_thread_size, ) p_mma_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, len([self.mma_warp_id]) ) return pipeline.PipelineAsyncUmma.create( barrier_storage=p_mma_mbar_ptr, num_stages=self.p_mma_stage, producer_group=p_mma_producer_group, consumer_group=p_mma_consumer_group, cta_layout_vmnk=cta_layout_vmnk, defer_sync=True, ) def make_and_init_p_cor_pipeline( self, p_cor_mbar_ptr ) -> pipeline.PipelineAsyncUmma: """Create and initialize the p correction pipeline. :param p_cor_mbar_ptr: The p correction mbar pointer :type p_cor_mbar_ptr: cute.Tensor :return: The p correction pipeline :rtype: pipeline.PipelineAsyncUmma """ producer_thread_size = self.threads_per_warp * len(self.compute_warp_ids) p_cor_producer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, producer_thread_size, ) p_cor_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, producer_thread_size, ) return pipeline.PipelineAsync.create( barrier_storage=p_cor_mbar_ptr, num_stages=self.p_cor_stage, producer_group=p_cor_producer_group, consumer_group=p_cor_consumer_group, defer_sync=True, ) def make_and_init_mma_o_pipeline( self, mma_o_mbar_ptr, cta_layout_vmnk ) -> pipeline.PipelineUmmaAsync: """Create and initialize the mma o pipeline. :param mma_o_mbar_ptr: The mma o mbar pointer :type mma_o_mbar_ptr: cute.Tensor :param cta_layout_vmnk: The cta layout vmnk :type cta_layout_vmnk: tuple[int, int, int] :return: The mma o pipeline :rtype: pipeline.PipelineUmmaAsync """ mma_o_producer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, len([self.mma_warp_id]) ) consumer_thread_size = ( self.threads_per_warp * len(self.compute_warp_ids) * self.cluster_shape_mnk[0] ) mma_o_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, consumer_thread_size, ) return pipeline.PipelineUmmaAsync.create( barrier_storage=mma_o_mbar_ptr, num_stages=self.mma_o_stage, producer_group=mma_o_producer_group, consumer_group=mma_o_consumer_group, cta_layout_vmnk=cta_layout_vmnk, defer_sync=True, ) @staticmethod def _compute_grid( o: cute.Tensor, split_kv: cutlass.Int32, cluster_shape_mnk: Tuple[int, int, int], max_active_clusters: int, is_persistent: bool, ) -> Tuple[MLAStaticTileSchedulerParams, Tuple[int, int, int]]: """Compute grid shape for the output tensor C. :param c: The output tensor C :type c: cute.Tensor :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. :type cta_tile_shape_mnk: tuple[int, int, int] :param cluster_shape_mn: Shape of each cluster in M, N dimensions. :type cluster_shape_mn: tuple[int, int] :return: Tile scheduler parameters and grid shape. :rtype: tuple[MLAStaticTileSchedulerParams, tuple[int, int, int]] """ o_shape = o.shape tile_sched_params = create_mla_static_tile_scheduler_params( is_persistent, cute.size(o_shape[3]), cute.size(o_shape[2]), cluster_shape_mnk, split_kv, ) grid = MLAStaticTileScheduler.get_grid_shape( tile_sched_params, max_active_clusters ) return tile_sched_params, grid @staticmethod def get_workspace_size( H: int, S: int, D: int, B: int, split_kv: int, acc_dtype: Type[cutlass.Numeric], ) -> int: """Get the extra workspace(device memory) size for the MLA kernel when split_kv is not 1. :param H: The height of the output tensor C :type H: int :param S: The sequence length of the output tensor C :type S: int :param D: The depth of the output tensor C :type D: int :param B: The batch size of the output tensor C :type B: int :param split_kv: The split key-value of the output tensor C :type split_kv: int :param acc_dtype: The data type of the output tensor C :type acc_dtype: Type[cutlass.Numeric] :return: The workspace size for the MLA kernel :rtype: int """ if split_kv == 1: return 0 return B * H * S * split_kv * (D + 1) * acc_dtype.width // 8 @cute.jit def initialize_workspace( self, H: cutlass.Int32, D: cutlass.Int32, S: cutlass.Int32, B: cutlass.Int32, split_kv: cutlass.Int32, acc_dtype: Type[cutlass.Numeric], workspace: cute.Tensor, ) -> tuple[cute.Tensor, cute.Tensor]: """Initialize the workspace for the MLA kernel. Construct the intermediate tensors acc_o and acc_lse. :param H: The height of the output tensor C :type H: cutlass.Int32 :param D: The depth of the output tensor C :type D: cutlass.Int32 :param S: The sequence length of the output tensor C :type S: cutlass.Int32 :param B: The batch size of the output tensor C :type B: cutlass.Int32 :param split_kv: The split key-value of the output tensor C :type split_kv: cutlass.Int32 :param acc_dtype: The data type of the output tensor C :type acc_dtype: Type[cutlass.Numeric] :param workspace: The workspace tensor :type workspace: cute.Tensor :return: The output tensor C and the workspace tensor :rtype: tuple[cute.Tensor, cute.Tensor] """ acc_o, acc_lse = None, None if cutlass.const_expr(workspace is not None): align = 256 // self.q_dtype.width acc_o_layout = cute.make_layout( (H, split_kv, D, S, B), stride=( cute.assume(split_kv * D, align), cute.assume(D, align), 1, cute.assume(split_kv * H * D, align), cute.assume(H * split_kv * S * D, align), ), ) acc_o_iter = cute.recast_ptr(workspace.iterator, dtype=acc_dtype) acc_o = cute.make_tensor(acc_o_iter, acc_o_layout) acc_lse_layout = cute.make_layout( (H, split_kv, S, B), stride=(split_kv, 1, H * split_kv, H * split_kv * S), ) acc_lse_iter = cute.recast_ptr( workspace.iterator + cute.cosize(acc_o_layout) * acc_dtype.width // 8, dtype=acc_dtype, ) acc_lse = cute.make_tensor(acc_lse_iter, acc_lse_layout) return acc_o, acc_lse @staticmethod def can_implement( B: int, S: int, K: int, H: int, L: int, R: int, in_dtype: Type[cutlass.Numeric], out_dtype: Type[cutlass.Numeric], acc_dtype: Type[cutlass.Numeric], lse_dtype: Type[cutlass.Numeric], mma_qk_tiler_mn: Tuple[int, int], mma_pv_tiler_mn: Tuple[int, int], split_kv: int, is_persistent: bool, is_var_seq: bool, is_var_split_kv: bool, page_size: int, ) -> bool: """Check if the MLA kernel can be implemented. :param B: The batch size of the output tensor C :type B: int :param S: The sequence length of the output tensor C :type S: int :param K: The width of the output tensor KV :type K: int :param H: The number of heads of the output tensor C :type H: int :param L: The number of latent dimensions of the tensor KV :type L: int :param R: The number of rope dimensions of the tensor C_rope :type R: int :param in_dtype: The data type of the input tensor :type in_dtype: Type[cutlass.Numeric] :param out_dtype: The data type of the output tensor :type out_dtype: Type[cutlass.Numeric] :param acc_dtype: The data type of the accumulator :type acc_dtype: Type[cutlass.Numeric] :param lse_dtype: The data type of the log-sum-exp :type lse_dtype: Type[cutlass.Numeric] :param mma_qk_tiler_mn: The tile shape of the query-key matrix multiplication :type mma_qk_tiler_mn: Tuple[int, int] :param mma_pv_tiler_mn: The tile shape of the probability-value matrix multiplication :type mma_pv_tiler_mn: Tuple[int, int] :param split_kv: The split key-value of the output tensor C :type split_kv: int :param is_persistent: Whether to use persistent kernel optimization :type is_persistent: bool :param is_var_seq: Whether to use variable sequence length :type is_var_seq: bool :param is_var_split_kv: Whether to use variable split_kv :type is_var_split_kv: bool :param page_size: The page size of the page table :type page_size: int :return: Whether the MLA kernel can be implemented :rtype: bool """ if L != 512 or R != 64: return False if in_dtype not in [cutlass.Float8E4M3FN]: return False if out_dtype not in [cutlass.Float8E4M3FN]: return False if acc_dtype != cutlass.Float32 or lse_dtype != cutlass.Float32: return False # page size equals 1 is prohibited by tma specification, not 128B aligned. if mma_qk_tiler_mn[1] % page_size != 0 or page_size == 1: return False if mma_qk_tiler_mn[0] != mma_pv_tiler_mn[0] or mma_qk_tiler_mn[0] != 128: return False if is_var_split_kv and not is_var_seq: return False if H > 128 or (H < 128 and split_kv != 1): return False if S <= 0 or S > 4: return False if K <= 0: return False return True def run( batch_size: int, seq_len_q: int, seq_len_k: int, num_heads: int, latent_dim: int, rope_dim: int, in_dtype: Type[cutlass.Numeric], out_dtype: Type[cutlass.Numeric], acc_dtype: Type[cutlass.Numeric], lse_dtype: Type[cutlass.Numeric], mma_qk_tiler_mn: Tuple[int, int], mma_pv_tiler_mn: Tuple[int, int], split_kv: int, is_persistent: bool, is_var_seq: bool, is_var_split_kv: bool, page_size: int, softmax_scale: float, output_scale: float, skip_correction_threshold: float, tolerance: float, warmup_iterations: int, iterations: int, skip_ref_check: bool, use_cold_l2: bool, **kwargs, ): """Execute Multi-Head Latent Attention (MLA) on Blackwell architecture and validate results. This function creates random input tensors for query latent/rope, compressed latent/rope, and value, then performs the complete MLA computation pipeline. It supports configurable data types, tiling parameters, page table, variable sequence length, and variable split_kv. Results can be validated against a PyTorch reference implementation or run multiple times for performance measurement. :param batch_size: Batch size :type batch_size: int :param seq_len_q: Sequence length of Q :type seq_len_q: int :param seq_len_k: Sequence length of K :type seq_len_k: int :param num_heads: Number of heads :type num_heads: int :param latent_dim: dimension of query/compressed latent :type latent_dim: int :param rope_dim: dimension of query/compressed rope :type rope_dim: int :param in_dtype: Input data type for query/compressed latent/rope tensors :type in_dtype: Type[cutlass.Numeric] :param out_dtype: Output data type for attention output :type out_dtype: Type[cutlass.Numeric] :param acc_dtype: Accumulator data type for query-key matrix multiplication :type acc_dtype: Type[cutlass.Numeric] :param lse_dtype: Accumulator data type for log-sum-exp :type lse_dtype: Type[cutlass.Numeric] :param mma_qk_tiler_mn: Matrix multiply accumulate tile shape (M, N) for query-key matrix multiplication :type mma_qk_tiler_mn: Tuple[int, int] :param mma_pv_tiler_mn: Matrix multiply accumulate tile shape (M, N) for probability-value matrix multiplication :type mma_pv_tiler_mn: Tuple[int, int] :param split_kv: Split key-value :type split_kv: int :param is_persistent: Whether to use persistent kernel optimization :type is_persistent: bool :param is_var_seq: Whether to use variable sequence length :type is_var_seq: bool :param is_var_split_kv: Whether to use variable split_kv :type is_var_split_kv: bool :param page_size: Page size of the page table :type page_size: int :param softmax_scale: Attention score scaling factor :type softmax_scale: float :param output_scale: Output scaling factor :type output_scale: float :param skip_correction_threshold: Threshold to skip correction :type skip_correction_threshold: float :param tolerance: Maximum acceptable error for validation :type tolerance: float :param warmup_iterations: Number of warmup iterations :type warmup_iterations: int :param iterations: Number of iterations to run for performance testing :type iterations: int :param skip_ref_check: Skip validation against reference implementation :type skip_ref_check: bool :param use_cold_l2: Whether to use cold L2 cache :type use_cold_l2: bool :raises ValueError: If input shapes are incompatible or head dimension is unsupported :raises RuntimeError: If GPU is unavailable for computation """ print("Running Blackwell MLA test with:") print(f" batch_size: {batch_size}") print(f" seq_len_q: {seq_len_q}") print(f" seq_len_k: {seq_len_k}") print(f" num_heads: {num_heads}") print(f" latent_dim: {latent_dim}") print(f" rope_dim: {rope_dim}") print(f" in_dtype: {in_dtype}") print(f" out_dtype: {out_dtype}") print(f" acc_dtype: {acc_dtype}") print(f" mma_qk_tiler_mn: {mma_qk_tiler_mn}") print(f" mma_pv_tiler_mn: {mma_pv_tiler_mn}") print(f" split_kv: {split_kv}") print(f" is_persistent: {is_persistent}") print(f" is_var_seq: {is_var_seq}") print(f" is_var_split_kv: {is_var_split_kv}") print(f" page_size: {page_size}") print(f" softmax_scale: {softmax_scale}") print(f" output_scale: {output_scale}") print(f" skip_correction_threshold: {skip_correction_threshold}") print(f" tolerance: {tolerance}") print(f" warmup_iterations: {warmup_iterations}") print(f" iterations: {iterations}") print(f" skip_ref_check: {skip_ref_check}") print(f" use_cold_l2: {use_cold_l2}") import torch import cutlass.torch as cutlass_torch # Prepare pytorch tensors: Q, K, V (random from 0 to 2) and O (all zero) if not torch.cuda.is_available(): raise RuntimeError("GPU is required to run this example!") if not BlackwellMultiHeadLatentAttentionForwardFP8.can_implement( batch_size, seq_len_q, seq_len_k, num_heads, latent_dim, rope_dim, in_dtype, out_dtype, acc_dtype, lse_dtype, mma_qk_tiler_mn, mma_pv_tiler_mn, split_kv, is_persistent, is_var_seq, is_var_split_kv, page_size, ): raise TypeError( f"Unsupported testcase {batch_size}, {seq_len_q}, {seq_len_k}, {num_heads}, {latent_dim}, {rope_dim}, {in_dtype}, {out_dtype}, {acc_dtype}, {lse_dtype}, {mma_qk_tiler_mn}, {mma_pv_tiler_mn}, {split_kv}, {is_persistent}, {is_var_seq}, {is_var_split_kv}, {page_size}" ) torch.manual_seed(1111) def create_data_tensor( B, HK, D, dtype, is_dynamic_layout=True, page_table=None, cache_seqs=None, is_lse=False, seq_len_q=None, ): shape = (B, HK, D) if page_table is not None: if cache_seqs is not None: max_seq_len = torch.max(cache_seqs) shape = (B * ceil_div(max_seq_len, page_size), page_size, D) else: shape = (B * ceil_div(HK, page_size), page_size, D) if seq_len_q is not None: shape = (B, seq_len_q, HK, D) permute_order = (1, 2, 0) stride_order = (2, 0, 1) leading_dim = 1 if is_lse: shape = (B, seq_len_q, HK) permute_order = (2, 1, 0) stride_order = (2, 1, 0) leading_dim = 0 elif seq_len_q is not None: permute_order = (2, 3, 1, 0) stride_order = (3, 2, 0, 1) leading_dim = 1 init_config = cutlass.torch.RandomInitConfig(min_val=-2, max_val=2) torch_dtype = ( cutlass_torch.dtype(dtype) if dtype != cutlass.Float8E4M3FN else torch.int8 ) # Create dtype torch tensor (cpu) torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( shape, torch_dtype, permute_order=permute_order, init_type=cutlass.torch.TensorInitType.RANDOM, init_config=init_config, ) # Create dtype torch tensor (gpu) torch_tensor_gpu = torch_tensor_cpu.cuda() # Create f32 torch tensor (cpu) f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) # Create dtype cute tensor (gpu) cute_tensor = from_dlpack(torch_tensor_gpu, assumed_align=16) cute_tensor.element_type = dtype if is_dynamic_layout: cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim) if not is_lse: cute_tensor = cute_tensor.mark_compact_shape_dynamic( mode=leading_dim, stride_order=stride_order, divisibility=(128 // dtype.width), ) cute_tensor = cutlass_torch.convert_cute_tensor( f32_torch_tensor, cute_tensor, dtype, is_dynamic_layout=is_dynamic_layout, ) return f32_torch_tensor, cute_tensor, torch_tensor_gpu def create_cache_seqs(batch_size, seq_len_k, is_var_seq): cache_seqs_ref = torch.ones(batch_size, dtype=torch.int32) * seq_len_k cache_seqs_gpu = cache_seqs_ref.cuda() cache_seqs = from_dlpack(cache_seqs_gpu, assumed_align=16).mark_layout_dynamic() if is_var_seq: max_seq_len = seq_len_k min_seq_len = int(seq_len_k * 0.8) cache_seqs_ref = cutlass_torch.create_and_permute_torch_tensor( (batch_size,), torch.int32, init_type=cutlass.torch.TensorInitType.RANDOM, init_config=cutlass.torch.RandomInitConfig( min_val=min_seq_len, max_val=max_seq_len + 1 ), ) cache_seqs_gpu = cache_seqs_ref.cuda() cache_seqs = from_dlpack( cache_seqs_gpu, assumed_align=16, ).mark_layout_dynamic() return cache_seqs_ref, cache_seqs, cache_seqs_gpu def create_page_table(batch_size, seq_len_k, is_var_seq, page_size): max_seq_len = seq_len_k if not is_var_seq else torch.max(cache_seqs_ref) page_count = ceil_div(max_seq_len, page_size) page_table_ref = torch.empty([batch_size, page_count], dtype=torch.int32) # use transposed index for page table to make sure the value is in bound of `batch_size * seq_len_block`. In practice, the value could be any positive values. This setting is only for testing purpose. for b in range(batch_size): for j in range(page_count): page_table_ref[b, j] = b + j * batch_size page_table_gpu = page_table_ref.permute(1, 0).cuda() page_table = from_dlpack(page_table_gpu, assumed_align=16).mark_layout_dynamic( leading_dim=0 ) return page_table_ref, page_table, page_table_gpu def create_block_split_kvs( batch_size, split_kv, cache_seqs_ref, is_var_split_kv, mma_qk_tiler_mn, cluster_shape_mnk, max_active_clusters, ): block_split_kvs_ref, block_split_kvs, block_split_kvs_gpu = None, None, None # check if split_kv is valid otherwise do auto setting of split_kv if is_var_split_kv: block_split_kvs_ref = torch.zeros([batch_size], dtype=torch.int32) for b in range(batch_size): block_split_kvs_ref[b] = ( BlackwellMultiHeadLatentAttentionForwardFP8.get_split_kv( batch_size, seq_len_q, cache_seqs_ref[b].item(), mma_qk_tiler_mn, max_active_clusters * cluster_shape_mnk[0], ) ) split_kv = torch.max(block_split_kvs_ref).item() block_split_kvs_gpu = block_split_kvs_ref.cuda() block_split_kvs = from_dlpack( block_split_kvs_gpu, assumed_align=16 ).mark_layout_dynamic() elif split_kv <= 0: split_kv = BlackwellMultiHeadLatentAttentionForwardFP8.get_split_kv( batch_size, seq_len_q, cache_seqs_ref[0].item(), mma_qk_tiler_mn, max_active_clusters * cluster_shape_mnk[0], ) return split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_gpu def create_workspace( num_heads, seq_len_q, latent_dim, batch_size, split_kv, acc_dtype ): workspace_size = BlackwellMultiHeadLatentAttentionForwardFP8.get_workspace_size( num_heads, seq_len_q, latent_dim, batch_size, split_kv, acc_dtype, ) workspace, workspace_torch = None, None if workspace_size > 0: workspace_torch = torch.empty([workspace_size], dtype=torch.int8).cuda() workspace = from_dlpack(workspace_torch, assumed_align=32) return workspace, workspace_torch cache_seqs_ref, cache_seqs, cache_seqs_torch = create_cache_seqs( batch_size, seq_len_k, is_var_seq ) page_table_ref, page_table, page_table_torch = create_page_table( batch_size, seq_len_k, is_var_seq, page_size ) cluster_shape_mnk = (2, 1, 1) hardware_info = utils.HardwareInfo() max_active_clusters = hardware_info.get_max_active_clusters( cluster_shape_mnk[0] * cluster_shape_mnk[1] ) split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_torch = ( create_block_split_kvs( batch_size, split_kv, cache_seqs_ref, is_var_split_kv, mma_qk_tiler_mn, cluster_shape_mnk, max_active_clusters, ) ) q_latent_ref, q_latent, q_latent_torch = create_data_tensor( batch_size, num_heads, latent_dim, in_dtype, is_dynamic_layout=True, seq_len_q=seq_len_q, ) q_rope_ref, q_rope, q_rope_torch = create_data_tensor( batch_size, num_heads, rope_dim, in_dtype, is_dynamic_layout=True, seq_len_q=seq_len_q, ) c_latent_ref, c_latent, c_latent_torch = create_data_tensor( batch_size, seq_len_k, latent_dim, in_dtype, is_dynamic_layout=True, page_table=page_table, cache_seqs=cache_seqs_ref, ) c_rope_ref, c_rope, c_rope_torch = create_data_tensor( batch_size, seq_len_k, rope_dim, in_dtype, is_dynamic_layout=True, page_table=page_table, cache_seqs=cache_seqs_ref, ) o_ref, o, o_torch = create_data_tensor( batch_size, num_heads, latent_dim, out_dtype, is_dynamic_layout=True, seq_len_q=seq_len_q, ) lse_ref, lse, lse_torch = create_data_tensor( batch_size, num_heads, 1, lse_dtype, is_dynamic_layout=True, is_lse=True, seq_len_q=seq_len_q, ) workspace, workspace_torch = create_workspace( num_heads, seq_len_q, latent_dim, batch_size, split_kv, acc_dtype ) mla = BlackwellMultiHeadLatentAttentionForwardFP8( acc_dtype, lse_dtype, mma_qk_tiler_mn, mma_pv_tiler_mn, max_active_clusters, page_size, skip_correction_threshold, is_persistent, is_var_seq, is_var_split_kv, ) # Get current CUDA stream from PyTorch torch_stream = torch.cuda.current_stream() # Get the raw stream pointer as a CUstream stream = cuda.CUstream(torch_stream.cuda_stream) # compile mla kernel compiled_mla = cute.compile( mla, q_latent, q_rope, c_latent, c_rope, page_table, o, lse, workspace, split_kv, cache_seqs, block_split_kvs, softmax_scale, output_scale, stream, options="--opt-level 2", ) def torch_reference_mla( q_latent, q_rope, c_latent, c_rope, page_table, cache_seqs, softmax_scale=1.0, output_scale=1.0, ): # expand and concat q_latent and q_rope to have the dimension of sequence length for q q_ref = torch.cat([q_latent, q_rope], dim=1).permute(3, 2, 0, 1) # expand and concat c_latent and c_rope to have the dimension of num_heads for k and v page_count = page_table_ref.shape[1] k_ref_paged = ( torch.cat([c_latent, c_rope], dim=1) .permute(2, 0, 1) .reshape(batch_size * page_count, page_size, latent_dim + rope_dim) ) v_ref_paged = c_latent.permute(2, 0, 1).reshape( batch_size * page_count, page_size, latent_dim ) if is_var_seq: max_seq_len = torch.max(cache_seqs_ref) else: max_seq_len = seq_len_k k_ref = torch.zeros([batch_size, 1, max_seq_len, latent_dim + rope_dim]) v_ref = torch.zeros([batch_size, 1, max_seq_len, latent_dim]) k_ref = torch.index_select( k_ref_paged, 0, torch.flatten(page_table_ref) ).reshape(batch_size, 1, -1, latent_dim + rope_dim)[:, :, :max_seq_len, :] v_ref = torch.index_select( v_ref_paged, 0, torch.flatten(page_table_ref) ).reshape(batch_size, 1, -1, latent_dim)[:, :, :max_seq_len, :] for b in range(batch_size): k_ref[b, :, cache_seqs_ref[b] :, :] = 0 v_ref[b, :, cache_seqs_ref[b] :, :] = 0 import torch.nn.functional as F o_ref = F.scaled_dot_product_attention( q_ref, k_ref, v_ref, attn_mask=None, dropout_p=0.0, scale=softmax_scale, is_causal=False, ) s_ref = torch.einsum("bhld,bhsd->bhls", q_ref, k_ref) s_ref_max, s_ref_max_pos = torch.max(s_ref, dim=-1, keepdim=True) softmax_scale_log2 = LOG2_E * softmax_scale s_ref_sum = torch.sum( torch.exp2((s_ref - s_ref_max) * softmax_scale_log2), dim=-1, keepdim=True ) lse_ref = s_ref_max * softmax_scale_log2 + torch.log2(s_ref_sum) lse_ref = lse_ref.squeeze(3).permute(2, 1, 0) o_ref = o_ref * output_scale o_ref = o_ref.permute(2, 3, 1, 0) return o_ref, lse_ref if skip_correction_threshold > 0.0: print( "Skipping correction verification since skip_correction_threshold is greater than 0.0..." ) skip_ref_check = True if not skip_ref_check: # Execute kernel once for reference checking compiled_mla( q_latent, q_rope, c_latent, c_rope, page_table, o, lse, workspace, split_kv, cache_seqs, block_split_kvs, softmax_scale, output_scale, stream, ) torch.cuda.synchronize() print("Verifying results...") if in_dtype == cutlass.Float8E4M3FN: tolerance = 0.13 o_ref, lse_ref = torch_reference_mla( q_latent_ref, q_rope_ref, c_latent_ref, c_rope_ref, page_table, cache_seqs, softmax_scale, output_scale, ) if out_dtype in [cutlass.Float8E5M2, cutlass.Float8E4M3FN]: # convert o back to f32 for comparison o_fp32, o_fp32_torch = cutlass_torch.cute_tensor_like( torch.empty(*o_torch.shape, dtype=torch.float32), cutlass.Float32, is_dynamic_layout=True, assumed_align=16, ) cute.testing.convert(o, o_fp32) o = o_fp32_torch.cpu() ref_fp8, _ = cutlass_torch.cute_tensor_like( torch.empty( *o_ref.permute(3, 2, 0, 1).shape, dtype=torch.uint8 ).permute(2, 3, 1, 0), out_dtype, is_dynamic_layout=True, assumed_align=16, ) o_ref_gpu = o_ref.cuda() o_ref_f32 = from_dlpack(o_ref_gpu).mark_layout_dynamic(leading_dim=1) # convert ref : f32 -> fp8 -> f32 cute.testing.convert(o_ref_f32, ref_fp8) cute.testing.convert(ref_fp8, o_ref_f32) o_ref = o_ref_gpu.cpu() else: o = o_torch.cpu().to(torch.float32) lse = lse_torch.cpu() lse_ref = lse_ref.to(cutlass.torch.dtype(lse_dtype)) # Assert close results torch.testing.assert_close(o, o_ref, atol=tolerance, rtol=1e-05) torch.testing.assert_close(lse, lse_ref, atol=tolerance, rtol=1e-05) print("Results verified successfully!") def generate_tensors(): _, cache_seqs, _ = create_cache_seqs(batch_size, seq_len_k, is_var_seq) _, page_table, _ = create_page_table( batch_size, seq_len_k, is_var_seq, page_size ) _split_kv, _, block_split_kvs, _ = create_block_split_kvs( batch_size, split_kv, cache_seqs_ref, is_var_split_kv, mma_qk_tiler_mn, cluster_shape_mnk, max_active_clusters, ) _, q_latent, _ = create_data_tensor( batch_size, num_heads, latent_dim, in_dtype, is_dynamic_layout=True, seq_len_q=seq_len_q, ) _, q_rope, _ = create_data_tensor( batch_size, num_heads, rope_dim, in_dtype, is_dynamic_layout=True, seq_len_q=seq_len_q, ) _, c_latent, _ = create_data_tensor( batch_size, seq_len_k, latent_dim, in_dtype, is_dynamic_layout=True, page_table=page_table, cache_seqs=cache_seqs_ref, ) _, c_rope, _ = create_data_tensor( batch_size, seq_len_k, rope_dim, in_dtype, is_dynamic_layout=True, page_table=page_table, cache_seqs=cache_seqs_ref, ) _, o, _ = create_data_tensor( batch_size, num_heads, latent_dim, out_dtype, is_dynamic_layout=True, seq_len_q=seq_len_q, ) _, lse, _ = create_data_tensor( batch_size, num_heads, 1, lse_dtype, is_dynamic_layout=True, is_lse=True, seq_len_q=seq_len_q, ) workspace, workspace_torch = create_workspace( num_heads, seq_len_q, latent_dim, batch_size, _split_kv, acc_dtype ) return testing.JitArguments( q_latent, q_rope, c_latent, c_rope, page_table, o, lse, workspace, _split_kv, cache_seqs, block_split_kvs, softmax_scale, output_scale, stream, ) workspace_count = 1 if use_cold_l2: one_workspace_bytes = ( q_latent_torch.numel() * q_latent_torch.element_size() + q_rope_torch.numel() * q_rope_torch.element_size() + c_latent_torch.numel() * c_latent_torch.element_size() + c_rope_torch.numel() * c_rope_torch.element_size() + o_torch.numel() * o_torch.element_size() + lse_torch.numel() * lse_torch.element_size() + cache_seqs_torch.numel() * cache_seqs_torch.element_size() ) one_workspace_bytes += ( page_table_torch.numel() * page_table_torch.element_size() ) if is_var_split_kv: one_workspace_bytes += ( block_split_kvs_torch.numel() * block_split_kvs_torch.element_size() ) if workspace_torch is not None: one_workspace_bytes += ( workspace_torch.numel() * workspace_torch.element_size() ) workspace_count = testing.get_workspace_count( one_workspace_bytes, warmup_iterations, iterations ) avg_time_us = testing.benchmark( compiled_mla, workspace_generator=generate_tensors, workspace_count=workspace_count, stream=stream, warmup_iterations=warmup_iterations, iterations=iterations, ) return avg_time_us # Return execution time in microseconds if __name__ == "__main__": def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: try: return tuple(int(x.strip()) for x in s.split(",")) except ValueError: raise argparse.ArgumentTypeError( "Invalid format. Expected comma-separated integers." ) def parse_mma_tiler(s: str) -> Tuple[int, int, Tuple[int, int]]: ret = parse_comma_separated_ints(s) if len(ret) != 2: raise argparse.ArgumentTypeError( "Invalid format. Expected 2 comma-separated integers." ) return (ret[0], ret[1]) parser = argparse.ArgumentParser(description="Example of MLA on Blackwell.") parser.add_argument( "--in_dtype", type=cutlass.dtype, default=cutlass.Float8E4M3FN, help="Input data type", ) parser.add_argument( "--out_dtype", type=cutlass.dtype, default=cutlass.Float8E4M3FN, help="Output data type", ) parser.add_argument( "--acc_dtype", type=cutlass.dtype, default=cutlass.Float32, help="Accumulator data type", ) parser.add_argument( "--lse_dtype", type=cutlass.dtype, default=cutlass.Float32, help="LSE data type", ) parser.add_argument( "--mma_qk_tiler_mn", type=parse_mma_tiler, default=(128, 128), help="MMA tile shape (H, K)", ) parser.add_argument( "--mma_pv_tiler_mn", type=parse_mma_tiler, default=(128, 256), help="MMA tile shape (H, D)", ) parser.add_argument( "--is_persistent", action="store_true", help="Is persistent", ) parser.add_argument( "--batch_size", type=int, default=1, help="Batch size", ) parser.add_argument( "--seq_len_q", type=int, default=1, help="Sequence length of Q", ) parser.add_argument( "--seq_len_k", type=int, default=128, help="Sequence length of K/V", ) parser.add_argument( "--num_heads", type=int, default=128, help="Number of heads of Q", ) parser.add_argument( "--latent_dim", type=int, default=512, help="Latent dimension of Q/C", ) parser.add_argument( "--rope_dim", type=int, default=64, help="Rope dimension of Q/C", ) parser.add_argument( "--is_var_seq", action="store_true", help="Use variable length of sequence length or not", ) parser.add_argument( "--is_var_split_kv", action="store_true", help="Use variable length of split kv or not", ) parser.add_argument( "--page_size", type=int, default=128, help="Page size of page table", ) parser.add_argument( "--split_kv", type=int, default=-1, help="Split KV setting", ) parser.add_argument( "--softmax_scale", type=float, default=0.0416, help="Scaling factor to scale softmax", ) parser.add_argument( "--output_scale", type=float, default=1.0, help="Scaling factor to scale output", ) parser.add_argument( "--skip_correction_threshold", type=float, default=0.0, help="Threshold to skip correction", ) parser.add_argument( "--tolerance", type=float, default=1e-02, help="Tolerance for validation" ) parser.add_argument( "--warmup_iterations", type=int, default=0, help="Number of iterations for warmup", ) parser.add_argument( "--iterations", type=int, default=1, help="Number of iterations after warmup", ) parser.add_argument( "--skip_ref_check", action="store_true", help="Skip reference check", ) parser.add_argument( "--use_cold_l2", action="store_true", help="Use cold L2 cache", ) args = parser.parse_args() run( args.batch_size, args.seq_len_q, args.seq_len_k, args.num_heads, args.latent_dim, args.rope_dim, args.in_dtype, args.out_dtype, args.acc_dtype, args.lse_dtype, args.mma_qk_tiler_mn, args.mma_pv_tiler_mn, args.split_kv, args.is_persistent, args.is_var_seq, args.is_var_split_kv, args.page_size, args.softmax_scale, args.output_scale, args.skip_correction_threshold, args.tolerance, args.warmup_iterations, args.iterations, args.skip_ref_check, args.use_cold_l2, ) print("PASS")