# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA # is strictly prohibited. import enum from typing import Tuple, Optional import cutlass from cutlass.cute.typing import Boolean from cutlass.cutlass_dsl import ( Int32, Float32, min, extract_mlir_values, new_from_mlir_values, ) from cutlass.utils.hardware_info import HardwareInfo from cutlass.utils import WorkTileInfo import cutlass.cute as cute ############################################################################## # Fmha static tile scheduler ############################################################################## class FmhaStaticTileSchedulerParams: """A class to represent parameters for the FMHA (Fused Multi-Head Attention) static tile scheduler. This class holds the configuration parameters needed to initialize and configure the tile scheduler for FMHA operations. :ivar is_persistent: Whether to use persistent kernel mode. :type is_persistent: bool :ivar problem_shape_mbh: Problem shape in (M, B, H) format. :type problem_shape_mbh: cute.Shape """ def __init__( self, is_persistent: bool, problem_shape_mbh: cute.Shape, *, loc=None, ip=None, ): """ Initializes the FmhaStaticTileSchedulerParams with the given parameters. :param is_persistent: Whether to use persistent kernel mode. :type is_persistent: bool :param problem_shape_mbh: Problem shape in (M, B, H) format. :type problem_shape_mbh: cute.Shape """ self.is_persistent = is_persistent self.problem_shape_mbh = problem_shape_mbh self._loc = loc self._ip = ip def __extract_mlir_values__(self): values, self._values_pos = [], [] for obj in [self.problem_shape_mbh]: obj_values = extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) return values def __new_from_mlir_values__(self, values): obj_list = [] for obj, n_items in zip([self.problem_shape_mbh], self._values_pos): obj_list.append(new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] return FmhaStaticTileSchedulerParams( self.is_persistent, *(tuple(obj_list)), loc=self._loc ) class FmhaStaticTileScheduler: """A static tile scheduler for FMHA (Fused Multi-Head Attention) operations. This class manages the scheduling of work tiles for FMHA kernels, supporting both persistent and non-persistent kernel modes. It tracks the current work position and advances through the problem space efficiently. :ivar _params: Scheduler parameters. :type _params: FmhaStaticTileSchedulerParams :ivar _blk_coord: Block coordinates. :type _blk_coord: cute.Coord :ivar _grid_shape: Grid shape for the kernel. :type _grid_shape: cute.Shape :ivar _is_persistent: Whether to use persistent kernel mode. :type _is_persistent: bool :ivar _current_work_linear_idx: Current linear work index. :type _current_work_linear_idx: Int32 :ivar _problem_shape_mbh: Problem shape in (M, B, H) format. :type _problem_shape_mbh: cute.Layout :ivar _num_blocks: Number of blocks in the problem. :type _num_blocks: Int32 :ivar _is_first_block: Whether this is the first block. :type _is_first_block: bool :ivar num_persistent_sm: Number of persistent SMs. :type num_persistent_sm: Int32 """ def __init__( self, params: FmhaStaticTileSchedulerParams, current_work_linear_idx: Int32, blk_coord: cute.Coord, grid_shape: cute.Shape, *, loc=None, ip=None, ): """ Initializes the FmhaStaticTileScheduler with the given parameters. :param params: Scheduler parameters. :type params: FmhaStaticTileSchedulerParams :param current_work_linear_idx: Current linear work index. :type current_work_linear_idx: Int32 :param blk_coord: Block coordinates. :type blk_coord: cute.Coord :param grid_shape: Grid shape for the kernel. :type grid_shape: cute.Shape """ self._params = params self._blk_coord = blk_coord self._grid_shape = grid_shape self._is_persistent = params.is_persistent self._current_work_linear_idx = current_work_linear_idx self._problem_shape_mbh = cute.make_layout( params.problem_shape_mbh, loc=loc, ip=ip ) self._num_blocks = cute.size(self._problem_shape_mbh, loc=loc, ip=ip) self._is_first_block = True self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip) self._loc = loc self._ip = ip # called by host @staticmethod def get_grid_shape( params: FmhaStaticTileSchedulerParams, *, loc=None, ip=None, ) -> cute.Shape: """ Determine the grid shape for the FMHA kernel. For persistent kernels, the grid shape is limited by the number of SMs (Streaming Multiprocessors) available on the device. For non-persistent kernels, the grid shape matches the problem shape. :param params: Scheduler parameters. :type params: FmhaStaticTileSchedulerParams :return: Grid shape as (M, B, H) tuple. :rtype: cute.Shape """ if params.is_persistent: hardware_info = HardwareInfo() sm_count = hardware_info.get_device_multiprocessor_count() return ( min(sm_count, cute.size(params.problem_shape_mbh, loc=loc, ip=ip)), 1, 1, ) else: return params.problem_shape_mbh @staticmethod def check_valid_work_for_seqlen_q( q_tiler: int, current_idx: Int32, seqlen_q: Int32, ) -> Boolean: """ Check if the current work index is valid for the given query sequence length. This method verifies that the current work tile index multiplied by the query tiler size is within the bounds of the query sequence length. :param q_tiler: Query tiler size. :type q_tiler: int :param current_idx: Current work index. :type current_idx: Int32 :param seqlen_q: Query sequence length. :type seqlen_q: Int32 :return: True if the work is valid, False otherwise. :rtype: Boolean """ return current_idx * q_tiler < seqlen_q def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: """ Get information about the current work tile. Determines if the current work is valid and computes the tile coordinates based on whether the kernel is persistent or non-persistent. :return: WorkTileInfo containing tile coordinates and validity flag. :rtype: WorkTileInfo """ is_valid = ( self._current_work_linear_idx < self._num_blocks if self._is_persistent else self._is_first_block ) blk_coord = (0, 0, 0) if self._is_persistent: blk_coord = self._problem_shape_mbh.get_hier_coord( self._current_work_linear_idx, loc=loc, ip=ip ) else: blk_coord = self._blk_coord # cur_tile_coord is (mid, 0, (bid, hid)) cur_tile_coord = ( blk_coord[0], 0, (blk_coord[1], blk_coord[2]), ) return WorkTileInfo(cur_tile_coord, is_valid) def initial_work_tile_info(self, *, loc=None, ip=None): """ Get the initial work tile information. :return: Initial WorkTileInfo. :rtype: WorkTileInfo """ return self.get_current_work(loc=loc, ip=ip) def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None): """ Advance to the next work tile. For persistent kernels, advances by the number of persistent SMs. For non-persistent kernels, marks that the first block has been processed. :param advance_count: Number of steps to advance (default: 1). :type advance_count: int """ if self._is_persistent: self._current_work_linear_idx += advance_count * self.num_persistent_sm self._is_first_block = False def __extract_mlir_values__(self): values = extract_mlir_values(self._params) values.extend(extract_mlir_values(self._current_work_linear_idx)) values.extend(extract_mlir_values(self._blk_coord)) values.extend(extract_mlir_values(self._grid_shape)) return values def __new_from_mlir_values__(self, values): assert len(values) == 10 new_params = new_from_mlir_values(self._params, values[0:3]) new_current_work_linear_idx = new_from_mlir_values( self._current_work_linear_idx, [values[3]] ) new_blk_coord = new_from_mlir_values(self._blk_coord, values[4:7]) new_grid_shape = new_from_mlir_values(self._grid_shape, values[7:]) return FmhaStaticTileScheduler( new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape ) def create_fmha_static_tile_scheduler( params: FmhaStaticTileSchedulerParams, blk_coord: cute.Coord, grid_shape: cute.Shape, ) -> FmhaStaticTileScheduler: """ Create a new FMHA static tile scheduler. :param params: Scheduler parameters. :type params: FmhaStaticTileSchedulerParams :param blk_coord: Block coordinates. :type blk_coord: cute.Coord :param grid_shape: Grid shape. :type grid_shape: cute.Shape :return: New FmhaStaticTileScheduler instance. :rtype: FmhaStaticTileScheduler """ return FmhaStaticTileScheduler(params, blk_coord[0], blk_coord, grid_shape) def create_fmha_static_tile_scheduler_params( is_persistent: bool, problem_shape_mbh: cute.Shape, ) -> FmhaStaticTileSchedulerParams: """ Create FMHA static tile scheduler parameters. :param is_persistent: Whether to use persistent kernel mode. :type is_persistent: bool :param problem_shape_mbh: Problem shape in (M, B, H) format. :type problem_shape_mbh: cute.Shape :return: New FmhaStaticTileSchedulerParams instance. :rtype: FmhaStaticTileSchedulerParams """ return FmhaStaticTileSchedulerParams(is_persistent, problem_shape_mbh) def compute_grid( o_shape: cute.Shape, cta_tiler: Tuple[int, int, int], is_persistent: bool, ) -> Tuple[FmhaStaticTileSchedulerParams, Tuple[int, int, int]]: """ Compute grid parameters for FMHA operation. This function calculates the appropriate grid shape and scheduler parameters based on the output tensor shape, CTA (Cooperative Thread Array) tiler, and whether to use persistent kernel mode. The output tensor o has shape (s, d, ((h_r, h_k), b)) where: - s: sequence length - d: head dimension - h_r: number of heads for query - h_k: number of heads for key - b: batch size :param o_shape: Output tensor shape for grid computation. :type o_shape: cute.Shape :param cta_tiler: CTA tiler dimensions (M, N, K). :type cta_tiler: Tuple[int, int, int] :param is_persistent: Whether to use persistent kernel mode. :type is_persistent: bool :return: Tuple of (scheduler_params, grid_shape). :rtype: Tuple[FmhaStaticTileSchedulerParams, Tuple[int, int, int]] """ tile_sched_params = create_fmha_static_tile_scheduler_params( is_persistent, ( cute.ceil_div(cute.size(o_shape[0]), cta_tiler[0]), cute.size(o_shape[2][0]), cute.size(o_shape[2][1]), ), ) grid = FmhaStaticTileScheduler.get_grid_shape(tile_sched_params) return tile_sched_params, grid ############################################################################## # Fused Mask ############################################################################## class MaskEnum(enum.Enum): """Enumeration of mask types for FMHA operations. - RESIDUAL_MASK: Residual mask for handling variable sequence lengths - WINDOW_MASK: Window mask for attention which also includes causal and no mask - WINDOW_MASK_INFERENCE: Same as the window mask, but has the limitation that the end of q is aligned with the end of k - WINDOW_MASK_BWD: Window mask for backward pass - WINDOW_MASK_BWD_INFERENCE: Same as the window mask for backward pass, but has the limitation that the end of q is aligned with the end of k """ RESIDUAL_MASK = enum.auto() RESIDUAL_MASK_BWD = enum.auto() WINDOW_MASK = enum.auto() WINDOW_MASK_INFERENCE = enum.auto() WINDOW_MASK_BWD = enum.auto() WINDOW_MASK_BWD_INFERENCE = enum.auto() class FusedMask: """A fused mask implementation for FMHA operations. This class handles different types of attention masks including no mask, residual mask for variable sequence lengths, and causal mask for autoregressive attention patterns. The class provides methods to: - Calculate trip counts for different mask types - Apply masks to attention scores - Handle masked and unmasked trip calculations """ def get_trip_count( mask_type: MaskEnum, blk_coord: cute.Coord, tile_shape: cute.Shape, seqlen_q: Int32, seqlen_k: Int32, window_size_left: Optional[Int32] = None, window_size_right: Optional[Int32] = None, ) -> Int32: """ Calculate the number of trips needed for the current block. The trip count depends on the mask type and the block coordinates. For causal masks, it considers the autoregressive constraint. :param mask_type: Type of mask to use :type mask_type: utils.MaskEnum :param blk_coord: Block coordinates. :type blk_coord: cute.Coord :param tile_shape: Shape of the tile. :type tile_shape: cute.Shape :param seqlen_q: Query sequence length for attention computation. :type seqlen_q: Int32 :param seqlen_k: Key sequence length for attention computation. :type seqlen_k: Int32 :param window_size_left: Left-side sliding window size for attention masking. :type window_size_left: Optional[Int32] :param window_size_right: Right-side sliding window size for attention masking. :type window_size_right: Optional[Int32] :return: Number of trips needed. :rtype: Int32 """ result = 0 offset = 0 if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_INFERENCE): offset = seqlen_k - seqlen_q if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE): offset = seqlen_q - seqlen_k if cutlass.const_expr(mask_type == MaskEnum.RESIDUAL_MASK): result = cute.ceil_div(seqlen_k, tile_shape[1]) if cutlass.const_expr(mask_type is MaskEnum.RESIDUAL_MASK_BWD): result = cute.ceil_div(seqlen_q, tile_shape[0]) if cutlass.const_expr( mask_type == MaskEnum.WINDOW_MASK or mask_type == MaskEnum.WINDOW_MASK_INFERENCE ): if cutlass.const_expr(window_size_right is None): result = cute.ceil_div(seqlen_k, tile_shape[1]) else: max_idx_q = (blk_coord[0] + 1) * tile_shape[0] idx_k = max_idx_q + offset + window_size_right tmp_blocks_k = cute.ceil_div(idx_k, tile_shape[1]) max_blocks_k = cute.ceil_div(seqlen_k, tile_shape[1]) result = min(max_blocks_k, tmp_blocks_k) if cutlass.const_expr( mask_type == MaskEnum.WINDOW_MASK_BWD or mask_type == MaskEnum.WINDOW_MASK_BWD_INFERENCE ): if cutlass.const_expr(window_size_left is None): result = cute.ceil_div(seqlen_q, tile_shape[0]) else: max_idx_k = (blk_coord[1] + 1) * tile_shape[1] idx_k = max_idx_k + offset + window_size_left tmp_blocks_q = cute.ceil_div(idx_k, tile_shape[0]) max_blocks_q = cute.ceil_div(seqlen_q, tile_shape[0]) result = min(max_blocks_q, tmp_blocks_q) start_block = FusedMask.get_trip_start( mask_type, blk_coord, tile_shape, seqlen_q, seqlen_k, window_size_left, window_size_right, ) result = result - start_block return result @cute.jit def get_trip_start( mask_type: MaskEnum, blk_coord: cute.Coord, tile_shape: cute.Shape, seqlen_q: Int32, seqlen_k: Int32, window_size_left: Optional[Int32] = None, window_size_right: Optional[Int32] = None, ) -> Int32: """ Get the start of the trip for the current block. :param mask_type: Type of mask to use :type mask_type: utils.MaskEnum :param blk_coord: Block coordinates. :type blk_coord: cute.Coord :param tile_shape: Shape of the tile. :type tile_shape: cute.Shape :param seqlen_q: Query sequence length for attention computation. :type seqlen_q: Int32 :param seqlen_k: Key sequence length for attention computation. :type seqlen_k: Int32 :param window_size_left: Left-side sliding window size for attention masking. :type window_size_left: Optional[Int32] :param window_size_right: Right-side sliding window size for attention masking. :type window_size_right: Optional[Int32] """ result = 0 offset = 0 if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_INFERENCE): offset = seqlen_k - seqlen_q if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE): offset = seqlen_q - seqlen_k if cutlass.const_expr( mask_type is MaskEnum.WINDOW_MASK or mask_type is MaskEnum.WINDOW_MASK_INFERENCE ): if cutlass.const_expr(window_size_left is not None): min_idx_q = blk_coord[0] * tile_shape[0] idx_k = min_idx_q + offset - window_size_left tmp_blocks_k = idx_k // tile_shape[1] result = max(tmp_blocks_k, result) if cutlass.const_expr( mask_type is MaskEnum.WINDOW_MASK_BWD or mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE ): if cutlass.const_expr(window_size_right is not None): min_idx_k = blk_coord[1] * tile_shape[1] idx_q = min_idx_k + offset - window_size_right tmp_blocks_q = idx_q // tile_shape[0] result = max(tmp_blocks_q, result) return result @cute.jit def get_leading_mask_id( mask_type: MaskEnum, blk_coord: cute.Coord, tile_shape: cute.Shape, seqlen_q: Int32, seqlen_k: Int32, window_size_left: Optional[Int32] = None, window_size_right: Optional[Int32] = None, ) -> Tuple[Int32, Int32]: """ Get the begin and end tile idx for the leading mask. :param mask_type: Type of mask to use :type mask_type: utils.MaskEnum :param blk_coord: Block coordinates. :type blk_coord: cute.Coord :param tile_shape: Shape of the tile. :type tile_shape: cute.Shape :param seqlen_q: Query sequence length for attention computation. :type seqlen_q: Int32 :param seqlen_k: Key sequence length for attention computation. :type seqlen_k: Int32 :param window_size_left: Left-side sliding window size for attention masking. :type window_size_left: Optional[Int32] :param window_size_right: Right-side sliding window size for attention masking. :type window_size_right: Optional[Int32] :return: Tuple of (begin, end) tile idx for the leading mask. :rtype: Tuple[Int32, Int32] """ offset = 0 if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_INFERENCE): offset = seqlen_k - seqlen_q if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE): offset = seqlen_q - seqlen_k leading_mask_begin = FusedMask.get_trip_start( mask_type, blk_coord, tile_shape, seqlen_q, seqlen_k, window_size_left, window_size_right, ) trip_count = FusedMask.get_trip_count( mask_type, blk_coord, tile_shape, seqlen_q, seqlen_k, window_size_left, window_size_right, ) leading_mask_end = leading_mask_begin if cutlass.const_expr( mask_type is MaskEnum.WINDOW_MASK or mask_type is MaskEnum.WINDOW_MASK_INFERENCE ): if cutlass.const_expr(window_size_left is not None): min_idx_q = ( (blk_coord[0] + 1) * tile_shape[0] + offset - window_size_left ) leading_mask_end = min( cute.ceil_div(min_idx_q, tile_shape[1]) - 1, trip_count + leading_mask_begin - 1, ) else: leading_mask_end = leading_mask_begin - 1 elif cutlass.const_expr( mask_type is MaskEnum.WINDOW_MASK_BWD or mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE ): if cutlass.const_expr(window_size_right is not None): min_idx_k = ( (blk_coord[1] + 1) * tile_shape[1] + offset - window_size_right ) leading_mask_end = cute.ceil_div(min_idx_k, tile_shape[0]) - 1 else: leading_mask_end = leading_mask_begin - 1 return leading_mask_begin, leading_mask_end @cute.jit def get_trailing_mask_id( mask_type: MaskEnum, blk_coord: cute.Coord, tile_shape: cute.Shape, seqlen_q: Int32, seqlen_k: Int32, window_size_left: Optional[Int32] = None, window_size_right: Optional[Int32] = None, ) -> Tuple[Optional[Int32], Optional[Int32]]: """ Get the begin and end tile idx for the trailing mask. :param mask_type: Type of mask to use :type mask_type: utils.MaskEnum :param blk_coord: Block coordinates. :type blk_coord: cute.Coord :param tile_shape: Shape of the tile. :type tile_shape: cute.Shape :param seqlen_q: Query sequence length for attention computation. :type seqlen_q: Int32 :param seqlen_k: Key sequence length for attention computation. :type seqlen_k: Int32 :param window_size_left: Left-side sliding window size for attention masking. :type window_size_left: Optional[Int32] :param window_size_right: Right-side sliding window size for attention masking. :type window_size_right: Optional[Int32] :return: Tuple of (begin, end) tile idx for the trailing mask. :rtype: Tuple[Int32, Int32] """ offset = 0 if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_INFERENCE): offset = seqlen_k - seqlen_q if cutlass.const_expr(mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE): offset = seqlen_q - seqlen_k trip_start = FusedMask.get_trip_start( mask_type, blk_coord, tile_shape, seqlen_q, seqlen_k, window_size_left, window_size_right, ) trip_count = FusedMask.get_trip_count( mask_type, blk_coord, tile_shape, seqlen_q, seqlen_k, window_size_left, window_size_right, ) trailing_mask_begin, trailing_mask_end = None, None if cutlass.const_expr( mask_type is MaskEnum.WINDOW_MASK or mask_type is MaskEnum.WINDOW_MASK_INFERENCE ): if cutlass.const_expr(window_size_right is not None): min_idx_q = blk_coord[0] * tile_shape[0] + offset + window_size_right trailing_mask_begin = min( min_idx_q // tile_shape[1], trip_count + trip_start - 1 ) trailing_mask_end = trip_count + trip_start - 1 else: # last tile, we always apply mask on it regardless whether it's a residual tile trailing_mask_begin = trip_count + trip_start - 1 trailing_mask_end = trip_count + trip_start - 1 else: if cutlass.const_expr(window_size_left is not None): min_idx_k = blk_coord[1] * tile_shape[1] + offset + window_size_left + 1 max_idx_k = ( (blk_coord[1] + 1) * tile_shape[1] + offset + window_size_left ) trailing_mask_begin = min( cute.ceil_div(min_idx_k, tile_shape[0]) - 1, trip_count + trip_start - 1, ) trailing_mask_end = min( cute.ceil_div(max_idx_k, tile_shape[0]) - 1, trip_count + trip_start - 1, ) else: # last tile, we always apply mask on it regardless whether it's a residual tile trailing_mask_begin = trip_count + trip_start - 1 trailing_mask_end = trip_count + trip_start - 1 return trailing_mask_begin, trailing_mask_end @cute.jit def get_masked_leading_count( mask_type: MaskEnum, blk_coord: cute.Coord, tile_shape: cute.Shape, seqlen_q: Int32, seqlen_k: Int32, window_size_left: Optional[Int32] = None, window_size_right: Optional[Int32] = None, ) -> Int32: """ Calculate the number of masked trips for the leading mask. This is used for blocks that need special handling due to masking. :param mask_type: Type of mask to use :type mask_type: utils.MaskEnum :param blk_coord: Block coordinates. :type blk_coord: cute.Coord :param tile_shape: Shape of the tile. :type tile_shape: cute.Shape :param seqlen_q: Query sequence length for attention computation. :type seqlen_q: Int32 :param seqlen_k: Key sequence length for attention computation. :type seqlen_k: Int32 :param window_size_left: Left-side sliding window size for attention masking. :type window_size_left: Optional[Int32] :param window_size_right: Right-side sliding window size for attention masking. :type window_size_right: Optional[Int32] :return: Number of masked trips. :rtype: Int32 """ result = 0 if cutlass.const_expr( mask_type is not MaskEnum.RESIDUAL_MASK and mask_type is not MaskEnum.RESIDUAL_MASK_BWD ): if cutlass.const_expr( window_size_left is not None or window_size_right is not None ): leading_mask_begin, leading_mask_end = FusedMask.get_leading_mask_id( mask_type, blk_coord, tile_shape, seqlen_q, seqlen_k, window_size_left, window_size_right, ) result = max(leading_mask_end - leading_mask_begin + 1, 0) return result @cute.jit def get_masked_trailing_count( mask_type: MaskEnum, blk_coord: cute.Coord, tile_shape: cute.Shape, seqlen_q: Int32, seqlen_k: Int32, window_size_left: Optional[Int32] = None, window_size_right: Optional[Int32] = None, rem_count: Optional[Int32] = 0, ) -> Int32: """ Calculate the number of masked trips for the trailing mask. This is used for blocks that need special handling due to masking. :param mask_type: Type of mask to use :type mask_type: utils.MaskEnum :param blk_coord: Block coordinates. :type blk_coord: cute.Coord :param tile_shape: Shape of the tile. :type tile_shape: cute.Shape :param seqlen_q: Query sequence length for attention computation. :type seqlen_q: Int32 :param seqlen_k: Key sequence length for attention computation. :type seqlen_k: Int32 :param window_size_left: Left-side sliding window size for attention masking. :type window_size_left: Optional[Int32] :param window_size_right: Right-side sliding window size for attention masking. :type window_size_right: Optional[Int32] :param rem_count: Remaining count from previous calculations. :type rem_count: Int32 :return: Number of masked trips. :rtype: Int32 """ result = 0 if cutlass.const_expr( mask_type is not MaskEnum.RESIDUAL_MASK and mask_type is not MaskEnum.RESIDUAL_MASK_BWD ): if cutlass.const_expr( window_size_left is not None or window_size_right is not None ): trailing_mask_begin, trailing_mask_end = FusedMask.get_trailing_mask_id( mask_type, blk_coord, tile_shape, seqlen_q, seqlen_k, window_size_left, window_size_right, ) leading_mask_begin, leading_mask_end = FusedMask.get_leading_mask_id( mask_type, blk_coord, tile_shape, seqlen_q, seqlen_k, window_size_left, window_size_right, ) if cutlass.const_expr( trailing_mask_begin is not None and trailing_mask_end is not None ): if trailing_mask_begin <= leading_mask_end: result = max(trailing_mask_end - leading_mask_end, 0) else: result = max(trailing_mask_end - trailing_mask_begin + 1, 0) else: if seqlen_k % tile_shape[1] != 0: result = 1 else: result = 0 return result + rem_count @cute.jit def get_unmasked_trip_count( mask_type: MaskEnum, blk_coord: cute.Coord, tile_shape: cute.Shape, seqlen_q: Int32, seqlen_k: Int32, window_size_left: Optional[Int32] = None, window_size_right: Optional[Int32] = None, ) -> Int32: """ Calculate the number of unmasked trips for the current block. This represents the number of trips that don't require special masking treatment. :param mask_type: Type of mask to use :type mask_type: utils.MaskEnum :param blk_coord: Block coordinates. :type blk_coord: cute.Coord :param tile_shape: Shape of the tile. :type tile_shape: cute.Shape :param seqlen_q: Query sequence length for attention computation. :type seqlen_q: Int32 :param seqlen_k: Key sequence length for attention computation. :type seqlen_k: Int32 :param window_size_left: Left-side sliding window size for attention masking. :type window_size_left: Optional[Int32] :param window_size_right: Right-side sliding window size for attention masking. :type window_size_right: Optional[Int32] :return: Number of unmasked trips. :rtype: Int32 """ result = ( FusedMask.get_trip_count( mask_type, blk_coord, tile_shape, seqlen_q, seqlen_k, window_size_left, window_size_right, ) - FusedMask.get_masked_leading_count( mask_type, blk_coord, tile_shape, seqlen_q, seqlen_k, window_size_left, window_size_right, ) - FusedMask.get_masked_trailing_count( mask_type, blk_coord, tile_shape, seqlen_q, seqlen_k, window_size_left, window_size_right, 0, ) ) return result @cute.jit def apply_mask( mask_type: MaskEnum, acc_qk: cute.Tensor, index_qk: cute.Tensor, seqlen_q: Int32, seqlen_k: Int32, window_size_left: Optional[int] = None, window_size_right: Optional[int] = None, index_transform: cutlass.Constexpr = lambda index_q, index_k: ( index_q, index_k, ), ): """ Apply the appropriate mask to the attention scores. This method modifies the attention scores (acc_qk) based on the mask type and the positions in the index tensor. :param mask_type: Type of mask to use :type mask_type: utils.MaskEnum :param acc_qk: Accumulated QK attention scores tensor. :type acc_qk: cute.Tensor :param index_qk: Index tensor containing position information. :type index_qk: cute.Tensor :param seqlen_k: Key sequence length for attention computation. :type seqlen_k: Int32 :param seqlen_q: Query sequence length for attention computation. :type seqlen_q: Optional[int] :param window_size_left: Left-side sliding window size for attention masking. :type window_size_left: Optional[int] :param window_size_right: Right-side sliding window size for attention masking. :type window_size_right: Optional[int] """ tidx, tidy, tidx = cute.arch.thread_idx() offset = 0 offset = ( seqlen_k - seqlen_q if cutlass.const_expr( mask_type is MaskEnum.WINDOW_MASK_INFERENCE or mask_type is MaskEnum.WINDOW_MASK_BWD_INFERENCE ) else 0 ) for i in cutlass.range_constexpr(cute.size(acc_qk)): index_q, index_k = index_transform(*index_qk[i]) if cutlass.const_expr( window_size_left is not None or window_size_right is not None ): if cutlass.const_expr(window_size_left is None): if index_q + offset + window_size_right < index_k: acc_qk[i] = -Float32.inf if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask acc_qk[i] = -Float32.inf elif cutlass.const_expr(window_size_right is None): if index_q + offset - window_size_left > index_k: acc_qk[i] = -Float32.inf if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask acc_qk[i] = -Float32.inf else: max_K_index = min(index_q + offset + window_size_right, seqlen_k) min_K_index = max(0, index_q + offset - window_size_left) if index_k > max_K_index or index_k < min_K_index: acc_qk[i] = -Float32.inf if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask acc_qk[i] = -Float32.inf if cutlass.const_expr( mask_type == MaskEnum.RESIDUAL_MASK or mask_type == MaskEnum.RESIDUAL_MASK_BWD ): if index_k >= seqlen_k or index_q >= seqlen_q: acc_qk[i] = -Float32.inf