Files
cutlass/examples/python/CuTeDSL/blackwell/grouped_mixed_input_gemm.py
2025-11-20 20:49:44 -05:00

3203 lines
129 KiB
Python

# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
from math import log2, ceil
from typing import Optional, Union
import torch
import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute
from cutlass.cutlass_dsl import (
extract_mlir_values,
new_from_mlir_values,
)
import cutlass.pipeline as pipeline
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
import cutlass.utils.blackwell_helpers as sm100_utils
import cutlass.utils.mixed_input_helpers as mixed_input_utils
from cutlass.utils.mixed_input_helpers import TransformMode
import cutlass.cute.testing as testing
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass.cute.runtime import from_dlpack
"""
A mixed-input grouped GEMM example for the NVIDIA Blackwell SM100 architecture using CUTE DSL.
This example demonstrates an implementation of mixed-input grouped GEMM using a TMA plus Blackwell
SM100 TensorCore warp-specialized persistent kernel. It could be viewed as an extension of the batched
mixed-input GEMM example to support a specific grouped GEMM pattern, grouped gemm with contiguous offsets.
Specifically, the input A tensor is still in the shape of (M, K, L), and L is the number of groups. The
input B tensor is in the shape of (N, K) and the result C tensor is in the shape of (M, N). Tensor B
and tensor C are not divided into groups explititly and there is an extra input tensor cumsum defining
the mapping between the N mode to groups. The cumsum tensor is in the shape of (N+1) and cumsum[i]
defines the accumulated size along N mode for groups up to i(not including i):
```
Group 0 Group 1 Group 2 ..... Group L-1
-+--------+--------+--------+.....+----------------+
| | | | |
|<- N0 ->|<- N1 ->|<- N2 ->|.....|<-- NL-1 -->|
| | | | |
-+--------+--------+--------+.....+-------------------+
cumsum: | 0 | N0 | N0+N1 |.....| sum(N0,N1,...NL-2) | sum(N0,N1,...NL-1)
```
The computation flow is same as the batched mixed-input GEMM example. A is the narrow-precision tensor
and B holds data with a wider precision. MMA will work in the wide precision of tensor B and tensor A
will be transformed to the wide precision of tensor B following 1 of the 2 possible modes as follows:
1. convert-only mode:
C = type_convert(A) x B
In convert-only mode, tensor A is directly converted to the wide precision of tensor B.
2. convert-scale mode:
C = (type_convert(A) * scale) x B
In convert-scale mode, tensor A is first converted to the wide precision of tensor B and then scaled by the scale tensor.
The scale tensor is in the same precision as tensor B.
The mode is determined by tensor A's data type as follows:
- if tensor A is in int8 or uint8, convert-only mode is used.
- if tensor A is in int4, convert-scale mode is used.
The output tensor C could have the same precision as tensor B or fp32.
To run this example:
.. code-block:: bash
python examples/blackwell/grouped_mixed_input_gemm.py \
--a_dtype Int8 --b_dtype BFloat16 \
--scale_granularity_m 0 --scale_granularity_k 0 \
--c_dtype BFloat16 --acc_dtype Float32 \
--mma_tiler_mnk 128,128,64 --cluster_shape_mn 1,1 \
--mnkl 256,512,8192,1
Input A and B have int8 and bf16 data types, respectively. The Blackwell tcgen05 MMA tile shape
is specified as (128,128,64) and the cluster shape is (1,1). The MMA accumulator and output data type
are set as fp32 and bf16, respectively. As tensor A is int8, convert-only mode is used.
scale_granularity_m and scale_granularity_k are set as 0 for convert-only mode.
Here is an example of running convert-scale mode:
.. code-block:: bash
python examples/blackwell/grouped_mixed_input_gemm.py \
--a_dtype Int4 --b_dtype BFloat16 \
--scale_granularity_m 1 --scale_granularity_k 256 \
--c_dtype BFloat16 --acc_dtype Float32 \
--mma_tiler_mnk 256,128,128 --cluster_shape_mn 2,1 \
--use_2cta_instrs --mnkl 1024,8192,6144,16 \
Input A and B have int4 and bf16 data types, respectively. The scale granularity is set as (1,256),
which means each element along the m mode of tensor A has its own scale element and 256 contiguous elements
along the k mode share the same scale element. There is no scale reuse along the L mode. If the GEMM shape is
(M, N, K, L), then the scale tensor shape is (M // scale_granularity_m, K // scale_granularity_k, L),
which is (1024, 6144/256, 16) in this example.
The Blackwell tcgen05 MMA tile shape is specified as (256,128,128) and tcgen05 2CTA feature is enabled.
The cluster shape is (2,1). The MMA accumulator and output data type are set as fp32 and bf16, respectively.
As tensor A is int4, the convert-scale mode is used.
To collect performance with NCU profiler:
.. code-block:: bash
ncu python examples/blackwell/grouped_grouped_mixed_input \
--a_dtype Int8 --b_dtype BFloat16 \
--scale_granularity_m 0 --scale_granularity_k 0 \
--c_dtype BFloat16 --acc_dtype Float32 \
--mma_tiler_mnk 128,128,64 --cluster_shape_mn 1,1 \
--mnkl 256,512,8192,1 \
--warmup_iterations 1 --iterations 10 --skip_ref_check
Besides the requirements from the batched mixed-input GEMM example, there are some constraints for this example:
* --use_tma_store option is removed as no alignment assumption is made for each group.
"""
class ContiguousGGSearchState:
"""
The state of group search for grouped gemm with contiguous offsets.
The state records the progress of group seach algorithm on 1 mode. It will be
initialized once and updated in every round of group index search.
:param last_tile_count: Number of cluster tiles before the current group
:type last_tile_count: cutlass.Int32
:param cur_boundary: The boundary of the current group, which is the size along the seach
mode before the next group
:type cur_boundary: cutlass.Int32
:param cur_tile_count: Number of cluster tiles searched so far
:type cur_tile_count: cutlass.Int32
:param cur_group_idx: The index of the current group
:type cur_group_idx: cutlass.Int32
:param cur_offset: The starting offset of the current group along the search mode
:type cur_offset: cutlass.Int32
:param cur_start: The starting offset of the current cluster tile size along the search mode
when group search is done
:type cur_start: cutlass.Int32
"""
def __init__(
self,
last_tile_count: cutlass.Int32,
cur_boundary: cutlass.Int32,
cur_tile_count: cutlass.Int32,
cur_group_idx: cutlass.Int32,
cur_offset: cutlass.Int32,
cur_start: cutlass.Int32,
):
self.last_tile_count = last_tile_count
self.cur_boundary = cur_boundary
self.cur_tile_count = cur_tile_count
self.cur_group_idx = cur_group_idx
self.cur_offset = cur_offset
self.cur_start = cur_start
def __extract_mlir_values__(self):
values = extract_mlir_values(self.last_tile_count)
values.extend(extract_mlir_values(self.cur_boundary))
values.extend(extract_mlir_values(self.cur_tile_count))
values.extend(extract_mlir_values(self.cur_group_idx))
values.extend(extract_mlir_values(self.cur_offset))
values.extend(extract_mlir_values(self.cur_start))
return values
def __new_from_mlir_values__(self, values) -> "ContiguousGGSearchState":
last_tile_count = new_from_mlir_values(self.last_tile_count, [values[0]])
cur_boundary = new_from_mlir_values(self.cur_boundary, [values[1]])
cur_tile_count = new_from_mlir_values(self.cur_tile_count, [values[2]])
cur_group_idx = new_from_mlir_values(self.cur_group_idx, [values[3]])
cur_offset = new_from_mlir_values(self.cur_offset, [values[4]])
cur_start = new_from_mlir_values(self.cur_start, [values[5]])
return ContiguousGGSearchState(
last_tile_count,
cur_boundary,
cur_tile_count,
cur_group_idx,
cur_offset,
cur_start,
)
def create_initial_search_state() -> ContiguousGGSearchState:
"""
Create an initial search state for grouped gemm with contiguous offsets.
"""
return ContiguousGGSearchState(
last_tile_count=cutlass.Int32(0),
cur_boundary=cutlass.Int32(0),
cur_tile_count=cutlass.Int32(0),
cur_group_idx=cutlass.Int32(0),
cur_offset=cutlass.Int32(0),
cur_start=cutlass.Int32(0),
)
class GroupedWorkTileInfo:
"""
Tile info for grouped gemm with contiguous offsets.
It's consutrcted from the search state and contains informtion needed for different warps.
:param group_count: The total number of groups
:type group_count: int
:param cta_coord_m: The coordinate of the current CTA tile along the M mode
:type cta_coord_m: cutlass.Int32
:param coord_n: The starting offset on N mode for the current CTA tile
:type coord_n: cutlass.Int32
:param group_idx: The index of the current group
:type group_idx: cutlass.Int32
:param distance_to_boundary: The distance to the boundary of the current group
:type distance_to_boundary: cutlass.Int32
"""
def __init__(
self,
group_count: int,
cta_coord_m: cutlass.Int32,
coord_n: cutlass.Int32,
group_idx: cutlass.Int32,
distance_to_boundary: cutlass.Int32,
):
self.cta_coord_m = cta_coord_m
self.coord_n = coord_n
self.group_idx = group_idx
self.distance_to_boundary = distance_to_boundary
self.group_count = group_count
def __extract_mlir_values__(self):
values = extract_mlir_values(self.cta_coord_m)
values.extend(extract_mlir_values(self.coord_n))
values.extend(extract_mlir_values(self.group_idx))
values.extend(extract_mlir_values(self.distance_to_boundary))
return values
def __new_from_mlir_values__(self, values):
assert len(values) == 4
new_cta_coord_m = new_from_mlir_values(self.cta_coord_m, [values[0]])
new_coord_n = new_from_mlir_values(self.coord_n, [values[1]])
new_group_idx = new_from_mlir_values(self.group_idx, [values[2]])
new_distance_to_boundary = new_from_mlir_values(
self.distance_to_boundary, [values[3]]
)
return GroupedWorkTileInfo(
self.group_count,
new_cta_coord_m,
new_coord_n,
new_group_idx,
new_distance_to_boundary,
)
@property
def is_valid_tile(self):
return self.group_idx < self.group_count
class GroupedMixedInputGemmKernel:
"""
Mixed-input grouped GEMM kernel for NVIDIA Blackwell SM100 architecture.
This kernel supports GEMM operations where input tensors A and B have different
data types, with tensor A being transformed to the precision of tensor B before
matrix multiplication.
Tensor A is in shape of [M, K, L] with L being the number of groups. Tensor B is in shape of [N, K] and group seach algorithm
is applied on the N mode to find the group index for each CTA tile. A cumsum input tensor provides the offset of each group along the N mode.
:param scale_granularity_m: Number of elements sharing the same scale factor along the M mode
:type scale_granularity_m: int
:param scale_granularity_k: Number of elements sharing the same scale factor along the K mode
:type scale_granularity_k: int
:param acc_dtype: Data type for accumulation during computation
:type acc_dtype: type[cutlass.Numeric]
:param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation
:type use_2cta_instrs: bool
:param mma_tiler_mnk: Shape of the Matrix Multiply-Accumulate (MMA) tile (M, N, K)
:type mma_tiler_mnk: tuple[int, int, int]
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
:type cluster_shape_mn: tuple[int, int]
:param group_count: The total number of groups
:type group_count: int
"""
def __init__(
self,
scale_granularity_m: int,
scale_granularity_k: int,
acc_dtype: type[cutlass.Numeric],
use_2cta_instrs: bool,
mma_tiler_mnk: tuple[int, int, int],
cluster_shape_mn: tuple[int, int],
group_count: int,
):
"""
Initializes the mixed-input GEMM kernel with a specified configuration.
"""
# Scale granularity defines how many elements share the same scale factor
# along the M and K modes.
self.scale_granularity_m = scale_granularity_m
self.scale_granularity_k = scale_granularity_k
# Set transform mode
if cutlass.const_expr(
self.scale_granularity_m == 0 and self.scale_granularity_k == 0
):
self.scale_mode = TransformMode.ConvertOnly
else:
self.scale_mode = TransformMode.ConvertScale
self.group_count = group_count
self.acc_dtype = acc_dtype
self.use_2cta_instrs = use_2cta_instrs
self.cluster_shape_mn = cluster_shape_mn
self.mma_tiler = mma_tiler_mnk
self.cta_group = (
tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
)
# Set specialized warp ids
self.epilog_warp_id = (
0,
1,
2,
3,
)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.scale_tma_warp_id = 6
# Schedule warp to do the group search
self.schedule_warp_id = 7
self.transform_warp_id = (
8,
9,
10,
11,
)
# Define expected register count for different warps
self.num_regs_epilogue_warps = 192
self.num_regs_mma_warp = 96
self.num_regs_tma_warps = 80
self.num_regs_transform_warps = 208
self.num_regs_schedule_warp = 64
self.threads_per_cta = 32 * (
max(
(
self.mma_warp_id,
self.tma_warp_id,
self.scale_tma_warp_id,
*self.epilog_warp_id,
*self.transform_warp_id,
)
)
+ 1
)
# Set barrier id for cta sync, epilogue sync, tmem ptr sync, and transform sync
self.epilog_sync_barrier = pipeline.NamedBarrier(
1, 32 * len(self.epilog_warp_id)
)
self.tmem_ptr_sync_barrier = pipeline.NamedBarrier(2, self.threads_per_cta)
self.transform_sync_barrier = pipeline.NamedBarrier(
3, 32 * len(self.transform_warp_id)
)
self.cta_sync_barrier = pipeline.NamedBarrier(4, self.threads_per_cta)
self.sched_sync_barrier = pipeline.NamedBarrier(5, 32)
self.smem_buffer_align_bytes = 1024
def _setup_attributes(self):
"""Set up configurations that are dependent on GEMM inputs
This method configures various attributes based on the input tensor properties
(data types, leading dimensions) and kernel settings:
- Deduce where the transformed A tensor is stored
- Configuring tiled MMA
- Computing MMA/cluster/tile shapes
- Computing cluster layout
- Computing multicast CTAs for A/B
- Computing epilogue subtile
- Setting up A/scale/B/C stage counts in shared memory
- Setting up transformed A stage count in shared memory or tensor memory
- Computing A/transformed A/scale/B/C memory layout
- Computing tensor memory allocation columns
"""
# Deduce where the transformed A tensor is stored, shared memory(SMEM) or tensor memory(TMEM)
self.transform_a_source = mixed_input_utils.get_transform_a_source(
self.a_major_mode
)
tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.mma_dtype,
self.a_major_mode,
self.b_major_mode,
self.acc_dtype,
self.cta_group,
self.mma_tiler[:2],
self.transform_a_source,
)
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler[1],
self.mma_tiler[2],
)
self.cluster_tile_shape_mnk = (
self.cluster_shape_mn[0] * self.cta_tile_shape_mnk[0],
self.cluster_shape_mn[1] * self.cta_tile_shape_mnk[1],
self.cta_tile_shape_mnk[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((*self.cluster_shape_mn, 1)),
(tiled_mma.thr_id.shape,),
)
self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
self.is_a_mcast = self.num_mcast_ctas_a > 1
self.is_b_mcast = self.num_mcast_ctas_b > 1
self.epi_tile = sm100_utils.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk,
self.use_2cta_instrs,
self.c_layout,
self.c_dtype,
)
# Compute tensor memory(TMEM) columns and stages for each pipeline
(
self.num_load2trans_stage,
self.num_scale_load2trans_stage,
self.num_trans2mma_stage,
self.num_acc_stage,
self.num_c_stage,
self.num_tile_info_stage,
self.num_acc_tmem_cols,
self.num_a_tmem_cols,
) = self._compute_stages_and_tmem_cols(
tiled_mma,
self.mma_tiler,
self.cta_tile_shape_mnk,
self.epi_tile,
self.a_dtype,
self.b_dtype,
self.c_dtype,
self.c_layout,
self.transform_a_source,
self.scale_granularity_m,
self.scale_granularity_k,
self.smem_buffer_align_bytes,
self.scale_mode,
)
# Align TMEM columns for allocation
# TMEM allocation requires power-of-2 column alignment
# and must meet minimum allocation requirements
self.num_tmem_alloc_cols = GroupedMixedInputGemmKernel.align_up(
self.num_acc_tmem_cols + self.num_a_tmem_cols,
cute.arch.SM100_TMEM_MIN_ALLOC_COLUMNS,
)
self.num_tmem_alloc_cols = 2 ** (ceil(log2(self.num_tmem_alloc_cols)))
# Get smem layout for C tensor
self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
self.c_dtype,
self.c_layout,
self.epi_tile,
self.num_c_stage,
)
# Get smem layout for A, transformed A, and B
(
self.smem_layout_a,
self.smem_layout_a_transform,
self.smem_layout_b,
) = mixed_input_utils.compute_smem_layout(
tiled_mma,
self.mma_tiler,
self.a_dtype,
self.b_dtype,
self.num_load2trans_stage,
self.num_trans2mma_stage,
)
# Get smem layout for scale tensor
self.smem_layout_scale_per_stage = None
self.smem_layout_scale = None
if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale):
# Get scale tile shape and smem layout for scale tensor
(
self.scale_tile_shape,
self.smem_layout_scale_per_stage,
self.smem_layout_scale,
) = mixed_input_utils.get_smem_layout_scale(
self.mma_tiler,
self.use_2cta_instrs,
self.scale_granularity_m,
self.scale_granularity_k,
self.scale_major_mode,
self.a_scale_dtype,
self.num_scale_load2trans_stage,
)
def _validate_inputs(
self,
a: cute.Tensor,
a_scale: Optional[cute.Tensor],
b: cute.Tensor,
c: cute.Tensor,
) -> None:
"""
Validates input tensors and their properties.
:param a: Input tensor A.
:type a: cute.Tensor
:param a_scale: Scale tensor for tensor A (None for ConvertOnly mode).
:type a_scale: Optional[cute.Tensor]
:param b: Input tensor B.
:type b: cute.Tensor
:param c: Output tensor C.
:type c: cute.Tensor
:raises ValueError: If inputs don't meet kernel requirements.
"""
# Validate scale tensor major mode
if cutlass.const_expr(
self.scale_mode == TransformMode.ConvertScale
and utils.LayoutEnum.from_tensor(a_scale).mma_major_mode()
!= tcgen05.OperandMajorMode.MN
):
raise ValueError("scale_major_mode should be m-major")
@cute.jit
def __call__(
self,
a: cute.Tensor,
a_scale: Optional[cute.Tensor], # None for ConvertOnly mode
b: cute.Tensor,
cumsum: cute.Tensor,
c: cute.Tensor,
max_active_clusters: cutlass.Constexpr,
stream: cuda.CUstream,
):
"""
Executes the Mixed Input Grouped GEMM operation.
This method sets up the kernel parameters, computes the grid size,
defines the shared storage, and launches the kernel.
The execution steps are as follows:
- Setup static attributes before smem/grid/tma computation.
- Setup TMA load/store atoms and tensors.
- Compute grid size with regard to hardware constraints.
- Define shared storage for kernel.
- Launch the kernel synchronously.
:param a: Input tensor A.
:type a: cute.Tensor
:param a_scale: Scale tensor for tensor A (None for ConvertOnly mode).
:type a_scale: Optional[cute.Tensor]
:param b: Input tensor B.
:type b: cute.Tensor
:param cumsum: tensor containing the cumulative size of each group along the search mode(aka, N mode in this example).
:type cumsum: cute.Tensor
:param c: Output tensor C.
:type c: cute.Tensor
:param max_active_clusters: Maximum number of active clusters to launch.
:type max_active_clusters: cutlass.Constexpr
:param stream: CUDA stream to launch the kernel on.
:type stream: cuda.CUstream
"""
self.a_dtype: type[cutlass.Numeric] = a.element_type
self.a_scale_dtype: type[cutlass.Numeric] = (
a_scale.element_type
if self.scale_mode is TransformMode.ConvertScale
else None
)
self.b_dtype: type[cutlass.Numeric] = b.element_type
self.c_dtype: type[cutlass.Numeric] = c.element_type
self.mma_dtype = self.b_dtype
self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode()
self.scale_major_mode = (
utils.LayoutEnum.from_tensor(a_scale).mma_major_mode()
if self.scale_mode is TransformMode.ConvertScale
else None
)
self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = utils.LayoutEnum.from_tensor(c)
if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale):
# Get gmem layout for scale tensor
self.gmem_layout_scale = mixed_input_utils.get_gmem_layout_scale(
a.shape,
self.scale_granularity_m,
self.scale_granularity_k,
self.scale_major_mode,
)
# Validate inputs
self._validate_inputs(a, a_scale, b, c)
# Setup attributes that dependent on gemm inputs
self._setup_attributes()
tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.mma_dtype,
self.a_major_mode,
self.b_major_mode,
self.acc_dtype,
self.cta_group,
self.mma_tiler[:2],
self.transform_a_source,
)
# Set up gmem copy atoms for A, scale, and B
a_op = mixed_input_utils.get_tma_atom_kind(
self.is_a_mcast, self.use_2cta_instrs, False
)
b_op = mixed_input_utils.get_tma_atom_kind(
self.is_b_mcast, self.use_2cta_instrs, True
)
a_scale_op = a_op
# Deduce TMA copy atom and TMA tensor for A, scale, and B
smem_layout_a_per_stage = cute.slice_(self.smem_layout_a, (None, None, None, 0))
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
a_op,
a,
smem_layout_a_per_stage,
self.mma_tiler,
tiled_mma,
self.cluster_layout_vmnk.shape,
internal_type=(
cutlass.TFloat32 if a.element_type is cutlass.Float32 else None
),
)
tma_atom_scale, tma_tensor_scale = None, None
if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale):
# Partition smem layout for scale tensor to make it compatible with TMA atom
smem_layout_for_tma_atom = cute.get(
tiled_mma._thrfrg_A(self.smem_layout_scale_per_stage.outer), mode=[1]
)
# ((MMA_M, MMA_K), REST_M, REST_K)
smem_layout_for_tma_atom = cute.dice(
smem_layout_for_tma_atom,
(1, (1,) * cute.rank(self.smem_layout_scale_per_stage.outer)),
)
tma_atom_scale, tma_tensor_scale = cute.nvgpu.make_tiled_tma_atom_A(
a_scale_op,
cute.make_tensor(a_scale.iterator, self.gmem_layout_scale),
smem_layout_for_tma_atom,
# (SCALE_M, 1, SCALE_K)
(self.scale_tile_shape[0], 1, self.scale_tile_shape[1]),
tiled_mma,
self.cluster_layout_vmnk.shape,
internal_type=(
cutlass.TFloat32
if a_scale.element_type is cutlass.Float32
else None
),
)
smem_layout_b_per_stage = cute.slice_(self.smem_layout_b, (None, None, None, 0))
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
b_op,
b,
smem_layout_b_per_stage,
self.mma_tiler,
tiled_mma,
self.cluster_layout_vmnk.shape,
internal_type=(
cutlass.TFloat32 if b.element_type is cutlass.Float32 else None
),
)
# Calculate copy size for tensor A, B, and scale
a_copy_size = cute.size_in_bytes(self.a_dtype, smem_layout_a_per_stage)
b_copy_size = cute.size_in_bytes(self.b_dtype, smem_layout_b_per_stage)
a_scale_copy_size = (
cute.size_in_bytes(self.a_scale_dtype, self.smem_layout_scale_per_stage)
if self.scale_mode is TransformMode.ConvertScale
else 0
)
self.num_tma_load_bytes_a = a_copy_size
self.num_tma_load_bytes_b = b_copy_size * cute.size(tiled_mma.thr_id.shape)
self.num_tma_load_bytes_scale = a_scale_copy_size
self.tile_sched_params, grid = self._compute_grid(
c,
self.cta_tile_shape_mnk,
self.cluster_shape_mn,
max_active_clusters,
)
epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(),
c,
epi_smem_layout,
self.epi_tile,
)
c_smem_size = cute.cosize(self.c_smem_layout_staged.outer)
# Shared memory structure
a_smem_size = cute.cosize(self.smem_layout_a.outer)
b_smem_size = cute.cosize(self.smem_layout_b.outer)
a_transform_smem_size = (
cute.cosize(self.smem_layout_a_transform.outer)
if self.transform_a_source == tcgen05.OperandSource.SMEM
else 0
)
a_scale_smem_size = (
cute.cosize(self.smem_layout_scale.outer)
if self.scale_mode is TransformMode.ConvertScale
else 0
)
@cute.struct
class SharedStorage:
# buffer holding group search results
tile_info: cute.struct.MemRange[cutlass.Int32, 4 * self.num_tile_info_stage]
a_load2trans_full_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_load2trans_stage
]
a_load2trans_empty_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_load2trans_stage
]
a_scale_load2trans_full_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_scale_load2trans_stage
]
a_scale_load2trans_empty_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_scale_load2trans_stage
]
a_trans2mma_full_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_trans2mma_stage
]
a_trans2mma_empty_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_trans2mma_stage
]
b_load2mma_full_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_load2trans_stage
]
b_load2mma_empty_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_load2trans_stage
]
acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
tile_info_full_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_tile_info_stage
]
tile_info_empty_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_tile_info_stage
]
tmem_dealloc_mbar_ptr: cutlass.Int64
tmem_holding_buf: cutlass.Int32
self.shared_storage = SharedStorage
# Launch kernel
self.kernel(
tiled_mma,
tma_atom_a,
tma_tensor_a,
tma_atom_scale,
tma_tensor_scale,
tma_atom_b,
tma_tensor_b,
tma_atom_c,
tma_tensor_c,
c,
cumsum,
self.group_count,
self.cluster_layout_vmnk,
self.smem_layout_a,
self.smem_layout_scale,
self.smem_layout_a_transform,
self.smem_layout_b,
self.c_smem_layout_staged,
self.epi_tile,
self.tile_sched_params,
).launch(
grid=grid,
block=[self.threads_per_cta, 1, 1],
cluster=(*self.cluster_shape_mn, 1),
min_blocks_per_mp=1,
stream=stream,
)
return
# GPU device kernel
@cute.kernel
def kernel(
self,
tiled_mma: cute.TiledMma,
tma_atom_a: cute.CopyAtom,
mA_mkl: cute.Tensor,
tma_atom_s: Optional[cute.CopyAtom],
mS_mkl: Optional[cute.Tensor],
tma_atom_b: cute.CopyAtom,
mB_nkl: cute.Tensor,
tma_atom_c: cute.CopyAtom,
mC_mnl: cute.Tensor,
tensor_c: cute.Tensor,
cumsum: cute.Tensor,
group_count: cutlass.Constexpr[int],
cluster_layout_vmnk: cute.Layout,
a_smem_layout: cute.ComposedLayout,
scale_smem_layout: cute.ComposedLayout,
a_smem_layout_transform: cute.ComposedLayout,
b_smem_layout: cute.ComposedLayout,
c_smem_layout_staged: cute.ComposedLayout,
epi_tile: cute.Tile,
tile_sched_params: utils.PersistentTileSchedulerParams,
):
"""
GPU device kernel performing the Persistent Mixed-Input Grouped GEMM computation.
"""
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
bidx, bidy, bidz = cute.arch.block_idx()
# Prefetch TMA descriptors
if warp_idx == self.epilog_warp_id[0]:
cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale):
cpasync.prefetch_descriptor(tma_atom_s)
cpasync.prefetch_descriptor(tma_atom_c)
use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
bidx, bidy, bidz = cute.arch.block_idx()
# Compute how many k_tiles share the same scale
num_k_tiles_per_scale = self.scale_granularity_k // self.cta_tile_shape_mnk[2]
mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
is_leader_cta = mma_tile_coord_v == 0
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster()
)
block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(
cta_rank_in_cluster
)
tidx, _, _ = cute.arch.thread_idx()
smem = utils.SmemAllocator()
storage = smem.allocate(self.shared_storage)
# Initialize load2transform pipeline, which tracks the dependencies between TMA's loading
# of A and B, and the transformation of A and MMA's consumption
transform_thread_idx = (
tidx - 32 * self.transform_warp_id[0]
if tidx >= 32 * self.transform_warp_id[0]
else tidx
)
a_load2trans_pipeline = pipeline.PipelineTmaAsync.create(
barrier_storage=storage.a_load2trans_full_mbar_ptr.data_ptr(),
num_stages=self.num_load2trans_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread,
self.num_mcast_ctas_a * len(self.transform_warp_id),
),
tx_count=self.num_tma_load_bytes_a,
cta_layout_vmnk=cluster_layout_vmnk,
tidx=transform_thread_idx,
mcast_mode_mn=(1, 0), # multicast for A will only happen on the M-mode
defer_sync=True,
)
# Initialize scale_load2trans pipeline, which tracks the dependencies between TMA's loading
# of scale, and the transformation of A
scale_load2trans_pipeline = None
if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale):
num_producers_a_scale = self.num_mcast_ctas_a
scale_load2trans_pipeline = pipeline.PipelineTmaAsync.create(
barrier_storage=storage.a_scale_load2trans_full_mbar_ptr.data_ptr(),
num_stages=self.num_scale_load2trans_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread,
num_producers_a_scale
* len(self.transform_warp_id)
* num_k_tiles_per_scale,
),
tx_count=self.num_tma_load_bytes_scale,
cta_layout_vmnk=cluster_layout_vmnk,
tidx=transform_thread_idx,
mcast_mode_mn=(
1,
0,
), # multicast for scale_a will only happen on the M-mode
defer_sync=True,
)
# Initialize transform2mma pipeline, which tracks the dependencies between the transformation
# of A and MMA's consumption of transformed A
cta_v_size = cute.size(cluster_layout_vmnk, mode=[0])
trans2mma_pipeline = pipeline.PipelineAsyncUmma.create(
barrier_storage=storage.a_trans2mma_full_mbar_ptr.data_ptr(),
num_stages=self.num_trans2mma_stage,
producer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread,
32 * len(self.transform_warp_id) * cta_v_size,
),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
)
# Initialize pipeline for tensor B load to MMA
# MMA warp informs TMA warp to proceed to load next tile of B tensor
b_load2mma_pipeline = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.b_load2mma_full_mbar_ptr.data_ptr(),
num_stages=self.num_load2trans_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, self.num_mcast_ctas_b
),
tx_count=self.num_tma_load_bytes_b,
cta_layout_vmnk=cluster_layout_vmnk,
mcast_mode_mn=(0, 1), # multicast for B will only happen on the N-mode
defer_sync=True,
)
# Initialize accumulator pipeline, which tracks the dependencies between
# MMA's computation of accumulators and epilogue warps' consumption of accumulators
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, cta_v_size * len(self.epilog_warp_id)
),
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
)
# Initialize tile info pipeline, which tracks the dependencies between
# tile scheduling warp and other warps
# Skip scheduler warp and TMA scale load warp when scale_mode is ConvertOnly
# when computing consumer thread count
num_tile_info_pipeline_consumer_threads = (
self.threads_per_cta
- 32
- (32 if self.scale_mode is TransformMode.ConvertOnly else 0)
)
tile_info_pipeline = pipeline.PipelineAsync.create(
barrier_storage=storage.tile_info_full_mbar_ptr.data_ptr(),
num_stages=self.num_tile_info_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * 1),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread,
num_tile_info_pipeline_consumer_threads,
),
defer_sync=True,
)
# Tensor memory dealloc barrier init
tmem = utils.TmemAllocator(
storage.tmem_holding_buf,
barrier_for_retrieve=self.tmem_ptr_sync_barrier,
allocator_warp_id=self.epilog_warp_id[0],
is_two_cta=use_2cta_instrs,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
)
# Cluster arrive after barrier init
pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True)
# Setup smem tensor A/scale/B/C
sC = smem.allocate_tensor(
element_type=self.c_dtype,
layout=c_smem_layout_staged.outer,
byte_alignment=self.smem_buffer_align_bytes,
swizzle=c_smem_layout_staged.inner,
)
sA_input = smem.allocate_tensor(
element_type=self.a_dtype,
layout=a_smem_layout.outer,
byte_alignment=self.smem_buffer_align_bytes,
swizzle=a_smem_layout.inner,
)
sS_input = (
smem.allocate_tensor(
element_type=self.mma_dtype,
layout=scale_smem_layout.outer,
byte_alignment=self.smem_buffer_align_bytes,
swizzle=scale_smem_layout.inner,
)
if self.scale_mode is TransformMode.ConvertScale
else None
)
sB_input = smem.allocate_tensor(
element_type=self.b_dtype,
layout=b_smem_layout.outer,
byte_alignment=self.smem_buffer_align_bytes,
swizzle=b_smem_layout.inner,
)
sA_transform = None
# Get smem tensor for transformed A when transform_a_source is SMEM
if cutlass.const_expr(self.transform_a_source == tcgen05.OperandSource.SMEM):
sA_transform = smem.allocate_tensor(
element_type=self.mma_dtype,
layout=a_smem_layout_transform.outer,
byte_alignment=self.smem_buffer_align_bytes,
swizzle=a_smem_layout_transform.inner,
)
sTile_info = storage.tile_info.get_tensor(
cute.make_layout((4, self.num_tile_info_stage), stride=(1, 4))
)
# Compute multicast mask for A/B buffer full
a_full_mcast_mask = None
b_full_mcast_mask = None
s_full_mcast_mask = None
if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
a_full_mcast_mask = cpasync.create_tma_multicast_mask(
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
)
# scale tensor share the same multicast mask with A tensor
s_full_mcast_mask = a_full_mcast_mask
b_full_mcast_mask = cpasync.create_tma_multicast_mask(
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1
)
# local_tile partition global tensors
# (bM, bK, loopM, loopK, loopL)
gA_mkl = cute.local_tile(
mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)
)
# (bM, bK, loopM, loopK, loopL)
gS_mkl = (
cute.local_tile(
mS_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)
)
if self.scale_mode is TransformMode.ConvertScale
else None
)
# (bN, bK, loopN, loopK, loopL)
gB_nkl = cute.local_tile(
mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)
)
# (bM, bN, loopM, loopN, loopL)
gC_mnl = cute.local_tile(
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
)
gC_mnl_simt = cute.local_tile(
tensor_c, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
)
k_tile_cnt = cute.size(gA_mkl, mode=[3])
# Partition global tensor for TiledMMA_A/B/C
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
# (MMA, MMA_M, MMA_K, loopM, loopK, loopL)
tCgA = thr_mma.partition_A(gA_mkl)
# (MMA, MMA_M, MMA_K, loopM, loopK, loopL)
tCgS = (
thr_mma.partition_A(gS_mkl)
if self.scale_mode is TransformMode.ConvertScale
else None
)
# (MMA, MMA_N, MMA_K, loopN, loopK, loopL)
tCgB = thr_mma.partition_B(gB_nkl)
# (MMA, MMA_M, MMA_N, loopM, loopN, loopL)
tCgC = thr_mma.partition_C(gC_mnl)
tCgC_simt = thr_mma.partition_C(gC_mnl_simt)
# Setup copy atom to load A from shared memory for further transformation
copy_atom_a_input = (
cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(), self.a_dtype, num_bits_per_copy=32
)
if self.scale_mode is TransformMode.ConvertScale
else None
)
a_smem_shape = tiled_mma.partition_shape_A(
cute.dice(self.mma_tiler, (1, None, 1))
)
# Setup copy atom to store transformed A into tensor memory or shared memory
copy_atom_a_transform = mixed_input_utils.get_copy_atom_a_transform(
self.mma_dtype,
self.use_2cta_instrs,
self.transform_a_source,
a_smem_shape,
self.a_dtype,
)
# Partition global/shared tensor for TMA load A/B
# TMA load A partition_S/D
a_cta_layout = cute.make_layout(
cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
)
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), loopM, loopK, loopL)
tAsA, tAgA = cpasync.tma_partition(
tma_atom_a,
block_in_cluster_coord_vmnk[2],
a_cta_layout,
cute.group_modes(sA_input, 0, 3),
cute.group_modes(tCgA, 0, 3),
)
tCsS = None
tSsS = None
tSgS = None
if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale):
thr_mma_leader_cta = tiled_mma.get_slice(0)
# (MMA, MMA_M, MMA_K, STAGE)
tCsS = thr_mma_leader_cta.partition_A(sS_input)
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), loopM, loopK, loopL)
tSsS, tSgS = mixed_input_utils.scale_tma_partition(
tCsS,
tCgS,
tma_atom_s,
block_in_cluster_coord_vmnk,
a_cta_layout,
)
# TMA load B partition_S/D
b_cta_layout = cute.make_layout(
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
)
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), loopM, loopK, loopL)
tBsB, tBgB = cpasync.tma_partition(
tma_atom_b,
block_in_cluster_coord_vmnk[1],
b_cta_layout,
cute.group_modes(sB_input, 0, 3),
cute.group_modes(tCgB, 0, 3),
)
# (MMA, MMA_N, MMA_K, STAGE)
tCrB = tiled_mma.make_fragment_B(sB_input)
# (MMA, MMA_M, MMA_N)
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(
cute.append(acc_shape, self.num_acc_stage)
)
# Cluster wait before TMEM alloc and ensure pipelines are ready
pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn)
# TMEM allocation
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
# Get the pointer to the TMEM buffer
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
accumulators = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
tCrA = None
if cutlass.const_expr(self.transform_a_source == tcgen05.OperandSource.TMEM):
tmem_ptr_transform = cute.recast_ptr(
accumulators.iterator + self.num_acc_tmem_cols, dtype=self.mma_dtype
)
tCrA = cute.make_tensor(
tmem_ptr_transform,
tiled_mma.make_fragment_A(a_smem_layout_transform.outer).layout,
)
else:
tCrA = tiled_mma.make_fragment_A(sA_transform)
# Schedule warp
if warp_idx == self.schedule_warp_id:
cute.arch.warpgroup_reg_dealloc(self.num_regs_schedule_warp)
# Persistent tile scheduling loop
tile_sched = utils.StaticPersistentRuntimeTileScheduler.create(
tile_sched_params,
(bidx, bidy, bidz),
cute.arch.grid_dim(),
inner_mode=0,
)
work_tile = tile_sched.initial_work_tile_info()
tile_info_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_tile_info_stage
)
# Create initial group search state
search_state = create_initial_search_state()
not_last_tile = cutlass.Boolean(1)
while not_last_tile:
tile_info_pipeline.producer_acquire(tile_info_producer_state)
cluster_tile_coord_mnl = work_tile.tile_idx
cta_tile_coord_m = (
cluster_tile_coord_mnl[0] * self.cluster_shape_mn[0]
+ block_in_cluster_coord_vmnk[1] * cute.size(tiled_mma.thr_id.shape)
+ block_in_cluster_coord_vmnk[0]
)
cta_tile_offset_n = block_in_cluster_coord_vmnk[2]
search_state = self.group_search(
group_count,
cluster_tile_coord_mnl[1],
search_state,
cumsum,
1, # mode index to perform the search. 0 for M and 1 for N
)
cur_sTile_info = sTile_info[(None, tile_info_producer_state.index)]
not_last_tile = search_state.cur_group_idx <= group_count
# Store tile info into shared memory buffer
with cute.arch.elect_one():
cur_sTile_info[0] = cta_tile_coord_m
cur_sTile_info[1] = (
search_state.cur_start
+ cta_tile_offset_n * self.cta_tile_shape_mnk[1]
)
cur_sTile_info[2] = search_state.cur_group_idx - 1
cur_sTile_info[3] = (
search_state.cur_boundary
- search_state.cur_start
- (cta_tile_offset_n * self.cta_tile_shape_mnk[1])
)
# Fence and barrier to ensure tile info store has finished
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
self.sched_sync_barrier.arrive_and_wait()
# Commit tile info pipeline
tile_info_pipeline.producer_commit(tile_info_producer_state)
# Advance to next tile
tile_info_producer_state.advance()
tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
tile_info_pipeline.producer_tail(tile_info_producer_state)
# Specialized TMA load warp for A/B tensor
if warp_idx == self.tma_warp_id:
cute.arch.warpgroup_reg_dealloc(self.num_regs_tma_warps)
# Persistent tile scheduling loop
tile_info_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_tile_info_stage
)
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
work_tile = self.make_work_tile_info(
sTile_info[(None, tile_info_consumer_state.index)]
)
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
a_load2trans_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_load2trans_stage
)
b_load2mma_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_load2trans_stage
)
while work_tile.is_valid_tile:
tAgA_slice = tAgA[
(
None,
work_tile.cta_coord_m // cute.size(tiled_mma.thr_id.shape),
None,
work_tile.group_idx,
)
]
# Apply offset to B tensor based on group search result
coord_n_offset = (
(work_tile.coord_n, 0, 0)
if cutlass.const_expr(
self.b_major_mode == tcgen05.OperandMajorMode.MN
)
else (0, work_tile.coord_n, 0)
)
tBgB_slice = cute.make_tensor(
(
tBgB.iterator[0] + coord_n_offset[0],
coord_n_offset[1] + tBgB.iterator[1],
coord_n_offset[2] + tBgB.iterator[2],
),
cute.slice_(tBgB.layout, (None, 0, None, 0)),
)
a_load2trans_producer_state.reset_count()
peek_load2trans_empty_status = cutlass.Boolean(1)
if a_load2trans_producer_state.count < k_tile_cnt:
peek_load2trans_empty_status = (
a_load2trans_pipeline.producer_try_acquire(
a_load2trans_producer_state
)
)
b_load2mma_producer_state.reset_count()
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
a_load2trans_pipeline.producer_acquire(
a_load2trans_producer_state, peek_load2trans_empty_status
)
b_load2mma_pipeline.producer_acquire(b_load2mma_producer_state)
# TMA load A/B
cute.copy(
tma_atom_a,
tAgA_slice[(None, a_load2trans_producer_state.count)],
tAsA[(None, a_load2trans_producer_state.index)],
tma_bar_ptr=a_load2trans_pipeline.producer_get_barrier(
a_load2trans_producer_state
),
mcast_mask=a_full_mcast_mask,
)
cute.copy(
tma_atom_b,
tBgB_slice[(None, b_load2mma_producer_state.count)],
tBsB[(None, b_load2mma_producer_state.index)],
tma_bar_ptr=b_load2mma_pipeline.producer_get_barrier(
b_load2mma_producer_state
),
mcast_mask=b_full_mcast_mask,
)
a_load2trans_pipeline.producer_commit(a_load2trans_producer_state)
b_load2mma_pipeline.producer_commit(b_load2mma_producer_state)
a_load2trans_producer_state.advance()
b_load2mma_producer_state.advance()
if a_load2trans_producer_state.count < k_tile_cnt:
peek_load2trans_empty_status = (
a_load2trans_pipeline.producer_try_acquire(
a_load2trans_producer_state
)
)
# Advance to next tile
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
work_tile = self.make_work_tile_info(
sTile_info[(None, tile_info_consumer_state.index)]
)
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
# Wait A/B buffer empty
a_load2trans_pipeline.producer_tail(a_load2trans_producer_state)
b_load2mma_pipeline.producer_tail(b_load2mma_producer_state)
# Specialized TMA load for scale tensor
if warp_idx == self.scale_tma_warp_id:
cute.arch.warpgroup_reg_dealloc(self.num_regs_tma_warps)
if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale):
# Persistent tile scheduling loop
tile_info_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_tile_info_stage
)
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
work_tile = self.make_work_tile_info(
sTile_info[(None, tile_info_consumer_state.index)]
)
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
scale_load2trans_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_scale_load2trans_stage
)
scale_k_tile_cnt = cute.size(mS_mkl.layout.shape[1][1])
while work_tile.is_valid_tile:
# ((atom_v, rest_v), RestK)
tSgS_slice = tSgS[
(
None,
work_tile.cta_coord_m // cute.size(tiled_mma.thr_id.shape),
None,
work_tile.group_idx,
)
]
# Filter zeros in rest mode
rest_filtered = cute.filter_zeros(tSgS_slice[(0, None)].layout)
tSgS_slice_filtered = cute.make_tensor(
tSgS_slice.iterator,
cute.make_layout(
(tSgS_slice.layout[0].shape, rest_filtered.shape),
stride=(tSgS_slice.layout[0].stride, rest_filtered.stride),
),
)
scale_load2trans_producer_state.reset_count()
peek_scale_load2trans_empty_status = cutlass.Boolean(1)
if scale_load2trans_producer_state.count < scale_k_tile_cnt:
peek_scale_load2trans_empty_status = (
scale_load2trans_pipeline.producer_try_acquire(
scale_load2trans_producer_state
)
)
for k_tile in cutlass.range(0, scale_k_tile_cnt, 1, unroll=1):
scale_load2trans_pipeline.producer_acquire(
scale_load2trans_producer_state,
peek_scale_load2trans_empty_status,
)
# TMA load scale
cute.copy(
tma_atom_s,
tSgS_slice_filtered[
(None, scale_load2trans_producer_state.count)
],
tSsS[(None, scale_load2trans_producer_state.index)],
tma_bar_ptr=scale_load2trans_pipeline.producer_get_barrier(
scale_load2trans_producer_state
),
mcast_mask=s_full_mcast_mask,
)
scale_load2trans_producer_state.advance()
peek_scale_load2trans_empty_status = cutlass.Boolean(1)
if scale_load2trans_producer_state.count < scale_k_tile_cnt:
peek_scale_load2trans_empty_status = (
scale_load2trans_pipeline.producer_try_acquire(
scale_load2trans_producer_state
)
)
# Advance to next tile
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
work_tile = self.make_work_tile_info(
sTile_info[(None, tile_info_consumer_state.index)]
)
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
# Wait scale buffer empty
scale_load2trans_pipeline.producer_tail(scale_load2trans_producer_state)
# Specialized transform warps
if warp_idx >= self.transform_warp_id[0]:
cute.arch.warpgroup_reg_alloc(self.num_regs_transform_warps)
transform_local_tidx = tidx - 32 * self.transform_warp_id[0]
# Partition tensors for transform input and output and set up the copy atom
# used for loading and storing transformed A tensor
src_copy_a, dst_copy_a, tAsA_input, tAsA_transform = (
mixed_input_utils.transform_partition(
self.transform_a_source,
self.scale_mode,
copy_atom_a_input,
copy_atom_a_transform,
sA_input,
(
tCrA
if self.transform_a_source == tcgen05.OperandSource.TMEM
else sA_transform
),
transform_local_tidx,
)
)
# make fragment for input A and transformed A
tArA = cute.make_rmem_tensor(
tAsA_input[(None, None, None, None, 0)].shape, tAsA_input.element_type
)
tArA_transform = cute.make_rmem_tensor(
tAsA_input[(None, None, None, None, 0)].shape, self.mma_dtype
)
# Partition scale tensor
smem_thr_copy_S = None
tSsS_trans = None
tSrS_copy = None
tSrS = None
if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale):
smem_thr_copy_S, tSsS_trans, tSrS_copy, tSrS = (
mixed_input_utils.scale_partition(
src_copy_a, tCsS, transform_local_tidx, self.mma_dtype
)
)
assert cute.size(tSrS, mode=[0]) == cute.size(tArA, mode=[0]), (
"tSrS and tArA have different leading dimension"
)
assert cute.size(tSrS) == cute.size(tArA), (
"tSrS and tArA have different shape"
)
# Deduce a subtile size and tile tensors
transform_tiler_size = min(
cute.size(cute.coalesce(tAsA_input.layout), mode=[0]), 64
)
transform_tiler = cute.make_layout(transform_tiler_size)
tArA_load = cute.flat_divide(tArA, transform_tiler)
tArA_load = cute.group_modes(tArA_load, 1, cute.rank(tArA_load))
tSrS_load = (
cute.flat_divide(tSrS, transform_tiler)
if self.scale_mode is TransformMode.ConvertScale
else None
)
tSrS_load = (
cute.group_modes(tSrS_load, 1, cute.rank(tSrS_load))
if self.scale_mode is TransformMode.ConvertScale
else None
)
tArA_transform_store = cute.flat_divide(tArA_transform, transform_tiler)
tArA_transform_store = cute.group_modes(
tArA_transform_store, 1, cute.rank(tArA_transform_store)
)
tile_info_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_tile_info_stage
)
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
work_tile = self.make_work_tile_info(
sTile_info[(None, tile_info_consumer_state.index)]
)
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
a_load2trans_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer,
self.num_load2trans_stage,
)
scale_load2trans_consumer_state = (
pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer,
self.num_scale_load2trans_stage,
)
if self.scale_mode is TransformMode.ConvertScale
else None
)
trans2mma_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer,
self.num_trans2mma_stage,
)
while work_tile.is_valid_tile:
a_load2trans_consumer_state.reset_count()
peek_load2trans_full_status = cutlass.Boolean(1)
if a_load2trans_consumer_state.count < k_tile_cnt:
peek_load2trans_full_status = (
a_load2trans_pipeline.consumer_try_wait(
a_load2trans_consumer_state
)
)
peek_scale_load2trans_full_status = cutlass.Boolean(1)
if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale):
scale_load2trans_consumer_state.reset_count()
peek_scale_load2trans_full_status = (
scale_load2trans_pipeline.consumer_try_wait(
scale_load2trans_consumer_state
)
)
trans2mma_producer_state.reset_count()
peek_trans2mma_empty_status = cutlass.Boolean(1)
if trans2mma_producer_state.count < k_tile_cnt:
peek_trans2mma_empty_status = (
trans2mma_pipeline.producer_try_acquire(
trans2mma_producer_state
)
)
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
a_load2trans_pipeline.consumer_wait(
a_load2trans_consumer_state, peek_load2trans_full_status
)
tAsA_input_slice = tAsA_input[
(None, None, None, None, a_load2trans_consumer_state.index)
]
tAsA_input_slice = cute.flat_divide(
tAsA_input_slice, transform_tiler
)
tAsA_input_slice = cute.group_modes(
tAsA_input_slice, 1, cute.rank(tAsA_input_slice)
)
if cutlass.const_expr(
self.scale_mode == TransformMode.ConvertScale
):
scale_load2trans_pipeline.consumer_wait(
scale_load2trans_consumer_state,
peek_scale_load2trans_full_status,
)
trans2mma_pipeline.producer_acquire(
trans2mma_producer_state, peek_trans2mma_empty_status
)
# load scale tensor when needed
if cutlass.const_expr(
self.scale_mode == TransformMode.ConvertScale
):
if k_tile % num_k_tiles_per_scale == 0:
tSsS_slice = tSsS_trans[
(
None,
None,
None,
None,
scale_load2trans_consumer_state.index,
)
]
tSsS_slice_filtered = cute.make_tensor(
tSsS_slice.iterator,
cute.filter_zeros(tSsS_slice.layout),
)
cute.autovec_copy(tSsS_slice_filtered, tSrS_copy)
cur_scale_load2trans_consumer_state = (
scale_load2trans_consumer_state.clone()
)
if (k_tile + 1) % num_k_tiles_per_scale == 0:
scale_load2trans_consumer_state.advance()
cur_a_load2trans_consumer_state = (
a_load2trans_consumer_state.clone()
)
for idx in cutlass.range_constexpr(cute.size(tArA_load, mode=[1])):
# Load A from shared memory
cute.autovec_copy(
tAsA_input_slice[(None, idx)],
tArA_load[(None, idx)],
)
if cutlass.const_expr(
idx == cute.size(tArA_load, mode=[1]) - 1
):
a_load2trans_consumer_state.advance()
if a_load2trans_consumer_state.count < k_tile_cnt:
peek_load2trans_full_status = (
a_load2trans_pipeline.consumer_try_wait(
a_load2trans_consumer_state
)
)
if cutlass.const_expr(
self.scale_mode == TransformMode.ConvertScale
):
peek_scale_load2trans_full_status = (
scale_load2trans_pipeline.consumer_try_wait(
scale_load2trans_consumer_state
)
)
# Convert it to mma dtype
tensor_transformed = (
tArA_load[(None, idx)].load().to(self.mma_dtype)
)
if cutlass.const_expr(
self.scale_mode == TransformMode.ConvertScale
):
scale = cute.TensorSSA(
tSrS_load[(None, idx)].load(),
tensor_transformed.shape,
self.mma_dtype,
)
# Apply scale
tensor_transformed = tensor_transformed * scale
tArA_transform_store[(None, idx)].store(tensor_transformed)
# Store transformed A to tensor memory or shared memory
if cutlass.const_expr(dst_copy_a is not None):
cute.copy(
dst_copy_a,
tArA_transform,
tAsA_transform[
(None, None, None, None, trans2mma_producer_state.index)
],
)
else:
cute.autovec_copy(
tArA_transform,
tAsA_transform[
(None, None, None, None, trans2mma_producer_state.index)
],
)
# Ensure all transform threads have finished the copy and reached the fence
self.transform_sync_barrier.arrive_and_wait()
if cutlass.const_expr(
self.transform_a_source == tcgen05.OperandSource.TMEM
):
cute.arch.fence_view_async_tmem_store()
else:
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
if cutlass.const_expr(
self.scale_mode == TransformMode.ConvertScale
):
scale_load2trans_pipeline.consumer_release(
cur_scale_load2trans_consumer_state
)
a_load2trans_pipeline.consumer_release(
cur_a_load2trans_consumer_state
)
# Signal the completion of transformation
trans2mma_pipeline.producer_commit(trans2mma_producer_state)
trans2mma_producer_state.advance()
if trans2mma_producer_state.count < k_tile_cnt:
peek_trans2mma_empty_status = (
trans2mma_pipeline.producer_try_acquire(
trans2mma_producer_state
)
)
# Advance to next tile
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
work_tile = self.make_work_tile_info(
sTile_info[(None, tile_info_consumer_state.index)]
)
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
# Wait a_transform buffer empty
trans2mma_pipeline.producer_tail(trans2mma_producer_state)
# Specialized MMA warp
if warp_idx == self.mma_warp_id:
cute.arch.warpgroup_reg_dealloc(self.num_regs_mma_warp)
tCtAcc_base = accumulators
# Persistent tile scheduling loop
tile_info_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_tile_info_stage
)
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
work_tile = self.make_work_tile_info(
sTile_info[(None, tile_info_consumer_state.index)]
)
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
trans2mma_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_trans2mma_stage
)
b_load2mma_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_load2trans_stage
)
acc_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_acc_stage
)
while work_tile.is_valid_tile:
# (MMA, MMA_M, MMA_N)
tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)]
b_load2mma_consumer_state.reset_count()
trans2mma_consumer_state.reset_count()
peek_trans2mma_full_status = cutlass.Boolean(1)
if is_leader_cta:
if trans2mma_consumer_state.count < k_tile_cnt:
peek_trans2mma_full_status = (
trans2mma_pipeline.consumer_try_wait(
trans2mma_consumer_state
)
)
acc_pipeline.producer_acquire(acc_producer_state)
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
# Mma mainloop
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
trans2mma_pipeline.consumer_wait(
trans2mma_consumer_state, peek_trans2mma_full_status
)
b_load2mma_pipeline.consumer_wait(b_load2mma_consumer_state)
num_kblocks = cute.size(tCrA, mode=[2])
for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
kblock_coord_a = (
None,
None,
kblock_idx,
trans2mma_consumer_state.index,
)
kblock_coord_b = (
None,
None,
kblock_idx,
b_load2mma_consumer_state.index,
)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[kblock_coord_a],
tCrB[kblock_coord_b],
tCtAcc,
)
# Enable accumulate on tCtAcc after first kblock
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
trans2mma_pipeline.consumer_release(trans2mma_consumer_state)
b_load2mma_pipeline.consumer_release(b_load2mma_consumer_state)
trans2mma_consumer_state.advance()
b_load2mma_consumer_state.advance()
peek_trans2mma_full_status = cutlass.Boolean(1)
if trans2mma_consumer_state.count < k_tile_cnt:
peek_trans2mma_full_status = (
trans2mma_pipeline.consumer_try_wait(
trans2mma_consumer_state
)
)
# Async arrive accumulator buffer full
acc_pipeline.producer_commit(acc_producer_state)
acc_producer_state.advance()
# Advance to next tile
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
work_tile = self.make_work_tile_info(
sTile_info[(None, tile_info_consumer_state.index)]
)
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
# Wait for accumulator buffer empty
acc_pipeline.producer_tail(acc_producer_state)
# Specialized epilogue warps
if warp_idx < self.mma_warp_id:
cute.arch.warpgroup_reg_alloc(self.num_regs_epilogue_warps)
epi_tidx = tidx
tCtAcc_base = accumulators
# Partition for epilogue
tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = (
self.epilog_tmem_copy_and_partition(
epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs
)
)
tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype)
tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(
tiled_copy_t2r, tTR_rC, epi_tidx, sC
)
(tma_atom_c, bSG_sC, bSG_gC_partitioned, simt_atom, tTR_gC_partitioned) = (
self.epilog_gmem_copy_and_partition(
epi_tidx, tma_atom_c, tiled_copy_t2r, tCgC, tCgC_simt, epi_tile, sC
)
)
# Predicates
thr_mapping = cute.make_identity_tensor(
(self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1])
)
thr_mapping_mn = cute.flat_divide(thr_mapping, epi_tile)
thr_copy_t2r = tiled_copy_t2r.get_slice(epi_tidx)
m_thr_offset = thr_copy_t2r.partition_D(thr_mapping_mn)
m_thr_offset = cute.group_modes(m_thr_offset, 3, cute.rank(m_thr_offset))
acc_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage
)
c_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
32 * len(self.epilog_warp_id),
)
c_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.num_c_stage,
producer_group=c_producer_group,
)
# Persistent tile scheduling loop
tile_info_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_tile_info_stage
)
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
work_tile = self.make_work_tile_info(
sTile_info[(None, tile_info_consumer_state.index)]
)
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
num_prev_subtiles = cutlass.Int32(0)
while work_tile.is_valid_tile:
bSG_gC = bSG_gC_partitioned[
(
None,
None,
None,
work_tile.cta_coord_m // cute.size(tiled_mma.thr_id.shape),
0,
0,
)
]
tma_store_offset_coord = (
(work_tile.coord_n, 0, 0)
if cutlass.const_expr(self.c_layout.is_n_major_c())
else (0, work_tile.coord_n, 0)
)
bSG_gC = cute.make_tensor(
(
tma_store_offset_coord[0] + bSG_gC.iterator[0],
tma_store_offset_coord[1] + bSG_gC.iterator[1],
tma_store_offset_coord[2] + bSG_gC.iterator[2],
),
bSG_gC.layout,
)
tTR_gC = tTR_gC_partitioned[
(
None,
None,
None,
None,
None,
work_tile.cta_coord_m // cute.size(tiled_mma.thr_id.shape),
0,
0,
)
]
tTR_gC = cute.make_tensor(
tTR_gC.iterator + (work_tile.coord_n * tensor_c.layout.stride[1]),
tTR_gC.layout,
)
tTR_tAcc = tTR_tAcc_base[
(None, None, None, None, None, acc_consumer_state.index)
]
# Wait for accumulator buffer full
acc_pipeline.consumer_wait(acc_consumer_state)
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC))
tTR_gC = cute.group_modes(tTR_gC, 3, cute.rank(tTR_gC))
# Store accumulator to global memory in subtiles
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
for subtile_idx in cutlass.range(subtile_cnt):
# Load accumulator from tensor memory buffer to register
tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)]
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
if work_tile.distance_to_boundary >= self.cta_tile_shape_mnk[1]:
# Convert to C type
acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
acc_vec = acc_vec.to(self.c_dtype)
tRS_rC.store(acc_vec)
num_prev_subtiles += 1
c_buffer = num_prev_subtiles % self.num_c_stage
# Store C to shared memory
cute.copy(
tiled_copy_r2s,
tRS_rC,
tRS_sC[(None, None, None, c_buffer)],
)
# Fence and barrier to make sure shared memory store is visible to TMA store
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
self.epilog_sync_barrier.arrive_and_wait()
# TMA store C to global memory
if warp_idx == self.epilog_warp_id[0]:
cute.copy(
tma_atom_c,
bSG_sC[(None, c_buffer)],
bSG_gC[(None, subtile_idx)],
)
c_pipeline.producer_commit()
c_pipeline.producer_acquire()
self.epilog_sync_barrier.arrive_and_wait()
else:
# Convert to C type
acc_vec = tTR_rAcc.load()
acc_vec = acc_vec.to(self.c_dtype)
tTR_rC.store(acc_vec)
# Compute predicate for SIMT store
tCpC = cute.make_rmem_tensor(
cute.make_layout(tTR_rC.shape),
cutlass.Boolean,
)
m_thr_slice = m_thr_offset[(None, None, None, subtile_idx)]
for i in cutlass.range(cute.size(tCpC), unroll_full=True):
tCpC[i] = (
m_thr_slice[(i)][0]
+ work_tile.cta_coord_m * self.cta_tile_shape_mnk[0]
< tensor_c.shape[0]
) and (m_thr_slice[(i)][1] < work_tile.distance_to_boundary)
# Store C to global memory
cute.copy(
simt_atom,
cute.flatten(tTR_rC),
cute.flatten(tTR_gC[(None, None, None, subtile_idx)]),
pred=cute.flatten(tCpC),
)
# Async arrive accumulator buffer empty
with cute.arch.elect_one():
acc_pipeline.consumer_release(acc_consumer_state)
acc_consumer_state.advance()
# Advance to next tile
tile_info_pipeline.consumer_wait(tile_info_consumer_state)
work_tile = self.make_work_tile_info(
sTile_info[(None, tile_info_consumer_state.index)]
)
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
# Dealloc the tensor memory buffer
tmem.relinquish_alloc_permit()
self.epilog_sync_barrier.arrive_and_wait()
tmem.free(tmem_ptr)
c_pipeline.producer_tail()
@cute.jit
def group_search(
self,
group_count: cutlass.Int32,
linear_idx: cutlass.Int32,
search_state: ContiguousGGSearchState,
cumsum: cute.Tensor,
search_mode: int,
) -> ContiguousGGSearchState:
"""
Group search for contiguously grouped gemm.
"""
not_found = linear_idx >= search_state.cur_tile_count
next_boundary = cutlass.Int32(0)
cur_group_idx = search_state.cur_group_idx
cur_offset = search_state.cur_offset
last_tile_count = search_state.last_tile_count
cur_boundary = search_state.cur_boundary
cur_tile_count = search_state.cur_tile_count
if not_found:
cur_group_idx = cur_group_idx + 1
while not_found and cur_group_idx <= group_count:
next_boundary = cumsum[cur_group_idx]
num_m_blocks = cute.ceil_div(
(next_boundary - cur_boundary),
self.cluster_tile_shape_mnk[search_mode],
)
next_tile_count = num_m_blocks + cur_tile_count
not_found = linear_idx >= next_tile_count
last_tile_count = cur_tile_count
cur_offset = cur_boundary
cur_boundary = next_boundary
cur_tile_count = next_tile_count
if not_found:
cur_group_idx = cur_group_idx + 1
cur_start = cur_offset + self.cluster_tile_shape_mnk[search_mode] * (
linear_idx - last_tile_count
)
return ContiguousGGSearchState(
last_tile_count,
cur_boundary,
cur_tile_count,
cur_group_idx,
cur_offset,
cur_start,
)
def make_work_tile_info(self, sTile_info: cute.Tensor):
tile_info = cute.make_rmem_tensor(sTile_info.shape, sTile_info.element_type)
cute.autovec_copy(sTile_info, tile_info)
return GroupedWorkTileInfo(
self.group_count, tile_info[0], tile_info[1], tile_info[2], tile_info[3]
)
def epilog_gmem_copy_and_partition(
self,
tidx: cutlass.Int32,
tma_atom_c: cute.CopyAtom,
tiled_copy_t2r: cute.TiledCopy,
gC_mnl_tma: cute.Tensor,
gC_mnl_simt: cute.Tensor,
epi_tile: cute.Tile,
sC: cute.Tensor,
) -> tuple[cute.CopyAtom, cute.Tensor, cute.Tensor, cute.CopyAtom, cute.Tensor]:
"""
Partitions source and destination tensors for a global memory store.
This method generates a tiled copy for storing results to global memory
and partitions the source (register or shared memory) and destination
(global memory) tensors accordingly. The behavior varies based on whether
TMA store is enabled.
:param tidx: The thread index in epilogue warp groups.
:type tidx: cutlass.Int32
:param tma_atom_c: The TMA copy atom.
:type tma_atom_c: cute.CopyAtom
:param tiled_copy_t2r: The tiled copy operation for tmem to register copy.
:type tiled_copy_t2r: cute.TiledCopy
:param gC_mnl_tma: The global tensor C for TMA.
:type gC_mnl_tma: cute.Tensor
:param gC_mnl_simt: The global tensor C for SIMT Copy.
:type gC_mnl_simt: cute.Tensor
:param epi_tile: The epilogue tiler.
:type epi_tile: cute.Tile
:param sC: The shared memory tensor C.
:return: A tuple containing the appropriate copy atom and partitioned
source and destination tensors for the store operation.
:rtype: tuple[cute.CopyAtom, cute.Tensor, cute.Tensor, cute.CopyAtom, cute.Tensor]
"""
gC_epi_tma = cute.flat_divide(
gC_mnl_tma[((None, None), 0, 0, None, None, None)], epi_tile
)
gC_epi_simt = cute.flat_divide(
gC_mnl_simt[((None, None), 0, 0, None, None, None)], epi_tile
)
# TMA store
sC_for_tma_partition = cute.group_modes(sC, 0, 2)
gC_for_tma_partition = cute.group_modes(gC_epi_tma, 0, 2)
# ((ATOM_V, REST_V), EPI_M, EPI_N)
# ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL)
bSG_sC, bSG_gC = cpasync.tma_partition(
tma_atom_c,
0,
cute.make_layout(1),
sC_for_tma_partition,
gC_for_tma_partition,
)
# SIMT Store
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL)
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
tTR_gC = thr_copy_t2r.partition_D(gC_epi_simt)
simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.c_dtype)
return tma_atom_c, bSG_sC, bSG_gC, simt_atom, tTR_gC
def epilog_smem_copy_and_partition(
self,
tiled_copy_t2r: cute.TiledCopy,
tTR_rC: cute.Tensor,
tidx: cutlass.Int32,
sC: cute.Tensor,
) -> tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
"""
Partitions source and destination tensors for a shared memory store.
This method generates a tiled copy for storing results to shared memory
and partitions the source (register) and destination (shared memory)
tensors accordingly.
:param tiled_copy_t2r: The tiled copy operation for tmem to register copy.
:param tTR_rC: The partitioned accumulator tensor.
:param tidx: The thread index in epilogue warp groups.
:param sC: The shared memory tensor to be copied and partitioned.
:return: A tuple containing the tiled copy for the store operation and
the partitioned source and destination tensors.
"""
copy_atom_r2s = sm100_utils.get_smem_store_op(
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
)
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
# (R2S, R2S_M, R2S_N, PIPE_D)
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
tRS_sC = thr_copy_r2s.partition_D(sC)
# (R2S, R2S_M, R2S_N)
tRS_rC = tiled_copy_r2s.retile(tTR_rC)
return tiled_copy_r2s, tRS_rC, tRS_sC
def epilog_tmem_copy_and_partition(
self,
tidx: cutlass.Int32,
tAcc: cute.Tensor,
gC_mnl: cute.Tensor,
epi_tile: cute.Tile,
use_2cta_instrs: Union[cutlass.Boolean, bool],
) -> tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
"""
Partitions source and destination tensors for a tensor memory load.
This method generates a tiled copy for loading accumulators from tensor
memory and partitions the source (tensor memory) and destination
(register) tensors accordingly.
:param tidx: The thread index in epilogue warp groups.
:param tAcc: The accumulator tensor to be copied and partitioned.
:param gC_mnl: The global tensor C.
:param epi_tile: The epilogue tiler.
:param use_2cta_instrs: Whether use_2cta_instrs is enabled.
:return: A tuple containing the tiled copy for the load operation and
the partitioned source and destination tensors.
"""
# Make tiledCopy for tensor memory load
copy_atom_t2r = sm100_utils.get_tmem_load_op(
self.cta_tile_shape_mnk,
self.c_layout,
self.c_dtype,
self.acc_dtype,
epi_tile,
use_2cta_instrs,
)
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE)
tAcc_epi = cute.flat_divide(
tAcc[((None, None), 0, 0, None)],
epi_tile,
)
# (EPI_TILE_M, EPI_TILE_N)
tiled_copy_t2r = tcgen05.make_tmem_copy(
copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]
)
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE)
tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL)
gC_mnl_epi = cute.flat_divide(
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
)
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL)
tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi)
# (T2R, T2R_M, T2R_N)
tTR_rAcc = cute.make_rmem_tensor(
tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype
)
return tiled_copy_t2r, tTR_tAcc, tTR_rAcc
@staticmethod
def align_up(x: int, align: int) -> int:
"""Align x up to the nearest multiple of align."""
return (x + align - 1) // align * align
@staticmethod
def _compute_stages_and_tmem_cols(
tiled_mma: cute.TiledMma,
mma_tiler_mnk: tuple[int, int, int],
cta_tile_shape_mnk: tuple[int, int, int],
epi_tile: cute.Tile,
a_dtype: type[cutlass.Numeric],
b_dtype: type[cutlass.Numeric],
c_dtype: type[cutlass.Numeric],
c_layout: utils.LayoutEnum,
transform_a_source: tcgen05.OperandSource,
scale_granularity_m: int,
scale_granularity_k: int,
smem_buffer_align_bytes: int,
scale_mode: TransformMode,
) -> tuple[int, int, int, int, int, int, int, int]:
"""
Compute pipeline stages and TMEM column allocation configurations.
This method calculates the number of pipeline stages for different operations
(tile_info, load2trans, trans2mma, accumulator, etc.) and determines TMEM column allocation
based on available memory resources and tile configuration.
:param tiled_mma: The tiled MMA object defining the core computation.
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler.
:type mma_tiler_mnk: tuple[int, int, int]
:param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile.
:type cta_tile_shape_mnk: tuple[int, int, int]
:param epi_tile: The epilogue tile shape.
:type epi_tile: cute.Tile
:param a_dtype: Data type of operand A.
:type a_dtype: type[cutlass.Numeric]
:param b_dtype: Data type of operand B.
:type b_dtype: type[cutlass.Numeric]
:param c_dtype: Data type of operand C.
:type c_dtype: type[cutlass.Numeric]
:param c_layout: Layout enum of operand C.
:type c_layout: utils.LayoutEnum
:param transform_a_source: The source of the transformed A tensor.
:type transform_a_source: tcgen05.OperandSource
:param scale_granularity_m: The granularity of the scale tensor along the M mode.
:type scale_granularity_m: int
:param scale_granularity_k: The granularity of the scale tensor along the K mode.
:type scale_granularity_k: int
:param smem_buffer_align_bytes: The alignment of the shared memory buffer.
:type smem_buffer_align_bytes: int
:param scale_mode: The transform mode.
:type scale_mode: TransformMode
:return: A tuple containing the number of stages for:
(load2trans, scale_load2trans, transform2mma, accumulator, c, tile_info, tmem_acc_cols, tmem_a_cols)
:rtype: tuple[int, int, int, int, int, int, int]
- num_load2trans_stage: Stages for load-to-transform A and B tensors pipeline
- num_scale_load2trans_stage: Stages for scale load-to-transform A tensor pipeline
- num_trans2mma_stage: Stages for transform-to-MMA pipeline
- num_acc_stage: Stages for accumulator-to-epilogue pipeline
- num_c_stage: Stages for epilogue-to-output C pipeline
- num_tile_info_stage: Stages for buffers storing tile info
- num_acc_tmem_cols: TMEM columns for accumulator
- num_a_tmem_cols: TMEM columns for transformed A tensor
"""
# Compute tmem columns required for accumulator
acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2])
tCtAcc_stage1 = tiled_mma.make_fragment_C(cute.append(acc_shape, 1))
num_tmem_acc_col_per_stage = utils.get_num_tmem_alloc_cols(tCtAcc_stage1, True)
# Heuristic to decide the number of stages for accumulator
sm100_tmem_columns = cute.arch.SM100_TMEM_CAPACITY_COLUMNS
accumulator_stage_count = sm100_tmem_columns // num_tmem_acc_col_per_stage
if transform_a_source == tcgen05.OperandSource.TMEM:
if num_tmem_acc_col_per_stage < 128:
accumulator_stage_count = 3
elif num_tmem_acc_col_per_stage < 256:
accumulator_stage_count = 2
else:
accumulator_stage_count = 1
# transformed A in 16bit, thus 1 tmem column could hold 2 elements
num_elts_per_tmem_col = 32 // tiled_mma.op.a_dtype.width
num_tmem_cols_a_per_stage = GroupedMixedInputGemmKernel.align_up(
(
cta_tile_shape_mnk[2] // num_elts_per_tmem_col
if transform_a_source == tcgen05.OperandSource.TMEM
else 0
),
4,
)
bytes_per_pipeline_stage = 16
# By default, we use 2 stages for tile info
num_tile_info_stage = 2
tile_info_bytes = (
cute.size_in_bytes(cute.Int32, cute.make_layout((4, num_tile_info_stage)))
+ bytes_per_pipeline_stage * num_tile_info_stage
)
c_stage_count = 2
c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(
c_dtype,
c_layout,
epi_tile,
1,
)
c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one)
c_bytes = c_bytes_per_stage * c_stage_count
smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
if scale_mode == TransformMode.ConvertOnly:
scale_load2trans_stage_count = 0
a_scale_bytes_per_stage = 0
else:
# Ensure we have 4 buffers for scale tiles needed for 1 CTA tile
a_scale_k_mode = max(cta_tile_shape_mnk[2] // scale_granularity_k, 1)
a_scale_m_mode = max(cta_tile_shape_mnk[0] // scale_granularity_m, 1)
scale_load2trans_stage_count = 4
a_scale_bytes_per_stage = GroupedMixedInputGemmKernel.align_up(
cute.size_in_bytes(
tiled_mma.op.a_dtype,
cute.make_layout((a_scale_m_mode, a_scale_k_mode)),
),
smem_buffer_align_bytes,
)
a_scale_bytes = (
a_scale_bytes_per_stage + bytes_per_pipeline_stage
) * scale_load2trans_stage_count
caveout_smem_bytes = (
bytes_per_pipeline_stage * accumulator_stage_count
+ a_scale_bytes
+ c_bytes
+ tile_info_bytes
)
# Compute transform stages if A is in TMEM
num_tmem_acc_cols = GroupedMixedInputGemmKernel.align_up(
accumulator_stage_count * num_tmem_acc_col_per_stage, 4
)
transform2mma_stage_count_a_source_tmem_potential = (
(sm100_tmem_columns - num_tmem_acc_cols) // num_tmem_cols_a_per_stage
if transform_a_source == tcgen05.OperandSource.TMEM
else -1
)
if (
transform_a_source == tcgen05.OperandSource.TMEM
and transform2mma_stage_count_a_source_tmem_potential <= 0
):
raise ValueError("Not enough TMEM capacity for selected tile size")
a_load_bytes_per_stage = GroupedMixedInputGemmKernel.align_up(
cute.size_in_bytes(
a_dtype,
cute.make_layout((cta_tile_shape_mnk[0], cta_tile_shape_mnk[2])),
),
smem_buffer_align_bytes,
)
b_load_bytes_per_stage = GroupedMixedInputGemmKernel.align_up(
cute.size_in_bytes(
b_dtype,
cute.make_layout(
(
cta_tile_shape_mnk[1] // cute.size(tiled_mma.thr_id),
cta_tile_shape_mnk[2],
)
),
),
smem_buffer_align_bytes,
)
ab_load_bytes_per_stage = (
a_load_bytes_per_stage
+ b_load_bytes_per_stage
+ 2 * bytes_per_pipeline_stage
)
a_transform_bytes_per_stage = (
GroupedMixedInputGemmKernel.align_up(
cute.size_in_bytes(
tiled_mma.op.a_dtype,
cute.make_layout((cta_tile_shape_mnk[0], cta_tile_shape_mnk[2])),
),
smem_buffer_align_bytes,
)
if transform_a_source == tcgen05.OperandSource.SMEM
else 0
)
a_transform_bytes_per_stage = (
a_transform_bytes_per_stage + bytes_per_pipeline_stage
)
transform2mma_stage_count_a_source_smem_potential = (
smem_capacity - caveout_smem_bytes
) // (ab_load_bytes_per_stage + a_transform_bytes_per_stage)
transform2mma_stage_count = (
min(
transform2mma_stage_count_a_source_tmem_potential,
transform2mma_stage_count_a_source_smem_potential,
)
if transform_a_source == tcgen05.OperandSource.TMEM
else transform2mma_stage_count_a_source_smem_potential
)
load2transform_stage_count = (
smem_capacity
- caveout_smem_bytes
- (transform2mma_stage_count * a_transform_bytes_per_stage)
) // ab_load_bytes_per_stage
if (
load2transform_stage_count < 2
or transform2mma_stage_count < 2
or accumulator_stage_count < 1
):
raise ValueError("Not enough SMEM or TMEM capacity for selected tile size")
num_tmem_a_cols = transform2mma_stage_count * num_tmem_cols_a_per_stage
# Check if we can increase c_stage_count with leftover smem
c_stage_count += (
smem_capacity
- load2transform_stage_count * ab_load_bytes_per_stage
- transform2mma_stage_count * a_transform_bytes_per_stage
- scale_load2trans_stage_count * a_scale_bytes_per_stage
- c_bytes
) // c_bytes_per_stage
return (
load2transform_stage_count,
scale_load2trans_stage_count,
transform2mma_stage_count,
accumulator_stage_count,
c_stage_count,
num_tile_info_stage,
num_tmem_acc_cols,
num_tmem_a_cols,
)
@staticmethod
def _compute_grid(
c: cute.Tensor,
cta_tile_shape_mnk: tuple[int, int, int],
cluster_shape_mn: tuple[int, int],
max_active_clusters: cutlass.Constexpr,
) -> tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]]:
"""
Use persistent tile scheduler to compute the grid size for the output tensor C.
"""
c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0))
gc = cute.zipped_divide(c, tiler=c_shape)
num_ctas_mnl = gc[(0, (None, None, None))].shape
cluster_shape_mnl = (*cluster_shape_mn, 1)
tile_sched_params = utils.PersistentTileSchedulerParams(
num_ctas_mnl, cluster_shape_mnl
)
grid = (cluster_shape_mn[0], cluster_shape_mn[1], max_active_clusters)
return tile_sched_params, grid
def is_valid_tensor_alignment(
m: int,
n: int,
k: int,
a_dtype: type[cutlass.Numeric],
b_dtype: type[cutlass.Numeric],
c_dtype: type[cutlass.Numeric],
scale_dtype: type[cutlass.Numeric],
a_major: str,
b_major: str,
c_major: str,
mma_tiler_mnk: tuple[int, int, int],
use_2cta_instrs: bool,
cluster_shape_mn: tuple[int, int],
scale_granularity_m: int,
scale_granularity_k: int,
) -> bool:
"""
Check if the tensor alignments are valid for the given problem size and data types.
"""
def check_contiguous_16B_alignment(dtype, is_mode0_major, tensor_shape):
major_mode_idx = 0 if is_mode0_major else 1
num_major_elements = tensor_shape[major_mode_idx]
num_contiguous_elements = 16 * 8 // dtype.width
return num_major_elements % num_contiguous_elements == 0
if not (
check_contiguous_16B_alignment(a_dtype, a_major == "m", (m, k))
and check_contiguous_16B_alignment(b_dtype, b_major == "n", (n, k))
and check_contiguous_16B_alignment(c_dtype, c_major == "m", (m, n))
and (
scale_granularity_k == 0
or check_contiguous_16B_alignment(
b_dtype, True, (m, k // scale_granularity_k)
)
)
):
return False
# Check if scale tensor matches the TMA load 128B alignment requirement
cta_tile_shape_mnk = (
mma_tiler_mnk[0] // (2 if use_2cta_instrs else 1),
mma_tiler_mnk[1],
mma_tiler_mnk[2],
)
if (
scale_granularity_m > 0
and (cta_tile_shape_mnk[0] // cluster_shape_mn[1] // scale_granularity_m)
* (scale_dtype.width // 8)
< 128
):
return False
return True
def is_valid_mma_tiler_and_cluster_shape(
mma_tiler: tuple[int, int, int],
cluster_shape_mn: tuple[int, int],
use_2cta_instrs: bool,
) -> bool:
"""
Check if the MMA tiler and cluster shape are valid for the given problem size.
"""
if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0:
return False
if (mma_tiler[0] // (2 if use_2cta_instrs else 1)) not in [64, 128]:
return False
return True
def can_implement(
mnkl: tuple[int, int, int, int],
a_dtype: type[cutlass.Numeric],
b_dtype: type[cutlass.Numeric],
c_dtype: type[cutlass.Numeric],
a_major: str,
b_major: str,
c_major: str,
scale_granularity_m: int,
scale_granularity_k: int,
mma_tiler: tuple[int, int, int],
cluster_shape_mn: tuple[int, int],
use_2cta_instrs: bool,
) -> bool:
"""
Check if the kernel can be implemented for the given tensor shapes and data types.
"""
m, n, k, l = mnkl
if not GroupedMixedInputGemmKernel.is_valid_mma_tiler_and_cluster_shape(
mma_tiler, cluster_shape_mn, use_2cta_instrs
):
return False
if not mixed_input_utils.is_valid_scale_granularity(
scale_granularity_m, scale_granularity_k, a_dtype, k, mma_tiler[2]
):
return False
if not GroupedMixedInputGemmKernel.is_valid_tensor_alignment(
m,
n,
k,
a_dtype,
b_dtype,
c_dtype,
b_dtype,
a_major,
b_major,
c_major,
mma_tiler,
use_2cta_instrs,
cluster_shape_mn,
scale_granularity_m,
scale_granularity_k,
):
return False
return True
def create_cumsum_tensor(
num_groups: int,
fused_n: int,
alignment: int,
uniform_distribution: bool = False,
) -> tuple[cute.Tensor, torch.Tensor]:
"""
Create a tensor of shape (num_groups + 1) recording the cumulative sum of the elements in each group.
"""
assert fused_n % alignment == 0, "fused_n must be divisible by alignment"
if uniform_distribution:
# keep a uniform distribution for debug and performance collection
group_counts = torch.tensor([fused_n // num_groups] * num_groups)
else:
# sample group sizes with equal probability for each group
probs = torch.ones(num_groups) / num_groups
group_sizes = torch.multinomial(probs, fused_n // alignment, replacement=True)
group_counts = torch.bincount(group_sizes, minlength=num_groups) * alignment
print(group_counts.tolist())
# Create cumulative sum
cumsum_torch = torch.cat([torch.tensor([0]), group_counts.cumsum(0)])
print(cumsum_torch.tolist())
cumsum_tensor, _ = cutlass_torch.cute_tensor_like(
cumsum_torch, cutlass.Int32, is_dynamic_layout=False
)
return cumsum_tensor, cumsum_torch.to("cpu")
def create_i4_tensor_and_scale(
l: int,
m: int,
k: int,
is_m_major: bool,
dtype: type[cutlass.Numeric],
scale_granularity_m: int,
scale_granularity_k: int,
is_dynamic_layout: bool = True,
init_config: tuple = (
cutlass_torch.TensorInitType.RANDOM,
cutlass_torch.RandomInitConfig(min_val=-7, max_val=6),
),
divisibility: int = 16,
transformed_dtype: Optional[type[cutlass.Numeric]] = None,
) -> tuple[
cute.Tensor,
torch.Tensor,
torch.Tensor,
cute.Tensor,
torch.Tensor,
torch.Tensor,
]:
"""
Create quantized 4-bit tensor and corresponding scale tensor.
"""
lb_4b = -8 if dtype == cutlass.Int4 else 0
up_4b = 7 if dtype == cutlass.Int4 else 15
if not (
init_config[0] == cutlass_torch.TensorInitType.RANDOM
or init_config[0] == cutlass_torch.TensorInitType.SCALAR
):
raise ValueError(
"Only random and scalar initialization is supported for 4bit data type"
)
# Construct reference tensor in f32
ref_fp32 = cutlass_torch.matrix(l, m, k, is_m_major, cutlass.Float32, *init_config)
# Generate scale data and perform quantization
num_scales = k // scale_granularity_k
ref = ref_fp32.to(dtype=cutlass_torch.dtype(transformed_dtype)).reshape(
m, num_scales, scale_granularity_k, l
)
# Get elements with maximum absolute value to compute scaling factors
a_max = (
torch.maximum(ref / up_4b, ref / lb_4b)
if dtype == cutlass.Int4
else torch.maximum(ref / up_4b)
)
a_scales, _ = torch.max(a_max, dim=2, keepdim=True)
a_scale_inv = torch.where(a_scales == 0, 0, 1 / a_scales)
a_quant = ref * a_scale_inv
# Convert values to integer to avoid computation errors
a_quant = a_quant.to(dtype=torch.int32).reshape((m, k, l)).to(dtype=torch.float32)
# Construct cute scale tensor
a_scales = a_scales.random_(-3, 3).reshape((m, num_scales, l))
# Scale tensor is always m-major
a_scales = a_scales.permute(2, 1, 0).contiguous().permute(2, 1, 0).to(device="cuda")
# Construct A quantized tensor
cute_a_quant_tensor, torch_a_quant_tensor = cutlass_torch.cute_tensor_like(
a_quant,
dtype,
is_dynamic_layout=is_dynamic_layout,
assumed_align=divisibility,
)
cute_scale_tensor = from_dlpack(a_scales, assumed_align=divisibility)
for i, stride in enumerate(a_scales.stride()):
if stride == 1:
leading_dim = i
break
if is_dynamic_layout:
cute_scale_tensor = cute_scale_tensor.mark_layout_dynamic(
leading_dim=leading_dim
)
return (
cute_a_quant_tensor,
torch_a_quant_tensor,
a_quant.to("cpu"),
cute_scale_tensor,
a_scales,
a_scales.to("cpu"),
)
def create_tensor_a(
l: int,
m: int,
k: int,
a_major: str,
a_dtype: type[cutlass.Numeric],
scale_granularity_m: int = 0,
scale_granularity_k: int = 0,
transformed_dtype: Optional[type[cutlass.Numeric]] = None,
) -> tuple[cute.Tensor, Optional[cute.Tensor], torch.Tensor, Optional[torch.Tensor]]:
"""
Create tensor A and scale tensor.
"""
a_scale_tensor = None
a_scale_torch_cpu = None
if a_dtype in (cutlass.Int4,):
(
a_tensor,
a_torch_gpu,
a_torch_cpu,
a_scale_tensor,
a_scale_torch_gpu,
a_scale_torch_cpu,
) = create_i4_tensor_and_scale(
l,
m,
k,
a_major == "m",
a_dtype,
scale_granularity_m,
scale_granularity_k,
divisibility=mixed_input_utils.get_divisibility(m if a_major == "m" else k),
transformed_dtype=transformed_dtype,
)
else:
a_torch_cpu = cutlass_torch.matrix(
l,
m,
k,
a_major == "m",
a_dtype,
)
a_tensor, _ = cutlass_torch.cute_tensor_like(
a_torch_cpu,
a_dtype,
is_dynamic_layout=True,
assumed_align=mixed_input_utils.get_divisibility(
m if a_major == "m" else k
),
)
return a_tensor, a_scale_tensor, a_torch_cpu, a_scale_torch_cpu
def create_tensors(
l: int,
m: int,
n: int,
k: int,
a_major: str,
b_major: str,
c_major: str,
a_dtype: type[cutlass.Numeric],
b_dtype: type[cutlass.Numeric],
c_dtype: type[cutlass.Numeric],
scale_granularity_m: int = 0,
scale_granularity_k: int = 0,
uniform_group_sizes: bool = False,
) -> tuple:
"""
Create all input and output tensors for the mixed-input GEMM.
"""
torch.manual_seed(2025)
a_tensor, a_scale_tensor, a_torch_cpu, a_scale_torch_cpu = create_tensor_a(
l,
m,
k,
a_major,
a_dtype,
scale_granularity_m,
scale_granularity_k,
b_dtype,
)
# In GROUP mode, l specifies the number of groups. We'll fuse group into the n mode for tensor B and C.
# Batch mode will be set to 1.
num_groups = l
fused_n = n * num_groups
b_torch_cpu = cutlass_torch.matrix(
1, # batch=1
fused_n,
k,
b_major == "n",
b_dtype,
cutlass_torch.TensorInitType.RANDOM,
cutlass_torch.RandomInitConfig(min_val=-10, max_val=10),
)
b_tensor, _ = cutlass_torch.cute_tensor_like(
b_torch_cpu,
b_dtype,
is_dynamic_layout=True,
assumed_align=mixed_input_utils.get_divisibility(n if b_major == "n" else k),
)
c_torch_cpu = cutlass_torch.matrix(
1, # batch=1
m,
fused_n,
c_major == "m",
c_dtype,
)
c_tensor, c_torch_gpu = cutlass_torch.cute_tensor_like(
c_torch_cpu,
c_dtype,
is_dynamic_layout=True,
assumed_align=mixed_input_utils.get_divisibility(m if c_major == "m" else n),
)
c_tensor = c_tensor.mark_compact_shape_dynamic(
mode=(0 if c_major == "m" else 1),
stride_order=(2, 1, 0) if c_major == "m" else (2, 0, 1),
divisibility=mixed_input_utils.get_divisibility(m if c_major == "m" else n),
)
# We need to ensure mode N satisfies 16B alignment for each group
alignment_n = 16 * 8 // b_dtype.width
cumsum_tensor, cumsum_torch = create_cumsum_tensor(
num_groups, fused_n, alignment_n, uniform_distribution=uniform_group_sizes
)
return (
a_tensor,
a_scale_tensor,
b_tensor,
cumsum_tensor,
c_tensor,
a_torch_cpu,
a_scale_torch_cpu,
b_torch_cpu,
cumsum_torch,
c_torch_gpu,
)
def compare(
a_torch_cpu: torch.Tensor,
b_torch_cpu: torch.Tensor,
a_scale_torch_cpu: Optional[torch.Tensor],
cumsum_torch_cpu: torch.Tensor,
c_torch_gpu: torch.Tensor,
c_dtype: type[cutlass.Numeric],
tolerance: float,
) -> None:
"""
Compare kernel result with reference computation.
"""
kernel_result = c_torch_gpu.cpu()
assert kernel_result.shape[2] == 1, "batch mode must be 1"
kernel_result = kernel_result.reshape(
kernel_result.shape[0], kernel_result.shape[1]
)
# Compute reference result
a_for_gemm = a_torch_cpu
if a_scale_torch_cpu is not None:
scale_shape = a_scale_torch_cpu.shape
a_shape = a_torch_cpu.shape
a_scale_torch_cpu = a_scale_torch_cpu.to(dtype=torch.float32).reshape(
scale_shape[0], scale_shape[1], 1, scale_shape[2]
)
a_torch_cpu = a_torch_cpu.to(dtype=torch.float32).reshape(
a_torch_cpu.shape[0], scale_shape[1], -1, a_torch_cpu.shape[2]
)
a_for_gemm = (a_torch_cpu * a_scale_torch_cpu).reshape(a_shape)
# A in (m, k, l), b in (n, k), c in (m, n)
assert cumsum_torch_cpu.shape[0] == a_for_gemm.shape[-1] + 1, (
"cumsum tensor must have one more element than a_for_gemm"
)
assert b_torch_cpu.shape[2] == 1, (
"b_torch_cpu must have a singleton dimension in the last position"
)
prev_idx = 0
ref = torch.zeros((a_for_gemm.shape[0], b_torch_cpu.shape[0]), dtype=torch.float32)
for group_idx in range(1, cumsum_torch_cpu.shape[0]):
# No computation for current group
if cumsum_torch_cpu[group_idx] == prev_idx:
continue
# Get A slice for current group
sliced_a = a_for_gemm[:, :, group_idx - 1]
# Get B slice for current group
sliced_b = b_torch_cpu[prev_idx : cumsum_torch_cpu[group_idx], :, 0]
sliced_ref = torch.einsum(
"mk,nk->mn",
sliced_a.to(dtype=torch.float32),
sliced_b.to(dtype=torch.float32),
)
ref[:, prev_idx : cumsum_torch_cpu[group_idx]] = sliced_ref
prev_idx = cumsum_torch_cpu[group_idx]
# Convert ref to c_dtype
_, ref_torch_gpu = cutlass_torch.cute_tensor_like(
ref, c_dtype, is_dynamic_layout=True, assumed_align=16
)
ref_result = ref_torch_gpu.cpu()
# Assert close results
torch.testing.assert_close(kernel_result, ref_result, atol=tolerance, rtol=1e-05)
def run(
mnkl: tuple[int, int, int, int],
scale_granularity_m: int,
scale_granularity_k: int,
a_dtype: type[cutlass.Numeric],
b_dtype: type[cutlass.Numeric],
c_dtype: type[cutlass.Numeric],
acc_dtype: type[cutlass.Numeric],
a_major: str,
b_major: str,
c_major: str,
mma_tiler_mnk: tuple[int, int, int],
cluster_shape_mn: tuple[int, int],
use_2cta_instrs: bool,
tolerance: float,
warmup_iterations: int = 0,
iterations: int = 1,
skip_ref_check: bool = False,
uniform_group_sizes: bool = False,
use_cold_l2: bool = False,
**kwargs,
) -> None:
"""
Run the mixed-input GEMM kernel with specified parameters.
This function creates tensors, validates parameters, executes the kernel,
optionally compares results with a reference implementation and reports
kernel execution time.
"""
m, n, k, l = mnkl
if not torch.cuda.is_available():
raise ValueError("CUDA is not available")
# Check if given configuration is supported
if not GroupedMixedInputGemmKernel.can_implement(
mnkl,
a_dtype,
b_dtype,
c_dtype,
a_major,
b_major,
c_major,
scale_granularity_m,
scale_granularity_k,
mma_tiler_mnk,
cluster_shape_mn,
use_2cta_instrs,
):
raise ValueError("GEMM configuration not supported")
# Get current CUDA stream from PyTorch
torch_stream = torch.cuda.current_stream()
# Get the raw stream pointer as a CUstream
current_stream = cuda.CUstream(torch_stream.cuda_stream)
group_count = l
mixed_input_gemm = GroupedMixedInputGemmKernel(
scale_granularity_m,
scale_granularity_k,
acc_dtype,
use_2cta_instrs,
mma_tiler_mnk,
cluster_shape_mn,
group_count,
)
(
a_tensor,
a_scale_tensor,
b_tensor,
cumsum_tensor,
c_tensor,
a_torch_cpu,
a_scale_torch_cpu,
b_torch_cpu,
cumsum_torch_cpu,
c_torch_gpu,
) = create_tensors(
l,
m,
n,
k,
a_major,
b_major,
c_major,
a_dtype,
b_dtype,
c_dtype,
scale_granularity_m,
scale_granularity_k,
uniform_group_sizes,
)
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(
cluster_shape_mn[0] * cluster_shape_mn[1],
)
compiled_kernel = cute.compile(
mixed_input_gemm,
a_tensor,
a_scale_tensor,
b_tensor,
cumsum_tensor,
c_tensor,
max_active_clusters,
current_stream,
)
if not skip_ref_check:
compiled_kernel(
a_tensor,
a_scale_tensor,
b_tensor,
cumsum_tensor,
c_tensor,
current_stream,
)
compare(
a_torch_cpu,
b_torch_cpu,
a_scale_torch_cpu,
cumsum_torch_cpu,
c_torch_gpu,
c_dtype,
tolerance,
)
# Early return if no performance measurement is needed
if iterations <= 0:
return
def generate_tensors():
a_tensor, a_scale_tensor, a_torch_cpu, a_scale_torch_cpu = create_tensor_a(
l,
m,
k,
a_major,
a_dtype,
scale_granularity_m,
scale_granularity_k,
b_dtype,
)
num_groups = l
fused_n = n * num_groups
b_torch_cpu = cutlass_torch.matrix(
1,
fused_n,
k,
b_major == "n",
b_dtype,
cutlass_torch.TensorInitType.RANDOM,
cutlass_torch.RandomInitConfig(min_val=-10, max_val=10),
)
b_tensor, _ = cutlass_torch.cute_tensor_like(
b_torch_cpu,
b_dtype,
is_dynamic_layout=True,
assumed_align=mixed_input_utils.get_divisibility(
n if b_major == "n" else k
),
)
c_torch_cpu = cutlass_torch.matrix(1, m, fused_n, c_major == "m", c_dtype)
c_tensor, c_torch_gpu = cutlass_torch.cute_tensor_like(
c_torch_cpu,
c_dtype,
is_dynamic_layout=True,
assumed_align=mixed_input_utils.get_divisibility(
m if c_major == "m" else n
),
)
c_tensor = c_tensor.mark_compact_shape_dynamic(
mode=(0 if c_major == "m" else 1),
stride_order=(2, 1, 0) if c_major == "m" else (2, 0, 1),
divisibility=mixed_input_utils.get_divisibility(m if c_major == "m" else n),
)
alignment_n = 16 * 8 // b_dtype.width
cumsum_tensor, cumsum_torch_cpu = create_cumsum_tensor(
num_groups, fused_n, alignment_n, uniform_distribution=uniform_group_sizes
)
return testing.JitArguments(
a_tensor, a_scale_tensor, b_tensor, cumsum_tensor, c_tensor, current_stream
)
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (
a_torch_cpu.numel() * a_torch_cpu.element_size()
+ b_torch_cpu.numel() * b_torch_cpu.element_size()
+ c_torch_gpu.numel() * c_torch_gpu.element_size()
+ a_scale_torch_cpu.numel() * a_scale_torch_cpu.element_size()
if a_scale_torch_cpu is not None
else 0
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
exec_time = testing.benchmark(
compiled_kernel,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=current_stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
return exec_time # Return execution time in microseconds
if __name__ == "__main__":
def parse_comma_separated_ints(s: str) -> tuple[int, ...]:
try:
return tuple(int(x.strip()) for x in s.split(","))
except ValueError:
raise argparse.ArgumentTypeError(
"Invalid format. Expected comma-separated integers."
)
parser = argparse.ArgumentParser()
parser.add_argument(
"--mnkl", type=parse_comma_separated_ints, default=(128, 128, 128, 1)
)
parser.add_argument(
"--mma_tiler_mnk", type=parse_comma_separated_ints, default=(128, 128, 128)
)
parser.add_argument(
"--cluster_shape_mn", type=parse_comma_separated_ints, default=(1, 1)
)
parser.add_argument(
"--use_2cta_instrs",
action="store_true",
help="Enable 2CTA MMA instructions feature",
)
parser.add_argument(
"--a_dtype",
type=cutlass.dtype,
default=cutlass.Int4,
choices=[cutlass.Int8, cutlass.Uint8, cutlass.Int4],
)
parser.add_argument(
"--b_dtype",
type=cutlass.dtype,
default=cutlass.BFloat16,
choices=[cutlass.BFloat16, cutlass.Float16],
)
parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.BFloat16)
parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32)
parser.add_argument("--a_major", choices=["k", "m"], type=str, default="m")
parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n")
parser.add_argument(
"--scale_granularity_m",
type=int,
default=1,
help="Scale granularity along M dimension.",
)
parser.add_argument(
"--scale_granularity_k",
type=int,
default=128,
help="Scale granularity along K dimension.",
)
parser.add_argument(
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
)
parser.add_argument(
"--warmup_iterations", type=int, default=0, help="Warmup iterations"
)
parser.add_argument(
"--iterations",
type=int,
default=1,
help="Number of iterations to run the kernel",
)
parser.add_argument(
"--skip_ref_check", action="store_true", help="Skip reference checking"
)
parser.add_argument(
"--uniform_group_sizes", action="store_true", help="Use uniform group sizes"
)
args = parser.parse_args()
run(
args.mnkl,
args.scale_granularity_m,
args.scale_granularity_k,
args.a_dtype,
args.b_dtype,
args.c_dtype,
args.acc_dtype,
args.a_major,
args.b_major,
args.c_major,
args.mma_tiler_mnk,
args.cluster_shape_mn,
args.use_2cta_instrs,
args.tolerance,
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.uniform_group_sizes,
)
print("PASS")