Files
cutlass/examples/python/CuTeDSL/blackwell/mla/mla_helpers.py
2026-02-13 23:27:58 -05:00

305 lines
12 KiB
Python

# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import cutlass
import cutlass.cute as cute
class MLAStaticTileSchedulerParams:
def __init__(
self,
is_persistent: bool,
problem_shape_b: cute.Int32,
problem_shape_s: cute.Int32,
cluster_shape_mnk: cute.Shape,
split_kv: cutlass.Int32,
*,
problem_shape_b_fdd: cute.FastDivmodDivisor = None,
problem_shape_s_fdd: cute.FastDivmodDivisor = None,
split_kv_fdd: cute.FastDivmodDivisor = None,
loc=None,
ip=None,
):
"""The static tile scheduler parameters prepared for MLA static tile scheduler.
:param is_persistent: Whether to use persistent kernel mode
:type is_persistent: bool
:param problem_shape_b: The shape of the problem
:type problem_shape_b: cute.Int32
:param problem_shape_s: The shape of the problem in sequence length Q dimension
:type problem_shape_s: cute.Int32
:param cluster_shape_mnk: The shape of the cluster
:type cluster_shape_mnk: cute.Shape
:param split_kv: The scalar factor for split KV
"""
self.is_persistent = is_persistent
self.problem_shape_b = problem_shape_b
self.problem_shape_s = problem_shape_s
self.problem_shape_b_fdd = problem_shape_b_fdd
self.problem_shape_s_fdd = problem_shape_s_fdd
self.cluster_shape_mnk = cluster_shape_mnk
self.split_kv = split_kv
self.split_kv_fdd = split_kv_fdd
if cutlass.const_expr(problem_shape_b_fdd is None):
self.problem_shape_b_fdd = cute.fast_divmod_create_divisor(
problem_shape_b, loc=loc, ip=ip
)
if cutlass.const_expr(problem_shape_s_fdd is None):
self.problem_shape_s_fdd = cute.fast_divmod_create_divisor(
problem_shape_s, loc=loc, ip=ip
)
if cutlass.const_expr(split_kv_fdd is None):
self.split_kv_fdd = cute.fast_divmod_create_divisor(
split_kv, loc=loc, ip=ip
)
self.loc = loc
self.ip = ip
def __extract_mlir_values__(self):
values = cutlass.extract_mlir_values(self.problem_shape_b)
values += cutlass.extract_mlir_values(self.problem_shape_s)
values += cutlass.extract_mlir_values(self.split_kv)
values += cutlass.extract_mlir_values(self.problem_shape_b_fdd)
values += cutlass.extract_mlir_values(self.problem_shape_s_fdd)
values += cutlass.extract_mlir_values(self.split_kv_fdd)
return values
def __new_from_mlir_values__(self, values):
problem_shape_b = cutlass.new_from_mlir_values(
self.problem_shape_b, (values[0],)
)
problem_shape_s = cutlass.new_from_mlir_values(
self.problem_shape_s, (values[1],)
)
split_kv = cutlass.new_from_mlir_values(self.split_kv, (values[2],))
problem_shape_b_fdd = cutlass.new_from_mlir_values(
self.problem_shape_b_fdd, (values[3],)
)
problem_shape_s_fdd = cutlass.new_from_mlir_values(
self.problem_shape_s_fdd, (values[4],)
)
split_kv_fdd = cutlass.new_from_mlir_values(self.split_kv_fdd, (values[5],))
return MLAStaticTileSchedulerParams(
self.is_persistent,
problem_shape_b,
problem_shape_s,
self.cluster_shape_mnk,
split_kv,
problem_shape_b_fdd=problem_shape_b_fdd,
problem_shape_s_fdd=problem_shape_s_fdd,
split_kv_fdd=split_kv_fdd,
loc=self.loc,
)
def create_mla_static_tile_scheduler_params(
is_persistent: bool,
problem_shape_b: cute.Int32,
problem_shape_s: cute.Int32,
cluster_shape_mnk: cute.Shape,
split_kv: cutlass.Int32,
) -> MLAStaticTileSchedulerParams:
return MLAStaticTileSchedulerParams(
is_persistent, problem_shape_b, problem_shape_s, cluster_shape_mnk, split_kv
)
class WorkTileInfo:
def __init__(self, blk_coord: cute.Coord, is_valid: bool):
self.blk_coord = blk_coord
self.is_valid = cutlass.Boolean(is_valid)
def __extract_mlir_values__(self):
values = cutlass.extract_mlir_values(self.blk_coord)
values += cutlass.extract_mlir_values(self.is_valid)
return values
def __new_from_mlir_values__(self, values):
new_tile_idx = cutlass.new_from_mlir_values(self.blk_coord, values[:-1])
new_is_valid_tile = cutlass.new_from_mlir_values(self.is_valid, [values[-1]])
return WorkTileInfo(new_tile_idx, new_is_valid_tile)
@property
def is_valid_tile(self) -> cutlass.Boolean:
return self.is_valid
@property
def tile_idx(self) -> cute.Coord:
return self.blk_coord
class MLAStaticTileScheduler:
def __init__(
self,
params: MLAStaticTileSchedulerParams,
current_work_linear_idx: cutlass.Int32,
blk_coord: cute.Coord,
grid_shape: cute.Shape,
*,
is_valid: bool = True,
loc=None,
ip=None,
):
"""The static tile scheduler for MLA split kv kernel.
Based on `is_persistent`, it provides 2 modes for use:
- Persistent mode: Launch fixed blocks and reschedule the data blocks.
- Non-persistent mode: Launch dynamic blocks and exit when the current work is done.
:param params: The static tile scheduler parameters
:type params: MLAStaticTileSchedulerParams
:param current_work_linear_idx: The linear index of the current work
:type current_work_linear_idx: cutlass.Int32
:param blk_coord: The coordinate of the current work
:type blk_coord: cute.Coord
:param grid_shape: The shape of the grid
:type grid_shape: cute.Shape
:param is_valid: Whether the current work is valid
:type is_valid: bool
"""
self.params = params
self.blk_coord = blk_coord
self.grid_shape = grid_shape
self.current_work_linear_idx = current_work_linear_idx
if params.is_persistent:
self.persistent_blk_layout = cute.make_layout(
(
params.cluster_shape_mnk[0],
params.problem_shape_s,
params.problem_shape_b,
params.split_kv,
),
loc=loc,
ip=ip,
)
self.num_blocks = cute.size(self.persistent_blk_layout, loc=loc, ip=ip)
# Used for persistent scheduling
self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip)
else:
self.is_valid = is_valid
self.loc = loc
self.ip = ip
@staticmethod
def get_grid_shape(
params: MLAStaticTileSchedulerParams,
max_active_clusters: int,
*,
loc=None,
ip=None,
) -> cute.Shape:
# called by host
grid_shape = (
params.cluster_shape_mnk[0],
params.problem_shape_b * params.problem_shape_s,
params.split_kv,
)
if params.is_persistent:
return (
cutlass.min(
max_active_clusters * cute.size(params.cluster_shape_mnk),
cute.size(grid_shape, loc=loc, ip=ip),
),
1,
1,
)
else:
return grid_shape
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
is_valid = (
self.current_work_linear_idx < self.num_blocks
if self.params.is_persistent
else self.is_valid
)
if self.params.is_persistent:
current_work_cluster_batch, cluster_idx = (
self.current_work_linear_idx // self.params.cluster_shape_mnk[0],
self.current_work_linear_idx % self.params.cluster_shape_mnk[0],
)
current_work_s_batch, s_idx = divmod(
current_work_cluster_batch, self.params.problem_shape_s_fdd
)
current_work_b_batch, b_idx = divmod(
current_work_s_batch, self.params.problem_shape_b_fdd
)
_, split_kv_idx = divmod(current_work_b_batch, self.params.split_kv_fdd)
blk_coord = (cluster_idx, s_idx, b_idx, split_kv_idx)
else:
s_idx, b_idx = divmod(self.blk_coord[1], self.params.problem_shape_b_fdd)
blk_coord = (self.blk_coord[0], s_idx, b_idx, self.blk_coord[2])
return WorkTileInfo(blk_coord, is_valid)
def initial_work_tile_info(self, *, loc=None, ip=None):
return self.get_current_work(loc=loc, ip=ip)
def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None):
if self.params.is_persistent:
self.current_work_linear_idx += advance_count * self.num_persistent_sm
else:
self.is_valid = False
def __extract_mlir_values__(self):
values = cutlass.extract_mlir_values(self.params)
values.extend(cutlass.extract_mlir_values(self.current_work_linear_idx))
values.extend(cutlass.extract_mlir_values(self.blk_coord))
values.extend(cutlass.extract_mlir_values(self.grid_shape))
return values
def __new_from_mlir_values__(self, values):
assert len(values) == 13
new_params = cutlass.new_from_mlir_values(self.params, values[0:6])
new_current_work_linear_idx = cutlass.new_from_mlir_values(
self.current_work_linear_idx, [values[6]]
)
new_blk_coord = cutlass.new_from_mlir_values(self.blk_coord, values[7:10])
new_grid_shape = cutlass.new_from_mlir_values(self.grid_shape, values[10:])
return MLAStaticTileScheduler(
new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape
)
def create_mla_static_tile_scheduler(
params: MLAStaticTileSchedulerParams,
blk_coord: cute.Coord,
grid_shape: cute.Shape,
) -> MLAStaticTileScheduler:
return MLAStaticTileScheduler(params, blk_coord[0], blk_coord, grid_shape)
LOG2_E = 1.4426950408889634074
# avoid register indexing on array.
MAX_SPLITS = 256
def ceil_div(a: int, b: int) -> int:
return (a + b - 1) // b