mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-13 09:45:45 +00:00
305 lines
12 KiB
Python
305 lines
12 KiB
Python
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
|
|
|
|
class MLAStaticTileSchedulerParams:
|
|
def __init__(
|
|
self,
|
|
is_persistent: bool,
|
|
problem_shape_b: cute.Int32,
|
|
problem_shape_s: cute.Int32,
|
|
cluster_shape_mnk: cute.Shape,
|
|
split_kv: cutlass.Int32,
|
|
*,
|
|
problem_shape_b_fdd: cute.FastDivmodDivisor = None,
|
|
problem_shape_s_fdd: cute.FastDivmodDivisor = None,
|
|
split_kv_fdd: cute.FastDivmodDivisor = None,
|
|
loc=None,
|
|
ip=None,
|
|
):
|
|
"""The static tile scheduler parameters prepared for MLA static tile scheduler.
|
|
|
|
:param is_persistent: Whether to use persistent kernel mode
|
|
:type is_persistent: bool
|
|
:param problem_shape_b: The shape of the problem
|
|
:type problem_shape_b: cute.Int32
|
|
:param problem_shape_s: The shape of the problem in sequence length Q dimension
|
|
:type problem_shape_s: cute.Int32
|
|
:param cluster_shape_mnk: The shape of the cluster
|
|
:type cluster_shape_mnk: cute.Shape
|
|
:param split_kv: The scalar factor for split KV
|
|
"""
|
|
self.is_persistent = is_persistent
|
|
self.problem_shape_b = problem_shape_b
|
|
self.problem_shape_s = problem_shape_s
|
|
self.problem_shape_b_fdd = problem_shape_b_fdd
|
|
self.problem_shape_s_fdd = problem_shape_s_fdd
|
|
self.cluster_shape_mnk = cluster_shape_mnk
|
|
self.split_kv = split_kv
|
|
self.split_kv_fdd = split_kv_fdd
|
|
if cutlass.const_expr(problem_shape_b_fdd is None):
|
|
self.problem_shape_b_fdd = cute.fast_divmod_create_divisor(
|
|
problem_shape_b, loc=loc, ip=ip
|
|
)
|
|
if cutlass.const_expr(problem_shape_s_fdd is None):
|
|
self.problem_shape_s_fdd = cute.fast_divmod_create_divisor(
|
|
problem_shape_s, loc=loc, ip=ip
|
|
)
|
|
if cutlass.const_expr(split_kv_fdd is None):
|
|
self.split_kv_fdd = cute.fast_divmod_create_divisor(
|
|
split_kv, loc=loc, ip=ip
|
|
)
|
|
self.loc = loc
|
|
self.ip = ip
|
|
|
|
def __extract_mlir_values__(self):
|
|
values = cutlass.extract_mlir_values(self.problem_shape_b)
|
|
values += cutlass.extract_mlir_values(self.problem_shape_s)
|
|
values += cutlass.extract_mlir_values(self.split_kv)
|
|
values += cutlass.extract_mlir_values(self.problem_shape_b_fdd)
|
|
values += cutlass.extract_mlir_values(self.problem_shape_s_fdd)
|
|
values += cutlass.extract_mlir_values(self.split_kv_fdd)
|
|
return values
|
|
|
|
def __new_from_mlir_values__(self, values):
|
|
problem_shape_b = cutlass.new_from_mlir_values(
|
|
self.problem_shape_b, (values[0],)
|
|
)
|
|
problem_shape_s = cutlass.new_from_mlir_values(
|
|
self.problem_shape_s, (values[1],)
|
|
)
|
|
split_kv = cutlass.new_from_mlir_values(self.split_kv, (values[2],))
|
|
problem_shape_b_fdd = cutlass.new_from_mlir_values(
|
|
self.problem_shape_b_fdd, (values[3],)
|
|
)
|
|
problem_shape_s_fdd = cutlass.new_from_mlir_values(
|
|
self.problem_shape_s_fdd, (values[4],)
|
|
)
|
|
split_kv_fdd = cutlass.new_from_mlir_values(self.split_kv_fdd, (values[5],))
|
|
return MLAStaticTileSchedulerParams(
|
|
self.is_persistent,
|
|
problem_shape_b,
|
|
problem_shape_s,
|
|
self.cluster_shape_mnk,
|
|
split_kv,
|
|
problem_shape_b_fdd=problem_shape_b_fdd,
|
|
problem_shape_s_fdd=problem_shape_s_fdd,
|
|
split_kv_fdd=split_kv_fdd,
|
|
loc=self.loc,
|
|
)
|
|
|
|
|
|
def create_mla_static_tile_scheduler_params(
|
|
is_persistent: bool,
|
|
problem_shape_b: cute.Int32,
|
|
problem_shape_s: cute.Int32,
|
|
cluster_shape_mnk: cute.Shape,
|
|
split_kv: cutlass.Int32,
|
|
) -> MLAStaticTileSchedulerParams:
|
|
return MLAStaticTileSchedulerParams(
|
|
is_persistent, problem_shape_b, problem_shape_s, cluster_shape_mnk, split_kv
|
|
)
|
|
|
|
|
|
class WorkTileInfo:
|
|
def __init__(self, blk_coord: cute.Coord, is_valid: bool):
|
|
self.blk_coord = blk_coord
|
|
self.is_valid = cutlass.Boolean(is_valid)
|
|
|
|
def __extract_mlir_values__(self):
|
|
values = cutlass.extract_mlir_values(self.blk_coord)
|
|
values += cutlass.extract_mlir_values(self.is_valid)
|
|
return values
|
|
|
|
def __new_from_mlir_values__(self, values):
|
|
new_tile_idx = cutlass.new_from_mlir_values(self.blk_coord, values[:-1])
|
|
new_is_valid_tile = cutlass.new_from_mlir_values(self.is_valid, [values[-1]])
|
|
return WorkTileInfo(new_tile_idx, new_is_valid_tile)
|
|
|
|
@property
|
|
def is_valid_tile(self) -> cutlass.Boolean:
|
|
return self.is_valid
|
|
|
|
@property
|
|
def tile_idx(self) -> cute.Coord:
|
|
return self.blk_coord
|
|
|
|
|
|
class MLAStaticTileScheduler:
|
|
def __init__(
|
|
self,
|
|
params: MLAStaticTileSchedulerParams,
|
|
current_work_linear_idx: cutlass.Int32,
|
|
blk_coord: cute.Coord,
|
|
grid_shape: cute.Shape,
|
|
*,
|
|
is_valid: bool = True,
|
|
loc=None,
|
|
ip=None,
|
|
):
|
|
"""The static tile scheduler for MLA split kv kernel.
|
|
Based on `is_persistent`, it provides 2 modes for use:
|
|
- Persistent mode: Launch fixed blocks and reschedule the data blocks.
|
|
- Non-persistent mode: Launch dynamic blocks and exit when the current work is done.
|
|
|
|
:param params: The static tile scheduler parameters
|
|
:type params: MLAStaticTileSchedulerParams
|
|
:param current_work_linear_idx: The linear index of the current work
|
|
:type current_work_linear_idx: cutlass.Int32
|
|
:param blk_coord: The coordinate of the current work
|
|
:type blk_coord: cute.Coord
|
|
:param grid_shape: The shape of the grid
|
|
:type grid_shape: cute.Shape
|
|
:param is_valid: Whether the current work is valid
|
|
:type is_valid: bool
|
|
"""
|
|
self.params = params
|
|
self.blk_coord = blk_coord
|
|
self.grid_shape = grid_shape
|
|
self.current_work_linear_idx = current_work_linear_idx
|
|
if params.is_persistent:
|
|
self.persistent_blk_layout = cute.make_layout(
|
|
(
|
|
params.cluster_shape_mnk[0],
|
|
params.problem_shape_s,
|
|
params.problem_shape_b,
|
|
params.split_kv,
|
|
),
|
|
loc=loc,
|
|
ip=ip,
|
|
)
|
|
self.num_blocks = cute.size(self.persistent_blk_layout, loc=loc, ip=ip)
|
|
# Used for persistent scheduling
|
|
self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip)
|
|
else:
|
|
self.is_valid = is_valid
|
|
self.loc = loc
|
|
self.ip = ip
|
|
|
|
@staticmethod
|
|
def get_grid_shape(
|
|
params: MLAStaticTileSchedulerParams,
|
|
max_active_clusters: int,
|
|
*,
|
|
loc=None,
|
|
ip=None,
|
|
) -> cute.Shape:
|
|
# called by host
|
|
grid_shape = (
|
|
params.cluster_shape_mnk[0],
|
|
params.problem_shape_b * params.problem_shape_s,
|
|
params.split_kv,
|
|
)
|
|
if params.is_persistent:
|
|
return (
|
|
cutlass.min(
|
|
max_active_clusters * cute.size(params.cluster_shape_mnk),
|
|
cute.size(grid_shape, loc=loc, ip=ip),
|
|
),
|
|
1,
|
|
1,
|
|
)
|
|
else:
|
|
return grid_shape
|
|
|
|
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
|
is_valid = (
|
|
self.current_work_linear_idx < self.num_blocks
|
|
if self.params.is_persistent
|
|
else self.is_valid
|
|
)
|
|
|
|
if self.params.is_persistent:
|
|
current_work_cluster_batch, cluster_idx = (
|
|
self.current_work_linear_idx // self.params.cluster_shape_mnk[0],
|
|
self.current_work_linear_idx % self.params.cluster_shape_mnk[0],
|
|
)
|
|
current_work_s_batch, s_idx = divmod(
|
|
current_work_cluster_batch, self.params.problem_shape_s_fdd
|
|
)
|
|
current_work_b_batch, b_idx = divmod(
|
|
current_work_s_batch, self.params.problem_shape_b_fdd
|
|
)
|
|
_, split_kv_idx = divmod(current_work_b_batch, self.params.split_kv_fdd)
|
|
|
|
blk_coord = (cluster_idx, s_idx, b_idx, split_kv_idx)
|
|
else:
|
|
s_idx, b_idx = divmod(self.blk_coord[1], self.params.problem_shape_b_fdd)
|
|
blk_coord = (self.blk_coord[0], s_idx, b_idx, self.blk_coord[2])
|
|
|
|
return WorkTileInfo(blk_coord, is_valid)
|
|
|
|
def initial_work_tile_info(self, *, loc=None, ip=None):
|
|
return self.get_current_work(loc=loc, ip=ip)
|
|
|
|
def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None):
|
|
if self.params.is_persistent:
|
|
self.current_work_linear_idx += advance_count * self.num_persistent_sm
|
|
else:
|
|
self.is_valid = False
|
|
|
|
def __extract_mlir_values__(self):
|
|
values = cutlass.extract_mlir_values(self.params)
|
|
values.extend(cutlass.extract_mlir_values(self.current_work_linear_idx))
|
|
values.extend(cutlass.extract_mlir_values(self.blk_coord))
|
|
values.extend(cutlass.extract_mlir_values(self.grid_shape))
|
|
return values
|
|
|
|
def __new_from_mlir_values__(self, values):
|
|
assert len(values) == 13
|
|
new_params = cutlass.new_from_mlir_values(self.params, values[0:6])
|
|
new_current_work_linear_idx = cutlass.new_from_mlir_values(
|
|
self.current_work_linear_idx, [values[6]]
|
|
)
|
|
new_blk_coord = cutlass.new_from_mlir_values(self.blk_coord, values[7:10])
|
|
new_grid_shape = cutlass.new_from_mlir_values(self.grid_shape, values[10:])
|
|
return MLAStaticTileScheduler(
|
|
new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape
|
|
)
|
|
|
|
|
|
def create_mla_static_tile_scheduler(
|
|
params: MLAStaticTileSchedulerParams,
|
|
blk_coord: cute.Coord,
|
|
grid_shape: cute.Shape,
|
|
) -> MLAStaticTileScheduler:
|
|
return MLAStaticTileScheduler(params, blk_coord[0], blk_coord, grid_shape)
|
|
|
|
|
|
LOG2_E = 1.4426950408889634074
|
|
# avoid register indexing on array.
|
|
MAX_SPLITS = 256
|
|
|
|
|
|
def ceil_div(a: int, b: int) -> int:
|
|
return (a + b - 1) // b
|