Files
cutlass/examples/python/CuTeDSL/hopper/grouped_gemm.py
2026-04-07 12:16:05 -04:00

2422 lines
92 KiB
Python

# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import functools
import os
from typing import List, Optional, Tuple, Type
from inspect import isclass
import math
import cuda.bindings.driver as cuda
import torch
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
import cutlass.pipeline as pipeline
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
import cutlass.utils as utils
import cutlass.utils.hopper_helpers as sm90_utils
import cutlass.torch as cutlass_torch
from cutlass.cutlass_dsl import extract_mlir_values, new_from_mlir_values
from cutlass.cute.core import (
AddressSpace as _CuteAddressSpace,
make_ptr as _cute_make_ptr,
)
from cutlass._mlir.dialects import nvvm as _nvvm_d
from cutlass._mlir.dialects._nvvm_enum_gen import (
CpAsyncBulkTensorLoadMode as _CpAsyncBulkTensorLoadMode,
)
from cutlass.cutlass_dsl import dsl_user_op as _dsl_user_op, T as _T
from cutlass.cute.typing import Int32 as _Int32, Pointer as _Pointer
def _env_flag(name: str, default: bool) -> bool:
val = os.getenv(name)
if val is None:
return default
return val.strip().lower() in {"1", "true", "yes", "on"}
# Debug switch: force `cute.copy` path for non-mcast loads.
_ENABLE_NVVM_NON_MCAST_LOAD = not _env_flag("GROUPED_GEMM_FORCE_CUTE_COPY", False)
# Experimental switch: enable true SMEM tensor map update/publish path in
# _FixedTensorMapManager for investigation.
_ENABLE_TRUE_SMEM_TMAP = _env_flag("GROUPED_GEMM_ENABLE_TRUE_SMEM_TMAP", False)
_ENABLE_TRUE_SMEM_TMAP_PREUPDATE = _env_flag(
"GROUPED_GEMM_ENABLE_TRUE_SMEM_TMAP_PREUPDATE", True
)
_ENABLE_TRUE_SMEM_TMAP_PUBLISH = _env_flag(
"GROUPED_GEMM_ENABLE_TRUE_SMEM_TMAP_PUBLISH", True
)
"""
Grouped GEMM (C_g = A_g * B_g for each group g) for the NVIDIA Hopper architecture
using CuTe DSL.
This kernel extends hopper/dense_gemm_persistent.py with per-group TMA tensor map updates
and a group-aware persistent tile scheduler (StaticPersistentGroupTileScheduler).
Key features:
- WGMMA + TMA + persistent warp-specialized kernel (inherited from dense_gemm_persistent)
- Per-group A/B/C TMA descriptor updates (tensor map) via GMEM or SMEM mode
- DMA warp group: loads A/B tiles, updates tensor maps A/B on group boundary
- MMA warp group: performs WGMMA, updates tensor map C on group boundary, stores C
To run:
.. code-block:: bash
python hopper/grouped_gemm.py \\
--num_groups 4 \\
--problem_sizes_mnkl "(8192,1280,32,1),(16,384,1536,1),(640,1280,16,1),(640,160,16,1)" \\
--tile_shape_mn 128,256 --cluster_shape_mn 1,1 \\
--a_dtype Float16 --b_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \\
--a_major k --b_major k --c_major n \\
--tensormap_update_mode SMEM
Constraints (same as dense_gemm_persistent.py plus):
* Only fp16/bf16 inputs are supported for grouped mode
* l (batch) must be 1 for each group
* CTA tile M: 64/128, N: 64/128/256
* Cluster shape M/N: power of 2, total <= 4
* Contiguous dim must be 16-byte aligned
Debug environment options:
* `GROUPED_GEMM_FORCE_CUTE_COPY=1`
Disable the non-mcast NVVM TMA load path and always use `cute.copy`.
"""
@_dsl_user_op
def _tma_load_ab_nvvm_no_mcast(
k_coord: _Int32,
m_coord: _Int32,
n_coord: _Int32,
desc_a: _Pointer,
desc_b: _Pointer,
smem_a: _Pointer,
smem_b: _Pointer,
mbar: _Pointer,
*,
loc=None,
ip=None,
) -> None:
"""Issue TMA A + TMA B loads via NVVM dialect ops for the non-mcast case.
By passing the elect_sync predicate directly to the NVVM TMA op (instead of
using scf.IfOp), all operands (k_coord, m_coord, desc_a, smem_a, mbar) are
computed unconditionally at the MLIR/LLVM level. PTXAS therefore emits any
required R2UR conversions outside the predicated ELECT block, which is legal
on sm_90a. The scf.IfOp path, by contrast, causes PTXAS to sink the R2UR
into the @P0-predicated elected-thread block, producing the illegal
"@P0 R2UR" instruction (CUDA_ERROR_ILLEGAL_INSTRUCTION / error 715).
"""
l_coord = _Int32(0).ir_value(loc=loc, ip=ip)
# llvm_ptr is a @property on _Pointer — access without call syntax.
smem_a_llvm = smem_a.llvm_ptr
smem_b_llvm = smem_b.llvm_ptr
mbar_llvm = mbar.llvm_ptr
desc_a_llvm = desc_a.llvm_ptr
desc_b_llvm = desc_b.llvm_ptr
# TMA A: elect one thread and issue the load with predicate.
is_elected_a = _nvvm_d.elect_sync(_T.bool(), loc=loc, ip=ip)
_nvvm_d.CpAsyncBulkTensorGlobalToSharedClusterOp(
dstMem=smem_a_llvm,
tmaDescriptor=desc_a_llvm,
coordinates=[
k_coord.ir_value(loc=loc, ip=ip),
m_coord.ir_value(loc=loc, ip=ip),
l_coord,
],
mbar=mbar_llvm,
im2colOffsets=[],
predicate=is_elected_a,
loadMode=_CpAsyncBulkTensorLoadMode.TILE,
loc=loc,
ip=ip,
)
# TMA B: elect one thread and issue the load with predicate.
is_elected_b = _nvvm_d.elect_sync(_T.bool(), loc=loc, ip=ip)
_nvvm_d.CpAsyncBulkTensorGlobalToSharedClusterOp(
dstMem=smem_b_llvm,
tmaDescriptor=desc_b_llvm,
coordinates=[
k_coord.ir_value(loc=loc, ip=ip),
n_coord.ir_value(loc=loc, ip=ip),
l_coord,
],
mbar=mbar_llvm,
im2colOffsets=[],
predicate=is_elected_b,
loadMode=_CpAsyncBulkTensorLoadMode.TILE,
loc=loc,
ip=ip,
)
class _GroupedWorkTileInfo:
"""Work tile info for grouped GEMM: carries is_valid_tile + group_search_result."""
def __init__(self, is_valid_tile, group_search_result):
self._is_valid_tile = is_valid_tile
self.group_search_result = group_search_result
@property
def is_valid_tile(self):
return self._is_valid_tile
def __extract_mlir_values__(self):
values = extract_mlir_values(self._is_valid_tile)
values.extend(extract_mlir_values(self.group_search_result))
return values
def __new_from_mlir_values__(self, values):
n_valid = len(extract_mlir_values(self._is_valid_tile))
is_valid = new_from_mlir_values(self._is_valid_tile, values[:n_valid])
gsr = new_from_mlir_values(self.group_search_result, values[n_valid:])
return _GroupedWorkTileInfo(is_valid, gsr)
class StaticPersistentGroupTileScheduler:
"""Grouped-GEMM-aware persistent tile scheduler.
Wraps StaticPersistentTileScheduler + GroupedGemmTileSchedulerHelper.
This class is not yet in cutlass.utils 4.3.5, so it is defined locally.
"""
def __init__(self, tile_sched, group_helper, problem_sizes_mnkl):
self._tile_sched = tile_sched
self._group_helper = group_helper
self._problem_sizes_mnkl = problem_sizes_mnkl
def __extract_mlir_values__(self):
values = extract_mlir_values(self._tile_sched)
values.extend(extract_mlir_values(self._group_helper))
return values
def __new_from_mlir_values__(self, values):
n_tile = len(extract_mlir_values(self._tile_sched))
tile_sched = new_from_mlir_values(self._tile_sched, values[:n_tile])
group_helper = new_from_mlir_values(self._group_helper, values[n_tile:])
return StaticPersistentGroupTileScheduler(
tile_sched, group_helper, self._problem_sizes_mnkl
)
@staticmethod
def create(
tile_sched_params,
bid,
grid_dim,
cluster_tile_shape_mnk,
search_state,
group_count,
problem_sizes_mnkl,
):
tile_sched = utils.StaticPersistentTileScheduler.create(
tile_sched_params, bid, grid_dim
)
group_helper = utils.GroupedGemmTileSchedulerHelper(
group_count, tile_sched_params, cluster_tile_shape_mnk, search_state
)
return StaticPersistentGroupTileScheduler(
tile_sched, group_helper, problem_sizes_mnkl
)
@staticmethod
def get_grid_shape(tile_sched_params, max_active_clusters):
return utils.StaticPersistentTileScheduler.get_grid_shape(
tile_sched_params, max_active_clusters
)
def initial_work_tile_info(self):
return self.get_current_work()
def get_current_work(self):
base = self._tile_sched.get_current_work()
# For invalid tiles (linear_idx >= total_tiles), delinearize_z's inner
# while-loop would infinite-loop. Clamp the z coordinate to 0 by
# multiplying with is_valid_tile (i1 zero-extended to i32). z=0 is
# always a valid tile index, so the group search terminates cleanly;
# the resulting GroupSearchResult is discarded because the caller only
# accesses it inside "while work_tile.is_valid_tile:".
valid_int = base.is_valid_tile.to(cutlass.Int32)
safe_tile_idx = (
base.tile_idx[0],
base.tile_idx[1],
base.tile_idx[2] * valid_int,
)
gsr = self._group_helper.delinearize_z(safe_tile_idx, self._problem_sizes_mnkl)
return _GroupedWorkTileInfo(base.is_valid_tile, gsr)
def advance_to_next_work(self, *, advance_count=1):
self._tile_sched.advance_to_next_work(advance_count=advance_count)
@property
def num_tiles_executed(self):
return self._tile_sched.num_tiles_executed
class _FixedTensorMapManager(utils.TensorMapManager):
"""Local stability manager for environments using older cutlass.utils.
By default, SMEM update/publish is routed through the GMEM branch for
stability. Set GROUPED_GEMM_ENABLE_TRUE_SMEM_TMAP=1 to test the true SMEM
path during investigation.
"""
@_dsl_user_op
@cute.jit
def update_tensormap(
self,
tensor_gmem,
tma_copy_atom,
tensormap_gmem_ptr,
warp_id: int,
tensormap_smem_ptr,
*,
loc=None,
ip=None,
) -> None:
warp_idx = cute.arch.make_warp_uniform(
cute.arch.warp_idx(loc=loc, ip=ip), loc=loc, ip=ip
)
if cutlass.const_expr(
self.tensormap_update_mode == utils.TensorMapUpdateMode.SMEM
):
# Hoist SMEM pointer uniformization outside predicated blocks to avoid
# predicated R2UR generation on sm_90a.
uniform_smem_ptrs = tuple(
_cute_make_ptr(
p.dtype,
cute.arch.make_warp_uniform(p.toint(), loc=loc, ip=ip),
mem_space=_CuteAddressSpace.smem,
assumed_align=p.alignment,
)
for p in tensormap_smem_ptr
)
else:
uniform_smem_ptrs = tensormap_smem_ptr
if warp_idx == warp_id:
if cutlass.const_expr(
self.tensormap_update_mode == utils.TensorMapUpdateMode.SMEM
and _ENABLE_TRUE_SMEM_TMAP
and _ENABLE_TRUE_SMEM_TMAP_PREUPDATE
):
for atom, tensor, sptr in zip(
tma_copy_atom, tensor_gmem, uniform_smem_ptrs
):
cute.nvgpu.cpasync.update_tma_descriptor(
atom, tensor, sptr, loc=loc, ip=ip
)
with cute.arch.elect_one(loc=loc, ip=ip):
cute.arch.cp_async_bulk_commit_group(loc=loc, ip=ip)
cute.arch.cp_async_bulk_wait_group(0, read=True, loc=loc, ip=ip)
cute.arch.sync_warp(loc=loc, ip=ip)
if cutlass.const_expr(
self.tensormap_update_mode == utils.TensorMapUpdateMode.SMEM
and _ENABLE_TRUE_SMEM_TMAP
and _ENABLE_TRUE_SMEM_TMAP_PUBLISH
):
for gptr, sptr in zip(tensormap_gmem_ptr, uniform_smem_ptrs):
cute.nvgpu.cpasync.cp_fence_tma_desc_release(
gptr, sptr, loc=loc, ip=ip
)
else:
for atom, tensor, gptr in zip(
tma_copy_atom, tensor_gmem, tensormap_gmem_ptr
):
cute.nvgpu.cpasync.update_tma_descriptor(
atom, tensor, gptr, loc=loc, ip=ip
)
cute.arch.sync_warp(loc=loc, ip=ip)
cute.nvgpu.cpasync.fence_tma_desc_release(loc=loc, ip=ip)
class HopperGroupedGemmPersistentKernel:
"""
This class implements batched matrix multiplication (C = A x B) with support for various data types
and architectural features specific to Hopper GPUs.
:param acc_dtype: Data type for accumulation during computation
:type acc_dtype: type[cutlass.Numeric]
:param tile_shape_mn: Shape of the CTA tile (M,N)
:type tile_shape_mn: Tuple[int, int]
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
:type cluster_shape_mn: Tuple[int, int]
:note: Supported A/B data types:
- Float16
A and B must have the same data type
- Float8E4M3FN/Float8E5M2
A and B can have different types (Float8E4M3FN/Float8E5M2)
only support k-major layout
- Int8/Uint8
A and B can have different types (Int8/Uint8)
only support k-major layout
:note: Supported accumulation types:
- Float32/Float16 (for all floating point inputs)
- Int32 (for Int8/Uint8 inputs)
:note: Constraints:
- CTA tile M must be 64/128
- CTA tile N must be 64/128/256
- CTA tile K must be 64
- Cluster shape M/N must be positive and power of 2, total cluster size <= 4
Example:
>>> gemm = HopperGroupedGemmPersistentKernel(
... acc_dtype=cutlass.Float32,
... tile_shape_mn=(128, 256),
... cluster_shape_mn=(1, 1)
... )
>>> gemm(a_tensor, b_tensor, c_tensor, stream)
"""
bytes_per_tensormap = 128
num_tensormaps = 3 # A, B, C
def __init__(
self,
acc_dtype: type[cutlass.Numeric],
tile_shape_mn: tuple[int, int],
cluster_shape_mn: tuple[int, int],
swizzle_size: int,
raster_along_m: bool,
tensormap_update_mode: utils.TensorMapUpdateMode = utils.TensorMapUpdateMode.SMEM,
):
"""
Initializes the configuration for a Hopper dense GEMM kernel.
This configuration includes data types for operands, tile shape, cluster configuration,
and thread layout.
:param acc_dtype: Data type for accumulation during computation
:type acc_dtype: type[cutlass.Numeric]
:param tile_shape_mn: Shape of the CTA tile (M,N)
:type tile_shape_mn: Tuple[int, int]
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
:type cluster_shape_mn: Tuple[int, int]
"""
self.acc_dtype = acc_dtype
self.cluster_shape_mn = cluster_shape_mn
self.swizzle_size = swizzle_size
self.raster_along_m = raster_along_m
self.mma_inst_shape_mn = None
# K dimension is deferred in _setup_attributes
self.tile_shape_mnk = (*tile_shape_mn, 1)
# For large tile size, using two warp groups is preferred because using only one warp
# group may result in register spill
self.atom_layout_mnk = (
(2, 1, 1)
if self.tile_shape_mnk[0] > 64 and self.tile_shape_mnk[1] > 128
else (1, 1, 1)
)
self.num_mcast_ctas_a = None
self.num_mcast_ctas_b = None
self.is_a_mcast = False
self.is_b_mcast = False
self.tiled_mma = None
self.occupancy = 1
self.num_dma_warp_groups = 1
self.num_mma_warp_groups = math.prod(self.atom_layout_mnk)
self.num_warps_per_warp_group = 4
self.num_threads_per_warp_group = self.num_warps_per_warp_group * 32
self.threads_per_cta = (
self.num_dma_warp_groups + self.num_mma_warp_groups
) * self.num_threads_per_warp_group
self.load_warp_id = 0
self.epi_store_warp_id = (
self.num_dma_warp_groups * self.num_warps_per_warp_group
)
self.load_register_requirement = 40
self.mma_register_requirement = 232
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_90")
self.ab_stage = None
self.epi_stage = None
self.a_smem_layout_staged = None
self.b_smem_layout_staged = None
self.epi_smem_layout_staged = None
self.epi_tile = None
self.shared_storage = None
self.buffer_align_bytes = 1024
self.num_mma_threads = (
self.num_mma_warp_groups * self.num_threads_per_warp_group
)
self.epilog_sync_barrier = pipeline.NamedBarrier(
barrier_id=1, num_threads=self.num_mma_threads
)
# Grouped GEMM: tensor map update mode
self.tensormap_update_mode = tensormap_update_mode
# Delegate A/B tensor map init to MMA warp for better latency hiding (SMEM mode)
self.delegate_tensormap_ab_init = (
tensormap_update_mode == utils.TensorMapUpdateMode.SMEM
)
# barrier_id=2 (barrier_id=1 is already used by epilog_sync_barrier)
# Only the load warp (32 threads) + all MMA threads participate:
# DMA warps 1-3 are idle and never reach this barrier.
self.tensormap_ab_init_barrier = pipeline.NamedBarrier(
barrier_id=2,
num_threads=self.num_mma_threads + 32,
)
def _setup_attributes(self):
"""Set up configurations that are dependent on GEMM inputs
This method configures various attributes based on the input tensor properties
(data types, leading dimensions) and kernel settings:
- Configuring tiled MMA
- Computing MMA/cluster/tile shapes
- Computing cluster layout
- Computing multicast CTAs for A/B
- Computing epilogue subtile
- Setting up A/B/C stage counts in shared memory
- Computing A/B/C shared memory layout
"""
# check the cta tile shape
if self.tile_shape_mnk[0] not in [64, 128]:
raise ValueError("CTA tile shape M must be 64/128")
if self.tile_shape_mnk[1] not in [64, 128, 256]:
raise ValueError("CTA tile shape N must be 64/128/256")
self.tiled_mma = sm90_utils.make_trivial_tiled_mma(
self.a_dtype,
self.b_dtype,
self.a_layout.sm90_mma_major_mode(),
self.b_layout.sm90_mma_major_mode(),
self.acc_dtype,
self.atom_layout_mnk,
tiler_mn=(64, self.tile_shape_mnk[1]),
)
mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2])
mma_inst_tile_k = 4
self.tile_shape_mnk = (
self.tile_shape_mnk[0],
self.tile_shape_mnk[1],
mma_inst_shape_k * mma_inst_tile_k,
)
self.cta_layout_mnk = cute.make_layout((*self.cluster_shape_mn, 1))
self.num_mcast_ctas_a = self.cluster_shape_mn[1]
self.num_mcast_ctas_b = self.cluster_shape_mn[0]
self.is_a_mcast = self.num_mcast_ctas_a > 1
self.is_b_mcast = self.num_mcast_ctas_b > 1
# Cluster tile shape used by group tile scheduler
self.cluster_tile_shape_mnk = (
self.tile_shape_mnk[0] * self.cluster_shape_mn[0],
self.tile_shape_mnk[1] * self.cluster_shape_mn[1],
self.tile_shape_mnk[2],
)
is_cooperative = self.atom_layout_mnk == (2, 1, 1)
self.epi_tile = self._sm90_compute_tile_shape_or_override(
self.tile_shape_mnk, self.c_dtype, is_cooperative=is_cooperative
)
# Compute stage before compute smem layout
self.ab_stage, self.epi_stage = self._compute_stages(
self.tile_shape_mnk,
self.a_dtype,
self.b_dtype,
self.epi_tile,
self.c_dtype,
self.smem_capacity,
self.occupancy,
)
(
self.a_smem_layout_staged,
self.b_smem_layout_staged,
self.epi_smem_layout_staged,
) = self._make_smem_layouts(
self.tile_shape_mnk,
self.epi_tile,
self.a_dtype,
self.a_layout,
self.b_dtype,
self.b_layout,
self.ab_stage,
self.c_dtype,
self.c_layout,
self.epi_stage,
)
@cute.jit
def __call__(
self,
initial_a: cute.Tensor,
initial_b: cute.Tensor,
initial_c: cute.Tensor,
group_count: cutlass.Constexpr[int],
problem_shape_mnkl: cute.Tensor,
strides_abc: cute.Tensor,
tensor_address_abc: cute.Tensor,
total_num_clusters: cutlass.Constexpr[int],
tensormap_cute_tensor: cute.Tensor,
max_active_clusters: cutlass.Constexpr[int],
stream: cuda.CUstream,
):
"""Execute the grouped GEMM operation.
:param initial_a: Carries dtype+majorness only (shape irrelevant).
:param initial_b: Carries dtype+majorness only (shape irrelevant).
:param initial_c: Carries dtype+majorness only (shape irrelevant).
:param group_count: Number of GEMM groups (compile-time constant).
:param problem_shape_mnkl: Device tensor of shape (G, 4) Int32 with (M,N,K,L) per group.
:param strides_abc: Device tensor of shape (G, 3, 2) Int32 with strides per group.
:param tensor_address_abc: Device tensor of shape (G, 3) Int64 with base ptrs per group.
:param total_num_clusters: Total clusters across all groups (compile-time constant).
:param tensormap_cute_tensor: Tensor map workspace, shape (num_sms, 3, 16) Int64.
:param max_active_clusters: Max active clusters (compile-time constant).
:param stream: CUDA stream.
"""
# Setup static attributes from initial tensor dtype/layout
self.a_dtype = initial_a.element_type
self.b_dtype = initial_b.element_type
self.c_dtype = initial_c.element_type
self.a_layout = utils.LayoutEnum.from_tensor(initial_a)
self.b_layout = utils.LayoutEnum.from_tensor(initial_b)
self.c_layout = utils.LayoutEnum.from_tensor(initial_c)
if cutlass.const_expr(
self.a_dtype.width == 16 and self.a_dtype != self.b_dtype
):
raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}")
if cutlass.const_expr(self.a_dtype.width != self.b_dtype.width):
raise TypeError(
f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}"
)
if cutlass.const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8):
raise TypeError("a_dtype should be float16, float8, or int8")
self._setup_attributes()
tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
initial_a,
self.a_smem_layout_staged,
(self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
self.cluster_shape_mn[1],
)
tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
initial_b,
self.b_smem_layout_staged,
(self.tile_shape_mnk[1], self.tile_shape_mnk[2]),
self.cluster_shape_mn[0],
)
tma_atom_c, tma_tensor_c = self._make_tma_store_atoms_and_tensors(
initial_c,
self.epi_smem_layout_staged,
self.epi_tile,
)
tile_sched_params, grid = self._compute_grid(
total_num_clusters,
self.cluster_shape_mn,
max_active_clusters,
)
# Number of Int64 words needed for the SMEM tensor map buffer (0 in GMEM mode)
self.size_tensormap_in_i64 = (
0
if self.tensormap_update_mode == utils.TensorMapUpdateMode.GMEM
else HopperGroupedGemmPersistentKernel.num_tensormaps
* HopperGroupedGemmPersistentKernel.bytes_per_tensormap
// 8
)
@cute.struct
class SharedStorage:
tensormap_buffer: cute.struct.MemRange[
cutlass.Int64, self.size_tensormap_in_i64
]
mainloop_pipeline_array_ptr: cute.struct.MemRange[
cutlass.Int64, self.ab_stage * 2
]
sA: cute.struct.Align[
cute.struct.MemRange[
self.a_dtype, cute.cosize(self.a_smem_layout_staged)
],
self.buffer_align_bytes,
]
sB: cute.struct.Align[
cute.struct.MemRange[
self.b_dtype, cute.cosize(self.b_smem_layout_staged)
],
self.buffer_align_bytes,
]
sC: cute.struct.Align[
cute.struct.MemRange[
self.c_dtype,
cute.cosize(self.epi_smem_layout_staged),
],
self.buffer_align_bytes,
]
self.shared_storage = SharedStorage
# Launch the kernel synchronously
self.kernel(
tma_atom_a,
tma_tensor_a,
tma_atom_b,
tma_tensor_b,
tma_atom_c,
tma_tensor_c,
self.tiled_mma,
self.cta_layout_mnk,
self.a_smem_layout_staged,
self.b_smem_layout_staged,
self.epi_smem_layout_staged,
tile_sched_params,
group_count,
problem_shape_mnkl,
strides_abc,
tensor_address_abc,
tensormap_cute_tensor,
).launch(
grid=grid,
block=[self.threads_per_cta, 1, 1],
cluster=(*self.cluster_shape_mn, 1),
min_blocks_per_mp=1,
stream=stream,
)
return
# GPU device kernel
@cute.kernel
def kernel(
self,
tma_atom_a: cute.CopyAtom,
mA_mkl: cute.Tensor,
tma_atom_b: cute.CopyAtom,
mB_nkl: cute.Tensor,
tma_atom_c: cute.CopyAtom,
mC_mnl: cute.Tensor,
tiled_mma: cute.TiledMma,
cta_layout_mnk: cute.Layout,
a_smem_layout_staged: cute.ComposedLayout,
b_smem_layout_staged: cute.ComposedLayout,
epi_smem_layout_staged: cute.ComposedLayout,
tile_sched_params: utils.PersistentTileSchedulerParams,
group_count: cutlass.Constexpr[int],
problem_sizes_mnkl: cute.Tensor,
strides_abc: cute.Tensor,
ptrs_abc: cute.Tensor,
tensormaps: cute.Tensor,
):
"""
GPU device kernel performing the batched GEMM computation.
:param tma_atom_a: TMA copy atom for A tensor
:type tma_atom_a: cute.CopyAtom
:param mA_mkl: Input tensor A
:type mA_mkl: cute.Tensor
:param tma_atom_b: TMA copy atom for B tensor
:type tma_atom_b: cute.CopyAtom
:param mB_nkl: Input tensor B
:type mB_nkl: cute.Tensor
:param tma_atom_c: TMA copy atom for C tensor
:type tma_atom_c: cute.CopyAtom
:param mC_mnl: Output tensor C
:type mC_mnl: cute.Tensor
:param tiled_mma: Tiled MMA object
:type tiled_mma: cute.TiledMma
:param cta_layout_mnk: CTA layout
:type cta_layout_mnk: cute.Layout
:param a_smem_layout_staged: Shared memory layout for A
:type a_smem_layout_staged: cute.ComposedLayout
:param b_smem_layout_staged: Shared memory layout for B
:type b_smem_layout_staged: cute.ComposedLayout
:param epi_smem_layout_staged: Shared memory layout for epilogue
:type epi_smem_layout_staged: cute.ComposedLayout
:param tile_sched_params: Parameters for the persistent tile scheduler
:type tile_sched_params: utils.PersistentTileSchedulerParams
"""
tidx, _, _ = cute.arch.thread_idx()
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
# Prefetch Tma desc
if warp_idx == 0:
cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_a)
cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_b)
cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_c)
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster()
)
cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster)
a_mcast_mask = cute.make_layout_image_mask(
cta_layout_mnk, cluster_coord_mnk, mode=1
)
b_mcast_mask = cute.make_layout_image_mask(
cta_layout_mnk, cluster_coord_mnk, mode=0
)
a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0
b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0
a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0))
b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0))
tma_copy_bytes = cute.size_in_bytes(
self.a_dtype, a_smem_layout
) + cute.size_in_bytes(self.b_dtype, b_smem_layout)
# Alloc and init AB full/empty + ACC full mbar (pipeline)
smem = cutlass.utils.SmemAllocator()
storage = smem.allocate(self.shared_storage)
# mbar arrays
mainloop_pipeline_array_ptr = storage.mainloop_pipeline_array_ptr.data_ptr()
# Threads/warps participating in this pipeline
mainloop_pipeline_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread
)
# Each warp will constribute to the arrive count with the number of mcast size
mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
consumer_arrive_cnt = (
mcast_size * self.num_mma_warp_groups * self.num_warps_per_warp_group
)
mainloop_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, consumer_arrive_cnt
)
mainloop_pipeline = pipeline.PipelineTmaAsync.create(
barrier_storage=mainloop_pipeline_array_ptr,
num_stages=self.ab_stage,
producer_group=mainloop_pipeline_producer_group,
consumer_group=mainloop_pipeline_consumer_group,
tx_count=tma_copy_bytes,
cta_layout_vmnk=cute.make_layout((1, *cta_layout_mnk.shape)),
defer_sync=True,
)
# Cluster arrive after barrier init
pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True)
# Generate smem tensor A/B
sA = storage.sA.get_tensor(
a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner
)
sB = storage.sB.get_tensor(
b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner
)
sC = storage.sC.get_tensor(
epi_smem_layout_staged.outer, swizzle=epi_smem_layout_staged.inner
)
# Local_tile partition global tensors
# (bM, bK, RestM, RestK, RestL)
gA_mkl = cute.local_tile(
mA_mkl,
cute.slice_(self.tile_shape_mnk, (None, 0, None)),
(None, None, None),
)
# (bN, bK, RestN, RestK, RestL)
gB_nkl = cute.local_tile(
mB_nkl,
cute.slice_(self.tile_shape_mnk, (0, None, None)),
(None, None, None),
)
# (bM, bN, RestM, RestN, RestL)
gC_mnl = cute.local_tile(
mC_mnl,
cute.slice_(self.tile_shape_mnk, (None, None, 0)),
(None, None, None),
)
# Partition shared tensor for TMA load A/B
# TMA load A partition_S/D
a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape)
a_cta_crd = cluster_coord_mnk[1]
tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
tma_atom_a,
a_cta_crd,
a_cta_layout,
cute.group_modes(sA, 0, 2),
cute.group_modes(gA_mkl, 0, 2),
)
# TMA load B partition_S/D
b_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape)
b_cta_crd = cluster_coord_mnk[0]
tBsB, tBgB = cute.nvgpu.cpasync.tma_partition(
tma_atom_b,
b_cta_crd,
b_cta_layout,
cute.group_modes(sB, 0, 2),
cute.group_modes(gB_nkl, 0, 2),
)
# Partition global tensor for TiledMMA_A/B/C
warp_group_idx = cute.arch.make_warp_uniform(
tidx // self.num_threads_per_warp_group
)
mma_warp_group_thread_layout = cute.make_layout(
self.num_mma_warp_groups, stride=self.num_threads_per_warp_group
)
thr_mma = tiled_mma.get_slice(
mma_warp_group_thread_layout(warp_group_idx - self.num_dma_warp_groups)
)
# Make fragments
tCsA = thr_mma.partition_A(sA)
tCsB = thr_mma.partition_B(sB)
tCrA = tiled_mma.make_fragment_A(tCsA)
tCrB = tiled_mma.make_fragment_B(tCsB)
tCgC = thr_mma.partition_C(gC_mnl)
acc_shape = tCgC.shape[:3]
accumulators = cute.make_rmem_tensor(acc_shape, self.acc_dtype)
# Cluster wait for barrier init
pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn)
#
# Setup per-SM tensor map pointers (shared by DMA and MMA warps)
#
grid_dim = cute.arch.grid_dim()
bid = cute.arch.block_idx()
sm_idx = bid[2] * grid_dim[1] * grid_dim[0] + bid[1] * grid_dim[0] + bid[0]
tensormap_manager = _FixedTensorMapManager(
self.tensormap_update_mode,
HopperGroupedGemmPersistentKernel.bytes_per_tensormap,
)
tensormap_a_ptr = tensormap_manager.get_tensormap_ptr(
tensormaps[(sm_idx, 0, None)].iterator
)
tensormap_b_ptr = tensormap_manager.get_tensormap_ptr(
tensormaps[(sm_idx, 1, None)].iterator
)
tensormap_c_ptr = tensormap_manager.get_tensormap_ptr(
tensormaps[(sm_idx, 2, None)].iterator
)
# SMEM buffer pointers for tensor maps (only non-None in SMEM mode)
if cutlass.const_expr(
self.tensormap_update_mode == utils.TensorMapUpdateMode.SMEM
):
smem_tm_base = storage.tensormap_buffer.data_ptr()
tensormap_a_smem_ptr = smem_tm_base
tensormap_b_smem_ptr = (
smem_tm_base
+ HopperGroupedGemmPersistentKernel.bytes_per_tensormap // 8
)
tensormap_c_smem_ptr = (
smem_tm_base
+ 2 * HopperGroupedGemmPersistentKernel.bytes_per_tensormap // 8
)
else:
tensormap_a_smem_ptr = None
tensormap_b_smem_ptr = None
tensormap_c_smem_ptr = None
tile_sched_params_for_sched = tile_sched_params
is_dma_warp_group = warp_group_idx < self.num_dma_warp_groups
if is_dma_warp_group:
cute.arch.warpgroup_reg_dealloc(self.load_register_requirement)
#
# DMA warp group (load A/B with TMA, update tensor maps A/B per group)
#
if warp_idx == self.load_warp_id:
# Initialize tensor maps A/B (either here or delegated to MMA warp)
if cutlass.const_expr(not self.delegate_tensormap_ab_init):
tensormap_manager.init_tensormap_from_atom(
tma_atom_a, tensormap_a_ptr, self.load_warp_id
)
tensormap_manager.init_tensormap_from_atom(
tma_atom_b, tensormap_b_ptr, self.load_warp_id
)
tensormap_manager.fence_tensormap_initialization()
else:
# Delegate path: wait for MMA warp to finish A/B tensor map init.
# Must be unconditional (before the tile loop) so every CTA
# participates even when it processes zero tiles.
self.tensormap_ab_init_barrier.arrive_and_wait()
last_group_idx = cutlass.Int32(-1)
# Create a per-warp scheduler (same state — each warp runs its own instance)
tile_sched = StaticPersistentGroupTileScheduler.create(
tile_sched_params_for_sched,
bid,
grid_dim,
self.cluster_tile_shape_mnk,
utils.create_initial_search_state(),
group_count,
problem_sizes_mnkl,
)
work_tile = tile_sched.initial_work_tile_info()
mainloop_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.ab_stage
)
while work_tile.is_valid_tile:
grouped_info = work_tile.group_search_result
cur_group_idx = grouped_info.group_idx
cur_k_tile_cnt = grouped_info.cta_tile_count_k
if cur_k_tile_cnt != 0:
is_group_changed = cur_group_idx != last_group_idx
if is_group_changed:
real_a = self.make_tensor_for_tensormap_update(
cur_group_idx,
self.a_dtype,
(
grouped_info.problem_shape_m,
grouped_info.problem_shape_n,
grouped_info.problem_shape_k,
),
strides_abc,
ptrs_abc,
0,
)
real_b = self.make_tensor_for_tensormap_update(
cur_group_idx,
self.b_dtype,
(
grouped_info.problem_shape_m,
grouped_info.problem_shape_n,
grouped_info.problem_shape_k,
),
strides_abc,
ptrs_abc,
1,
)
tensormap_manager.update_tensormap(
(real_a, real_b),
(tma_atom_a, tma_atom_b),
(tensormap_a_ptr, tensormap_b_ptr),
self.load_warp_id,
(tensormap_a_smem_ptr, tensormap_b_smem_ptr),
)
tensormap_manager.fence_tensormap_update(tensormap_a_ptr)
tensormap_manager.fence_tensormap_update(tensormap_b_ptr)
mma_tile_coord_mnl = (
grouped_info.cta_tile_idx_m,
grouped_info.cta_tile_idx_n,
0,
)
tAgA_slice = tAgA[
(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])
]
tBgB_slice = tBgB[
(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])
]
# Cache loop-invariant TMA descriptor pointers before K-loop.
# Keep two variants:
# - gmem descriptors for direct NVVM cp.async.bulk.tensor ops
# - generic descriptors for cute.copy fallback (mcast path)
#
# Using explicit gmem descriptors in the direct NVVM path avoids
# relying on generic-pointer lowering for the descriptor operand.
tma_a_desc_ptr_nvvm = tensormap_manager.get_tensormap_ptr(
tensormap_a_ptr, cute.AddressSpace.gmem
)
tma_b_desc_ptr_nvvm = tensormap_manager.get_tensormap_ptr(
tensormap_b_ptr, cute.AddressSpace.gmem
)
tma_a_desc_ptr_copy = tensormap_manager.get_tensormap_ptr(
tensormap_a_ptr, cute.AddressSpace.generic
)
tma_b_desc_ptr_copy = tensormap_manager.get_tensormap_ptr(
tensormap_b_ptr, cute.AddressSpace.generic
)
# Pre-compute loop-invariant TMA coordinates (m, n).
# For the non-mcast case (cluster 1x1), the TMA box offset is
# simply cta_tile_idx * tile_size. k_coord is computed inside
# the loop because it varies per K-tile.
_tile_k = self.tile_shape_mnk[2]
_tile_m = self.tile_shape_mnk[0]
_tile_n = self.tile_shape_mnk[1]
use_nvvm_non_mcast_load = cutlass.const_expr(
_ENABLE_NVVM_NON_MCAST_LOAD
and not self.is_a_mcast
and not self.is_b_mcast
)
mainloop_producer_state.reset_count()
for k_tile in cutlass.range(0, cur_k_tile_cnt, 1, unroll=1):
mainloop_pipeline.producer_acquire(mainloop_producer_state)
if use_nvvm_non_mcast_load:
# Non-mcast path: use NVVM dialect TMA op with
# predicate= so operands are computed outside any
# predicated block. This prevents PTXAS from
# generating the illegal @P0 R2UR instruction on
# sm_90a (CUDA_ERROR_ILLEGAL_INSTRUCTION / 715).
_tma_load_ab_nvvm_no_mcast(
k_tile * _tile_k,
mma_tile_coord_mnl[0] * _tile_m,
mma_tile_coord_mnl[1] * _tile_n,
tma_a_desc_ptr_nvvm,
tma_b_desc_ptr_nvvm,
tAsA[(None, mainloop_producer_state.index)].iterator,
tBsB[(None, mainloop_producer_state.index)].iterator,
mainloop_pipeline.producer_get_barrier(
mainloop_producer_state
),
)
else:
# Mcast path: fall back to cute.copy which handles
# the multicast mask for multi-CTA clusters.
cute.copy(
tma_atom_a,
tAgA_slice[(None, k_tile)],
tAsA[(None, mainloop_producer_state.index)],
tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
mainloop_producer_state
),
mcast_mask=a_mcast_mask,
tma_desc_ptr=tma_a_desc_ptr_copy,
)
cute.copy(
tma_atom_b,
tBgB_slice[(None, k_tile)],
tBsB[(None, mainloop_producer_state.index)],
tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
mainloop_producer_state
),
mcast_mask=b_mcast_mask,
tma_desc_ptr=tma_b_desc_ptr_copy,
)
mainloop_pipeline.producer_commit(mainloop_producer_state)
mainloop_producer_state.advance()
else:
pass # k_tile_cnt == 0: tensor map init already done before loop
tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
last_group_idx = cur_group_idx
mainloop_pipeline.producer_tail(mainloop_producer_state)
#
# MMA warp group (WGMMA + epilogue, update tensor map C per group)
#
if not is_dma_warp_group:
cute.arch.warpgroup_reg_alloc(self.mma_register_requirement)
# MMA warp always initializes tensor map C
tensormap_manager.init_tensormap_from_atom(
tma_atom_c, tensormap_c_ptr, self.epi_store_warp_id
)
# When delegating, MMA warp also initializes A/B and signals DMA warp
if cutlass.const_expr(self.delegate_tensormap_ab_init):
tensormap_manager.init_tensormap_from_atom(
tma_atom_a, tensormap_a_ptr, self.epi_store_warp_id
)
tensormap_manager.init_tensormap_from_atom(
tma_atom_b, tensormap_b_ptr, self.epi_store_warp_id
)
self.tensormap_ab_init_barrier.arrive_and_wait()
tensormap_manager.fence_tensormap_initialization()
tile_sched = StaticPersistentGroupTileScheduler.create(
tile_sched_params_for_sched,
bid,
grid_dim,
self.cluster_tile_shape_mnk,
utils.create_initial_search_state(),
group_count,
problem_sizes_mnkl,
)
work_tile = tile_sched.initial_work_tile_info()
mainloop_consumer_read_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.ab_stage
)
mainloop_consumer_release_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.ab_stage
)
num_k_blocks = cute.size(tCrA, mode=[2])
# Partition for epilogue
copy_atom_r2s = sm90_utils.sm90_get_smem_store_op(
self.c_layout,
elem_ty_d=self.c_dtype,
elem_ty_acc=self.acc_dtype,
)
copy_atom_C = cute.make_copy_atom(
cute.nvgpu.warp.StMatrix8x8x16bOp(
self.c_layout.is_m_major_c(),
4,
),
self.c_dtype,
)
tiled_copy_C_Atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
tiled_copy_r2s = cute.make_tiled_copy_S(
copy_atom_r2s,
tiled_copy_C_Atom,
)
# (R2S, R2S_M, R2S_N, PIPE_D)
thr_copy_r2s = tiled_copy_r2s.get_slice(
tidx - self.num_dma_warp_groups * self.num_threads_per_warp_group
)
# (t)hread-partition for (r)egister to (s)mem copy (tRS_)
tRS_sD = thr_copy_r2s.partition_D(sC)
# (R2S, R2S_M, R2S_N)
tRS_rAcc = tiled_copy_r2s.retile(accumulators)
# Allocate D registers.
rD_shape = cute.shape(thr_copy_r2s.partition_S(sC))
tRS_rD_layout = cute.make_layout(rD_shape[:3])
tRS_rD = cute.make_rmem_tensor(tRS_rD_layout.shape, self.acc_dtype)
tRS_rD_out = cute.make_rmem_tensor(tRS_rD_layout.shape, self.c_dtype)
size_tRS_rD = cute.size(tRS_rD)
k_pipe_mmas = 1
# Initialize tma store pipeline
tma_store_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
self.num_mma_threads,
)
tma_store_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.epi_stage,
producer_group=tma_store_producer_group,
)
last_group_idx_mma = cutlass.Int32(-1)
while work_tile.is_valid_tile:
grouped_info = work_tile.group_search_result
cur_group_idx = grouped_info.group_idx
cur_k_tile_cnt = grouped_info.cta_tile_count_k
# Per-group tensor map C update (only epi_store warp issues it)
is_group_changed = cur_group_idx != last_group_idx_mma
if is_group_changed and warp_idx == self.epi_store_warp_id:
real_c = self.make_tensor_for_tensormap_update(
cur_group_idx,
self.c_dtype,
(
grouped_info.problem_shape_m,
grouped_info.problem_shape_n,
grouped_info.problem_shape_k,
),
strides_abc,
ptrs_abc,
2,
)
tensormap_manager.update_tensormap(
(real_c,),
(tma_atom_c,),
(tensormap_c_ptr,),
self.epi_store_warp_id,
(tensormap_c_smem_ptr,),
)
tensormap_manager.fence_tensormap_update(tensormap_c_ptr)
mma_tile_coord_mnl = (
grouped_info.cta_tile_idx_m,
grouped_info.cta_tile_idx_n,
0,
)
gC_mnl_slice = gC_mnl[(None, None, *mma_tile_coord_mnl)]
# MAINLOOP
mainloop_consumer_read_state.reset_count()
mainloop_consumer_release_state.reset_count()
accumulators.fill(0.0)
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
cute.nvgpu.warpgroup.fence()
prologue_mma_cnt = cutlass.min(k_pipe_mmas, cur_k_tile_cnt)
for k_tile in cutlass.range(0, prologue_mma_cnt, 1, unroll=1):
# Wait for TMA copies to complete
mainloop_pipeline.consumer_wait(mainloop_consumer_read_state)
# WGMMA
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
k_block_coord = (
None,
None,
k_block_idx,
mainloop_consumer_read_state.index,
)
cute.gemm(
tiled_mma,
accumulators,
tCrA[k_block_coord],
tCrB[k_block_coord],
accumulators,
)
cute.nvgpu.warpgroup.commit_group()
mainloop_consumer_read_state.advance()
for k_tile in cutlass.range(prologue_mma_cnt, cur_k_tile_cnt, 1, unroll=1):
# Wait for TMA copies to complete
mainloop_pipeline.consumer_wait(mainloop_consumer_read_state)
# WGMMA
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
k_block_coord = (
None,
None,
k_block_idx,
mainloop_consumer_read_state.index,
)
cute.gemm(
tiled_mma,
accumulators,
tCrA[k_block_coord],
tCrB[k_block_coord],
accumulators,
)
cute.nvgpu.warpgroup.commit_group()
# Wait on the wgmma barrier for WGMMA to complete
cute.nvgpu.warpgroup.wait_group(k_pipe_mmas)
mainloop_pipeline.consumer_release(mainloop_consumer_release_state)
mainloop_consumer_release_state.advance()
mainloop_consumer_read_state.advance()
cute.nvgpu.warpgroup.wait_group(0)
for k_tile in cutlass.range(0, prologue_mma_cnt, 1, unroll=1):
mainloop_pipeline.consumer_release(mainloop_consumer_release_state)
mainloop_consumer_release_state.advance()
# Epilogue
tCgC_for_tma_partition = cute.zipped_divide(gC_mnl_slice, self.epi_tile)
# thread(b)lock-partition for (s)mem to (g)mem copy (bSG_)
bSG_sD, bSG_gD = cute.nvgpu.cpasync.tma_partition(
tma_atom_c,
0,
cute.make_layout(1),
cute.group_modes(sC, 0, 2),
tCgC_for_tma_partition,
)
epi_tile_num = cute.size(tCgC_for_tma_partition, mode=[1])
epi_tile_shape = tCgC_for_tma_partition.shape[1]
epi_tile_layout = cute.make_layout(
epi_tile_shape, stride=(epi_tile_shape[1], 1)
)
num_prev_epi_tiles = tile_sched.num_tiles_executed * epi_tile_num
for epi_idx in cutlass.range_constexpr(epi_tile_num):
# Copy from accumulators to D registers
for epi_v in cutlass.range_constexpr(size_tRS_rD):
tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v]
# Type conversion
acc_vec = tRS_rD.load()
tRS_rD_out.store(acc_vec.to(self.c_dtype))
# Copy from D registers to shared memory
epi_buffer = (num_prev_epi_tiles + epi_idx) % cute.size(
tRS_sD, mode=[3]
)
cute.copy(
tiled_copy_r2s,
tRS_rD_out,
tRS_sD[(None, None, None, epi_buffer)],
)
cute.arch.fence_proxy(
"async.shared",
space="cta",
)
self.epilog_sync_barrier.arrive_and_wait()
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
# Copy from shared memory to global memory (TMA store with updated desc)
if warp_idx == self.epi_store_warp_id:
cute.copy(
tma_atom_c,
bSG_sD[(None, epi_buffer)],
bSG_gD[(None, gmem_coord)],
tma_desc_ptr=tensormap_manager.get_tensormap_ptr(
tensormap_c_ptr, cute.AddressSpace.generic
),
)
tma_store_pipeline.producer_commit()
tma_store_pipeline.producer_acquire()
self.epilog_sync_barrier.arrive_and_wait()
last_group_idx_mma = cur_group_idx
tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
tma_store_pipeline.producer_tail()
@cute.jit
def make_tensor_for_tensormap_update(
self,
group_idx: cutlass.Int32,
dtype: Type[cutlass.Numeric],
problem_shape_mnk: tuple,
strides_abc: cute.Tensor,
tensor_address_abc: cute.Tensor,
tensor_index: int,
):
"""Construct a global tensor for tensormap update from per-group metadata.
:param group_idx: Index of the current group.
:param dtype: Element type of the tensor (A, B, or C).
:param problem_shape_mnk: (M, N, K) of the current group.
:param strides_abc: Tensor of strides, shape (G, 3, 2), dtype Int32.
:param tensor_address_abc: Tensor of base ptrs, shape (G, 3), dtype Int64.
:param tensor_index: 0=A, 1=B, 2=C.
"""
ptr_i64 = tensor_address_abc[(group_idx, tensor_index)]
if cutlass.const_expr(
not isclass(dtype) or not issubclass(dtype, cutlass.Numeric)
):
raise TypeError(
f"dtype must be a type of cutlass.Numeric, got {type(dtype)}"
)
tensor_gmem_ptr = cute.make_ptr(
dtype, ptr_i64, cute.AddressSpace.gmem, assumed_align=16
)
strides_tensor_gmem = strides_abc[(group_idx, tensor_index, None)]
strides_tensor_reg = cute.make_rmem_tensor(
cute.make_layout(2),
strides_abc.element_type,
)
cute.autovec_copy(strides_tensor_gmem, strides_tensor_reg)
stride_mn = strides_tensor_reg[0]
stride_k = strides_tensor_reg[1]
c1 = cutlass.Int32(1)
c0 = cutlass.Int32(0)
if cutlass.const_expr(tensor_index == 0): # tensor A
m = problem_shape_mnk[0]
k = problem_shape_mnk[2]
return cute.make_tensor(
tensor_gmem_ptr,
cute.make_layout((m, k, c1), stride=(stride_mn, stride_k, c0)),
)
elif cutlass.const_expr(tensor_index == 1): # tensor B
n = problem_shape_mnk[1]
k = problem_shape_mnk[2]
return cute.make_tensor(
tensor_gmem_ptr,
cute.make_layout((n, k, c1), stride=(stride_mn, stride_k, c0)),
)
else: # tensor C
m = problem_shape_mnk[0]
n = problem_shape_mnk[1]
return cute.make_tensor(
tensor_gmem_ptr,
cute.make_layout((m, n, c1), stride=(stride_mn, stride_k, c0)),
)
@staticmethod
def _compute_stages(
tile_shape_mnk: tuple[int, int, int],
a_dtype: type[cutlass.Numeric],
b_dtype: type[cutlass.Numeric],
epi_tile: tuple[int, int],
c_dtype: type[cutlass.Numeric],
smem_capacity: int,
occupancy: int,
) -> tuple[int, int]:
"""Computes the number of stages for A/B/C operands based on heuristics.
:param tile_shape_mnk: The shape (M, N, K) of the CTA tile.
:type tile_shape_mnk: tuple[int, int, int]
:param a_dtype: Data type of operand A.
:type a_dtype: type[cutlass.Numeric]
:param b_dtype: Data type of operand B.
:type b_dtype: type[cutlass.Numeric]
:param epi_tile: Epilogue tile shape
:type epi_tile: Tuple[int, int]
:param c_dtype: The data type of the output tensor
:type c_dtype: type[cutlass.Numeric]
:param smem_capacity: Total available shared memory capacity in bytes.
:type smem_capacity: int
:param occupancy: Target number of CTAs per SM (occupancy).
:type occupancy: int
:return: A tuple containing the computed number of stages for:
(A/B operand stages, epilogue stages)
:rtype: tuple[int, int]
"""
a_shape = cute.slice_(tile_shape_mnk, (None, 0, None))
b_shape = cute.slice_(tile_shape_mnk, (0, None, None))
ab_bytes_per_stage = (
cute.size(a_shape) * a_dtype.width // 8
+ cute.size(b_shape) * b_dtype.width // 8
)
c_bytes_per_stage = cute.size(epi_tile) * c_dtype.width // 8
epi_stage = 4
epi_bytes = c_bytes_per_stage * epi_stage
mbar_helpers_bytes = 1024
ab_stage = (
smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes)
) // ab_bytes_per_stage
return ab_stage, epi_stage
@staticmethod
def _sm90_compute_tile_shape_or_override(
tile_shape_mnk: tuple[int, int, int],
element_type: type[cutlass.Numeric],
is_cooperative: bool = False,
epi_tile_override: Optional[tuple[int, int]] = None,
) -> tuple[int, int]:
"""Compute the epilogue tile shape or use override if provided.
:param tile_shape_mnk: CTA tile shape (M,N,K)
:type tile_shape_mnk: Tuple[int, int, int]
:param element_type: Data type of elements
:type element_type: type[cutlass.Numeric]
:param is_cooperative: Whether to use cooperative approach
:type is_cooperative: bool
:param epi_tile_override: Optional override for epilogue tile shape
:type epi_tile_override: Tuple[int, int] or None
:return: Computed epilogue tile shape
:rtype: Tuple[int, int]
"""
if epi_tile_override is not None:
return epi_tile_override
if is_cooperative:
tile_m = min(128, cute.size(tile_shape_mnk, mode=[0]))
tile_n = min(32, cute.size(tile_shape_mnk, mode=[1]))
return (tile_m, tile_n)
else:
n_perf = 64 if element_type.width == 8 else 32
tile_m = min(64, cute.size(tile_shape_mnk, mode=[0]))
tile_n = min(n_perf, cute.size(tile_shape_mnk, mode=[1]))
return (tile_m, tile_n)
@staticmethod
def _make_smem_layouts(
tile_shape_mnk: tuple[int, int, int],
epi_tile: tuple[int, int],
a_dtype: type[cutlass.Numeric],
a_layout: utils.LayoutEnum,
b_dtype: type[cutlass.Numeric],
b_layout: utils.LayoutEnum,
ab_stage: int,
c_dtype: type[cutlass.Numeric],
c_layout: utils.LayoutEnum,
epi_stage: int,
) -> tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]:
"""Create shared memory layouts for A, B, and C tensors.
:param tile_shape_mnk: CTA tile shape (M,N,K)
:type tile_shape_mnk: Tuple[int, int, int]
:param epi_tile: Epilogue tile shape
:type epi_tile: Tuple[int, int]
:param a_dtype: Data type for matrix A
:type a_dtype: type[cutlass.Numeric]
:param a_layout: Layout enum for matrix A
:type a_layout: utils.LayoutEnum
:param b_dtype: Data type for matrix B
:type b_dtype: type[cutlass.Numeric]
:param b_layout: Layout enum for matrix B
:type b_layout: utils.LayoutEnum
:param ab_stage: Number of stages for A/B tensors
:type ab_stage: int
:param c_dtype: Data type for output matrix C
:type c_dtype: type[cutlass.Numeric]
:param c_layout: Layout enum for the output matrix C
:type c_layout: utils.LayoutEnum
:param epi_stage: Number of epilogue stages
:type epi_stage: int
:return: Tuple of shared memory layouts for A, B, and C
:rtype: Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]
"""
a_smem_shape = cute.slice_(tile_shape_mnk, (None, 0, None))
a_is_k_major = (
a_layout.sm90_mma_major_mode() == cute.nvgpu.warpgroup.OperandMajorMode.K
)
b_is_k_major = (
b_layout.sm90_mma_major_mode() == cute.nvgpu.warpgroup.OperandMajorMode.K
)
a_major_mode_size = tile_shape_mnk[2 if a_is_k_major else 0]
a_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom(
sm90_utils.get_smem_layout_atom(
a_layout,
a_dtype,
a_major_mode_size,
),
a_dtype,
)
a_smem_layout_staged = cute.tile_to_shape(
a_smem_layout_atom,
cute.append(a_smem_shape, ab_stage),
order=(0, 1, 2) if a_is_k_major else (1, 0, 2),
)
b_smem_shape = cute.slice_(tile_shape_mnk, (0, None, None))
b_major_mode_size = tile_shape_mnk[2 if b_is_k_major else 1]
b_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom(
sm90_utils.get_smem_layout_atom(
b_layout,
b_dtype,
b_major_mode_size,
),
b_dtype,
)
b_smem_layout_staged = cute.tile_to_shape(
b_smem_layout_atom,
cute.append(b_smem_shape, ab_stage),
order=(0, 1, 2) if b_is_k_major else (1, 0, 2),
)
c_smem_shape = epi_tile
c_major_mode_size = epi_tile[1] if c_layout.is_n_major_c() else epi_tile[0]
c_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom(
sm90_utils.get_smem_layout_atom(
c_layout,
c_dtype,
c_major_mode_size,
),
c_dtype,
)
epi_smem_layout_staged = cute.tile_to_shape(
c_smem_layout_atom,
cute.append(c_smem_shape, epi_stage),
order=(1, 0, 2) if c_layout.is_m_major_c() else (0, 1, 2),
)
return a_smem_layout_staged, b_smem_layout_staged, epi_smem_layout_staged
@staticmethod
def _compute_grid(
total_num_clusters: int,
cluster_shape_mn: tuple[int, int],
max_active_clusters: cutlass.Constexpr,
) -> tuple[utils.PersistentTileSchedulerParams, tuple]:
"""Compute tile scheduler params and grid shape for grouped GEMM.
:param total_num_clusters: Total clusters across all groups.
:type total_num_clusters: int
:param cluster_shape_mn: Shape of each cluster in M, N dimensions.
:type cluster_shape_mn: tuple[int, int]
:param max_active_clusters: Maximum number of active clusters.
:type max_active_clusters: cutlass.Constexpr
:return: (tile_sched_params, grid)
:rtype: tuple
"""
problem_shape_ntile_mnl = (
cluster_shape_mn[0],
cluster_shape_mn[1],
cutlass.Int32(total_num_clusters),
)
tile_sched_params = utils.PersistentTileSchedulerParams(
problem_shape_ntile_mnl, (*cluster_shape_mn, 1)
)
grid = StaticPersistentGroupTileScheduler.get_grid_shape(
tile_sched_params, max_active_clusters
)
return tile_sched_params, grid
@staticmethod
def _make_tma_store_atoms_and_tensors(
tensor_c: cute.Tensor,
epi_smem_layout_staged: cute.ComposedLayout,
epi_tile: tuple[int, int],
) -> tuple[cute.CopyAtom, cute.Tensor]:
"""Create TMA atoms and tensors for C tensor storage.
:param tensor_c: Output tensor C
:type tensor_c: cute.Tensor
:param epi_smem_layout_staged: Shared memory layout for epilogue
:type epi_smem_layout_staged: cute.ComposedLayout
:param epi_tile: Epilogue tile shape
:type epi_tile: Tuple[int, int]
:return: TMA atom and tensor for C
:rtype: Tuple[cute.CopyAtom, cute.Tensor]
"""
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
tma_atom_c, tma_tensor_c = cute.nvgpu.cpasync.make_tiled_tma_atom(
cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(),
tensor_c,
epi_smem_layout,
epi_tile,
)
return tma_atom_c, tma_tensor_c
@staticmethod
def _make_tma_atoms_and_tensors(
tensor: cute.Tensor,
smem_layout_staged: cute.ComposedLayout,
smem_tile: tuple[int, int],
mcast_dim: int,
) -> tuple[cute.CopyAtom, cute.Tensor]:
"""Create TMA atoms and tensors for input tensors.
:param tensor: Input tensor (A or B)
:type tensor: cute.Tensor
:param smem_layout_staged: Shared memory layout for the tensor
:type smem_layout_staged: cute.ComposedLayout
:param smem_tile: Shared memory tile shape
:type smem_tile: Tuple[int, int]
:param mcast_dim: Multicast dimension
:type mcast_dim: int
:return: TMA atom and tensor
:rtype: Tuple[cute.CopyAtom, cute.Tensor]
"""
op = (
cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp()
if mcast_dim == 1
else cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp()
)
smem_layout = cute.slice_(smem_layout_staged, (None, None, 0))
tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
op,
tensor,
smem_layout,
smem_tile,
num_multicast=mcast_dim,
)
return tma_atom, tma_tensor
@staticmethod
def is_valid_dtypes(
a_dtype: Type[cutlass.Numeric],
b_dtype: Type[cutlass.Numeric],
acc_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
a_major: str,
b_major: str,
) -> bool:
"""
Check if the dtypes are valid
:param a_dtype: The data type of tensor A
:type a_dtype: Type[cutlass.Numeric]
:param b_dtype: The data type of tensor B
:type b_dtype: Type[cutlass.Numeric]
:param acc_dtype: The data type of the accumulator
:type acc_dtype: Type[cutlass.Numeric]
:param c_dtype: The data type of the output tensor
:type c_dtype: Type[cutlass.Numeric]
:param a_major: major mode of tensor A
:type a_major: str
:param b_major: major mode of tensor B
:type b_major: str
:return: True if the dtypes are valid, False otherwise
:rtype: bool
"""
is_valid = True
valid_ab_dtypes = {
cutlass.Float16,
cutlass.Float8E4M3FN,
cutlass.Float8E5M2,
cutlass.Uint8,
cutlass.Int8,
}
if a_dtype not in valid_ab_dtypes:
is_valid = False
if b_dtype not in valid_ab_dtypes:
is_valid = False
# make sure a_dtype == b_dtype for Float16
if a_dtype.width == 16 and a_dtype != b_dtype:
is_valid = False
if a_dtype.width != b_dtype.width:
is_valid = False
if not a_dtype.is_same_kind(b_dtype):
is_valid = False
# for 8-bit types, this implementation only supports k-major layout
if (a_dtype.width == 8 and a_major != "k") or (
b_dtype.width == 8 and b_major != "k"
):
is_valid = False
# Define compatibility mapping between accumulator type and AB type
acc_ab_compatibility = {
cutlass.Float32: {
cutlass.Float16,
cutlass.Float8E4M3FN,
cutlass.Float8E5M2,
},
cutlass.Float16: {
cutlass.Float16,
cutlass.Float8E4M3FN,
cutlass.Float8E5M2,
},
cutlass.Int32: {cutlass.Uint8, cutlass.Int8},
}
# Check compatibility between accumulator type and A type
if a_dtype not in acc_ab_compatibility[acc_dtype]:
is_valid = False
# Define compatibility mapping between accumulator type and C type
acc_c_compatibility = {
cutlass.Float32: {
cutlass.Float32,
cutlass.Float16,
cutlass.Float8E4M3FN,
cutlass.Float8E5M2,
},
cutlass.Float16: {
cutlass.Float32,
cutlass.Float16,
cutlass.Float8E4M3FN,
cutlass.Float8E5M2,
},
cutlass.Int32: {
cutlass.Float32,
cutlass.Float16,
cutlass.Int32,
cutlass.Int8,
cutlass.Uint8,
},
}
# Check compatibility between accumulator type and C type
if c_dtype not in acc_c_compatibility[acc_dtype]:
is_valid = False
return is_valid
@staticmethod
def is_valid_tensor_alignment(
m: int,
n: int,
k: int,
l: int,
ab_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
a_major: str,
b_major: str,
c_major: str,
) -> bool:
"""
Check if the tensor alignment is valid
:param m: The number of rows in the A tensor
:type m: int
:param n: The number of columns in the B tensor
:type n: int
:param k: The number of columns in the A tensor
:type k: int
:param l: The number of columns in the C tensor
:type l: int
:param ab_dtype: The data type of the A and B operands
:type ab_dtype: Type[cutlass.Numeric]
:param c_dtype: The data type of the output tensor
:type c_dtype: Type[cutlass.Numeric]
:param a_major: The major axis of the A tensor
:type a_major: str
:param b_major: The major axis of the B tensor
:type b_major: str
:param c_major: The major axis of the C tensor
:type c_major: str
:return: True if the problem shape is valid, False otherwise
:rtype: bool
"""
is_valid = True
def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape):
major_mode_idx = 0 if is_mode0_major else 1
num_major_elements = tensor_shape[major_mode_idx]
num_contiguous_elements = 16 * 8 // dtype.width
return num_major_elements % num_contiguous_elements == 0
if (
not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l))
or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l))
or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l))
):
is_valid = False
return is_valid
# ---------------------------------------------------------------------------
# Helper functions for tensor creation (ported from blackwell/grouped_gemm.py)
# ---------------------------------------------------------------------------
def create_tensor_and_stride(
l: int,
mode0: int,
mode1: int,
is_mode0_major: bool,
dtype: Type[cutlass.Numeric],
is_dynamic_layout: bool = True,
torch_tensor_cpu: "torch.Tensor" = None,
) -> tuple:
"""Create a GPU tensor and return its pointer, torch tensor, cute tensor, CPU tensor, and strides."""
if torch_tensor_cpu is None:
torch_tensor_cpu = cutlass_torch.matrix(l, mode0, mode1, is_mode0_major, dtype)
cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like(
torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16
)
return (
torch_tensor.data_ptr(),
torch_tensor,
cute_tensor,
torch_tensor_cpu,
torch_tensor.stride()[:-1],
)
def create_tensors_for_all_groups(
problem_sizes_mnkl: List[Tuple[int, int, int, int]],
a_dtype: Type[cutlass.Numeric],
b_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
a_major: str,
b_major: str,
c_major: str,
torch_fp32_tensors_abc: List[List] = None,
) -> tuple:
"""Create A/B/C tensors for all groups."""
if torch_fp32_tensors_abc is not None and len(torch_fp32_tensors_abc) != len(
problem_sizes_mnkl
):
raise ValueError("torch_fp32_tensors_abc must have one entry per group")
new_torch_fp32_tensors_abc = (
[] if torch_fp32_tensors_abc is None else torch_fp32_tensors_abc
)
torch_tensors_abc = []
cute_tensors_abc = []
strides_abc = []
ptrs_abc = []
for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl):
existing_cpu_a = (
torch_fp32_tensors_abc[group_idx][0] if torch_fp32_tensors_abc else None
)
existing_cpu_b = (
torch_fp32_tensors_abc[group_idx][1] if torch_fp32_tensors_abc else None
)
existing_cpu_c = (
torch_fp32_tensors_abc[group_idx][2] if torch_fp32_tensors_abc else None
)
ptr_a, torch_a, cute_a, fp32_a, stride_mk_a = create_tensor_and_stride(
l, m, k, a_major == "m", a_dtype, torch_tensor_cpu=existing_cpu_a
)
ptr_b, torch_b, cute_b, fp32_b, stride_nk_b = create_tensor_and_stride(
l, n, k, b_major == "n", b_dtype, torch_tensor_cpu=existing_cpu_b
)
ptr_c, torch_c, cute_c, fp32_c, stride_mn_c = create_tensor_and_stride(
l, m, n, c_major == "m", c_dtype, torch_tensor_cpu=existing_cpu_c
)
if torch_fp32_tensors_abc is None:
new_torch_fp32_tensors_abc.append([fp32_a, fp32_b, fp32_c])
ptrs_abc.append([ptr_a, ptr_b, ptr_c])
torch_tensors_abc.append([torch_a, torch_b, torch_c])
strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c])
cute_tensors_abc.append((cute_a, cute_b, cute_c))
return (
ptrs_abc,
torch_tensors_abc,
cute_tensors_abc,
strides_abc,
new_torch_fp32_tensors_abc,
)
def create_group_metadata(
problem_sizes_mnkl: List[Tuple[int, int, int, int]],
a_major: str,
b_major: str,
c_major: str,
) -> tuple[list[list[int]], list[list[tuple[int, int]]]]:
"""Create per-group pointer/stride metadata without allocating operand tensors."""
def get_stride(mode0: int, mode1: int, is_mode0_major: bool) -> tuple[int, int]:
# Matches the layout produced by cutlass_torch.matrix(...).permute(...).
return (1, mode0) if is_mode0_major else (mode1, 1)
ptrs_abc = []
strides_abc = []
for m, n, k, _ in problem_sizes_mnkl:
ptrs_abc.append([0, 0, 0])
strides_abc.append(
[
get_stride(m, k, a_major == "m"),
get_stride(n, k, b_major == "n"),
get_stride(m, n, c_major == "m"),
]
)
return ptrs_abc, strides_abc
def _to_reference_operand_fp32(
tensor: "torch.Tensor", dtype: Type[cutlass.Numeric]
) -> "torch.Tensor":
"""Convert an operand tensor to fp32 for host-side reference GEMM.
For FP8 dtypes, tensors are stored as int8 bit-patterns by
`cutlass_torch.matrix`, so we must reinterpret before casting.
"""
tensor_cpu = tensor.cpu()
if dtype == cutlass.Float8E4M3FN:
return tensor_cpu.view(torch.float8_e4m3fn).to(dtype=torch.float32)
if dtype == cutlass.Float8E5M2:
return tensor_cpu.view(torch.float8_e5m2).to(dtype=torch.float32)
return tensor_cpu.to(dtype=torch.float32)
def run(
num_groups: int,
problem_sizes_mnkl: List[Tuple[int, int, int, int]],
a_dtype: Type[cutlass.Numeric],
b_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
acc_dtype: Type[cutlass.Numeric],
a_major: str,
b_major: str,
c_major: str,
tile_shape_mn: Tuple[int, int],
cluster_shape_mn: Tuple[int, int],
tensormap_update_mode: utils.TensorMapUpdateMode = utils.TensorMapUpdateMode.SMEM,
tolerance: float = 1e-01,
warmup_iterations: int = 0,
iterations: int = 1,
skip_ref_check: bool = False,
use_cold_l2: bool = False,
**kwargs,
):
"""Prepare per-group tensors, compile, launch, and validate the Hopper grouped GEMM kernel.
:return: Execution time in microseconds.
:rtype: float
"""
print("Running Hopper Grouped GEMM test with:")
print(f"{num_groups} groups")
for i, (m, n, k, l) in enumerate(problem_sizes_mnkl):
print(f"Group {i}: {m}x{n}x{k}x{l}")
print(
f"A dtype: {a_dtype}, B dtype: {b_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}"
)
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
print(f"Tile Shape: {tile_shape_mn}, Cluster Shape: {cluster_shape_mn}")
print(f"Tensor map update mode: {tensormap_update_mode}")
print(f"Tolerance: {tolerance}")
print(f"Warmup iterations: {warmup_iterations}")
print(f"Iterations: {iterations}")
print(f"Skip reference checking: {skip_ref_check}")
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
if not torch.cuda.is_available():
raise RuntimeError("GPU is required to run this example!")
# Validate dtypes (reuse existing static method, check each group)
for m, n, k, l in problem_sizes_mnkl:
if not HopperGroupedGemmPersistentKernel.is_valid_dtypes(
a_dtype, b_dtype, acc_dtype, c_dtype, a_major, b_major
):
raise TypeError(
f"unsupported dtype combination: A {a_dtype}, B {b_dtype}, "
f"Acc {acc_dtype}, C {c_dtype}, {a_major=}, {b_major=}"
)
if not HopperGroupedGemmPersistentKernel.is_valid_tensor_alignment(
m, n, k, l, a_dtype, c_dtype, a_major, b_major, c_major
):
raise TypeError(
f"Group {m}x{n}x{k}x{l}: contiguous dimension not 16-byte aligned"
)
compile_only = skip_ref_check and iterations <= 0
if compile_only:
ptrs_abc, strides_abc = create_group_metadata(
problem_sizes_mnkl, a_major, b_major, c_major
)
torch_tensors_abc = []
torch_fp32_tensors_abc = []
else:
# Create per-group tensors only when we will execute or validate.
(
ptrs_abc,
torch_tensors_abc,
_,
strides_abc,
torch_fp32_tensors_abc,
) = create_tensors_for_all_groups(
problem_sizes_mnkl, a_dtype, b_dtype, c_dtype, a_major, b_major, c_major
)
# Build small "initial" tensors that carry only dtype+majorness (used for TMA atom init)
alignment = 16
min_ab_size = alignment * 8 // a_dtype.width
min_c_size = alignment * 8 // c_dtype.width
initial_cute_tensors_abc = [
create_tensor_and_stride(1, min_ab_size, min_ab_size, a_major == "m", a_dtype)[2],
create_tensor_and_stride(1, min_ab_size, min_ab_size, b_major == "n", b_dtype)[2],
create_tensor_and_stride(1, min_c_size, min_c_size, c_major == "m", c_dtype)[2],
]
hardware_info = utils.HardwareInfo()
sm_count = hardware_info.get_max_active_clusters(1)
max_active_clusters = hardware_info.get_max_active_clusters(
cluster_shape_mn[0] * cluster_shape_mn[1]
)
# Tensor map workspace: (num_sms, 3, bytes_per_tensormap // 8) of Int64
tensormap_shape = (
sm_count,
HopperGroupedGemmPersistentKernel.num_tensormaps,
HopperGroupedGemmPersistentKernel.bytes_per_tensormap // 8,
)
tensor_of_tensormap, tensor_of_tensormap_torch = cutlass_torch.cute_tensor_like(
torch.empty(tensormap_shape, dtype=torch.int64),
cutlass.Int64,
is_dynamic_layout=False,
)
grouped_gemm = HopperGroupedGemmPersistentKernel(
acc_dtype, tile_shape_mn, cluster_shape_mn,
swizzle_size=1, raster_along_m=True,
tensormap_update_mode=tensormap_update_mode,
)
# Build device tensors for problem shapes, strides, and pointers
tensor_of_dim_size_mnkl, tensor_of_dim_size_mnkl_torch = cutlass_torch.cute_tensor_like(
torch.tensor(problem_sizes_mnkl, dtype=torch.int32),
cutlass.Int32,
is_dynamic_layout=False,
assumed_align=16,
)
tensor_of_strides_abc, tensor_of_strides_abc_torch = cutlass_torch.cute_tensor_like(
torch.tensor(strides_abc, dtype=torch.int32),
cutlass.Int32,
is_dynamic_layout=False,
assumed_align=16,
)
tensor_of_ptrs_abc, tensor_of_ptrs_abc_torch = cutlass_torch.cute_tensor_like(
torch.tensor(ptrs_abc, dtype=torch.int64),
cutlass.Int64,
is_dynamic_layout=False,
assumed_align=16,
)
# Compute total number of cluster tiles across all groups
def compute_total_num_clusters(
problem_sizes: List[Tuple[int, int, int, int]],
cluster_tile_shape_mn: Tuple[int, int],
) -> int:
total = 0
for m, n, _, _ in problem_sizes:
nm = (m + cluster_tile_shape_mn[0] - 1) // cluster_tile_shape_mn[0]
nn = (n + cluster_tile_shape_mn[1] - 1) // cluster_tile_shape_mn[1]
total += nm * nn
return total
# cluster tile shape for Hopper: tile_shape_mn * cluster_shape_mn
cluster_tile_shape_mn = (
tile_shape_mn[0] * cluster_shape_mn[0],
tile_shape_mn[1] * cluster_shape_mn[1],
)
total_num_clusters = compute_total_num_clusters(
problem_sizes_mnkl, cluster_tile_shape_mn
)
current_stream = cutlass_torch.default_stream()
# Compile kernel
_compiler = cute.compile
if os.environ.get("CUTE_DSL_KEEP_PTX"):
_compiler = cute.compile[cute.KeepPTX()]
compiled_grouped_gemm = _compiler(
grouped_gemm,
initial_cute_tensors_abc[0],
initial_cute_tensors_abc[1],
initial_cute_tensors_abc[2],
num_groups,
tensor_of_dim_size_mnkl,
tensor_of_strides_abc,
tensor_of_ptrs_abc,
total_num_clusters,
tensor_of_tensormap,
max_active_clusters,
current_stream,
)
if not skip_ref_check:
compiled_grouped_gemm(
initial_cute_tensors_abc[0],
initial_cute_tensors_abc[1],
initial_cute_tensors_abc[2],
tensor_of_dim_size_mnkl,
tensor_of_strides_abc,
tensor_of_ptrs_abc,
tensor_of_tensormap,
current_stream,
)
torch.cuda.synchronize()
for i, (a_t, b_t, c_t) in enumerate(torch_tensors_abc):
a_ref = _to_reference_operand_fp32(a_t, a_dtype)
b_ref = _to_reference_operand_fp32(b_t, b_dtype)
ref = torch.einsum(
"mkl,nkl->mnl",
a_ref,
b_ref,
)
print(f"Checking group {i}...")
torch.testing.assert_close(
c_t.cpu(),
ref.to(cutlass_torch.dtype(c_dtype)),
atol=tolerance,
rtol=1e-03,
)
if iterations <= 0:
return 0
def generate_tensors():
(
ptrs_abc_ws,
torch_tensors_abc_ws,
_,
strides_abc_ws,
__,
) = create_tensors_for_all_groups(
problem_sizes_mnkl, a_dtype, b_dtype, c_dtype, a_major, b_major, c_major,
torch_fp32_tensors_abc,
)
init_ws = [
create_tensor_and_stride(1, min_ab_size, min_ab_size, a_major == "m", a_dtype)[2],
create_tensor_and_stride(1, min_ab_size, min_ab_size, b_major == "n", b_dtype)[2],
create_tensor_and_stride(1, min_c_size, min_c_size, c_major == "m", c_dtype)[2],
]
strides_ws, _ = cutlass_torch.cute_tensor_like(
torch.tensor(strides_abc_ws, dtype=torch.int32),
cutlass.Int32, is_dynamic_layout=False, assumed_align=16,
)
ptrs_ws, _ = cutlass_torch.cute_tensor_like(
torch.tensor(ptrs_abc_ws, dtype=torch.int64),
cutlass.Int64, is_dynamic_layout=False, assumed_align=16,
)
tensormap_ws, _ = cutlass_torch.cute_tensor_like(
torch.empty(tensormap_shape, dtype=torch.int64),
cutlass.Int64, is_dynamic_layout=False,
)
args = testing.JitArguments(
init_ws[0], init_ws[1], init_ws[2],
tensor_of_dim_size_mnkl,
strides_ws, ptrs_ws, tensormap_ws,
current_stream,
)
return args
workspace_count = 1
if use_cold_l2:
one_workspace_bytes = sum(
t.numel() * t.element_size()
for group in torch_tensors_abc
for t in group
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
exec_time = testing.benchmark(
compiled_grouped_gemm,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=current_stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
runtime_s = exec_time / 1.0e6
fmas = sum(m * n * k for m, n, k, _ in problem_sizes_mnkl)
gflops = (2 * fmas / 1.0e9) / runtime_s
print(f"Average Runtime : {exec_time / 1000:.3f} ms")
print(f"GFLOPS : {gflops:.1f}")
return exec_time
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def _parse_comma_separated_ints(s: str) -> tuple:
try:
return tuple(int(x.strip()) for x in s.split(","))
except ValueError:
raise argparse.ArgumentTypeError("Expected comma-separated integers.")
def _parse_problem_sizes(s: str) -> List[Tuple[int, ...]]:
"""Parse e.g. '(4096,4096,4096,1),(512,512,512,1)' into a list of tuples."""
s = s.strip()
if s.startswith("("):
tuples = s.strip("()").split("),(")
result = []
for t in tuples:
nums = [int(x.strip()) for x in t.split(",")]
result.append(tuple(nums))
return result
raise argparse.ArgumentTypeError(
"Expected a list of tuples like '(M,N,K,L),(M,N,K,L)'"
)
def _validate_problem_sizes_args(args, parser: argparse.ArgumentParser) -> None:
if len(args.problem_sizes_mnkl) not in (0, args.num_groups):
parser.error("--problem_sizes_mnkl must contain exactly --num_groups tuples")
for _, _, _, l in args.problem_sizes_mnkl:
if l != 1:
parser.error("l (batch size) must be 1 for all groups")
def _resolve_tensormap_update_mode(
mode: str, parser: argparse.ArgumentParser
) -> utils.TensorMapUpdateMode:
if mode == "GMEM":
return utils.TensorMapUpdateMode.GMEM
if mode == "SMEM":
return utils.TensorMapUpdateMode.SMEM
parser.error("--tensormap_update_mode must be GMEM or SMEM")
return utils.TensorMapUpdateMode.SMEM
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Hopper Grouped GEMM (CuTe DSL)"
)
parser.add_argument(
"--num_groups", type=int, default=1, help="Number of groups"
)
parser.add_argument(
"--problem_sizes_mnkl",
type=_parse_problem_sizes,
default=((4096, 4096, 4096, 1),),
help="Problem sizes per group, e.g. '(4096,4096,4096,1),(512,512,512,1)'",
)
parser.add_argument(
"--tile_shape_mn",
type=_parse_comma_separated_ints,
choices=[(128, 128), (128, 256), (128, 64), (64, 64)],
default=(128, 128),
)
parser.add_argument(
"--cluster_shape_mn",
type=_parse_comma_separated_ints,
choices=[(1, 1), (2, 1), (1, 2), (2, 2)],
default=(1, 1),
)
parser.add_argument(
"--tensormap_update_mode",
type=str,
choices=["GMEM", "SMEM"],
default="SMEM",
help="Tensor map update mode",
)
parser.add_argument("--a_dtype", type=cutlass.dtype, default=cutlass.Float16)
parser.add_argument("--b_dtype", type=cutlass.dtype, default=cutlass.Float16)
parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float16)
parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32)
parser.add_argument("--a_major", choices=["k", "m"], default="k")
parser.add_argument("--b_major", choices=["k", "n"], default="k")
parser.add_argument("--c_major", choices=["n", "m"], default="n")
parser.add_argument("--tolerance", type=float, default=1e-1)
parser.add_argument("--warmup_iterations", type=int, default=0, help="Warmup iterations")
parser.add_argument(
"--iterations", type=int, default=1, help="Number of iterations to run the kernel"
)
parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking")
parser.add_argument("--use_cold_l2", action="store_true", default=False, help="Use cold L2")
args = parser.parse_args()
_validate_problem_sizes_args(args, parser)
tensormap_update_mode = _resolve_tensormap_update_mode(
args.tensormap_update_mode, parser
)
torch.manual_seed(2025)
run(
args.num_groups,
args.problem_sizes_mnkl,
args.a_dtype,
args.b_dtype,
args.c_dtype,
args.acc_dtype,
args.a_major,
args.b_major,
args.c_major,
args.tile_shape_mn,
args.cluster_shape_mn,
tensormap_update_mode,
args.tolerance,
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")