mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-12 01:10:08 +00:00
1326 lines
50 KiB
Python
1326 lines
50 KiB
Python
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
import argparse
|
|
from typing import Tuple, Type
|
|
|
|
import cuda.bindings.driver as cuda
|
|
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
import cutlass.cute.testing as testing
|
|
import cutlass.utils as utils
|
|
import cutlass.pipeline as pipeline
|
|
import cutlass.utils.hopper_helpers as sm90_utils
|
|
|
|
"""
|
|
A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Blackwell Geforce architecture
|
|
using CUTE DSL.
|
|
- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M")
|
|
- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K")
|
|
- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M")
|
|
|
|
This GEMM kernel supports the following features:
|
|
- Utilizes Tensor Memory Access (TMA) for efficient memory operations
|
|
- Utilizes Blackwell MMA for matrix multiply-accumulate (MMA) operations
|
|
- Supports multi-stage pipeline to overlap computation and memory access
|
|
|
|
This GEMM works as follows:
|
|
1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations.
|
|
2. Perform matrix multiply-accumulate (MMA) operations using Blackwell MMA instruction.
|
|
3. Store results from registers (RMEM) to shared memory (SMEM), then to global memory (GMEM) with TMA operations.
|
|
|
|
Blackwell MMA instructions operate as follows:
|
|
- Read matrix A from registers
|
|
- Read matrix B from registers
|
|
- Perform MMA operation and store the result in Accumulator(register)
|
|
|
|
To run this example:
|
|
|
|
.. code-block:: bash
|
|
|
|
python examples/blackwell_geforce/dense_gemm.py \
|
|
--mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
|
|
--a_dtype Float16 --b_dtype Float16 \
|
|
--c_dtype Float16 --acc_dtype Float32 \
|
|
--a_major k --b_major k --c_major n
|
|
|
|
The above example command compute batched gemm with M=8192, N=8192, K=8192,
|
|
batch_count=1. The tile shape is 128x256x64 and the cluster shape is (1,1).
|
|
The input, mma accumulator and output data type are set as fp16, fp32
|
|
and fp16, respectively.
|
|
|
|
To collect performance with NCU profiler:
|
|
|
|
.. code-block:: bash
|
|
|
|
ncu python examples/blackwell_geforce/dense_gemm.py \
|
|
--mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \
|
|
--a_dtype Float16 --b_dtype Float16 \
|
|
--c_dtype Float16 --acc_dtype Float32 \
|
|
--a_major k --b_major k --c_major n
|
|
|
|
Constraints:
|
|
* Supported input data types: fp16, bf16
|
|
* For fp16 types, A and B must have the same data type
|
|
* Only fp32 accumulation is supported in this example
|
|
* CTA tile shape M must be 64/128
|
|
* CTA tile shape N must be 64/128/256
|
|
* CTA tile shape K must be 64
|
|
* Cluster shape M/N must be positive and power of 2, total cluster size <= 4
|
|
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
|
|
i.e, number of elements is a multiple of 8, 16 for Float16, respectively.
|
|
* OOB tiles are not allowed when TMA store is disabled
|
|
"""
|
|
|
|
|
|
# /////////////////////////////////////////////////////////////////////////////
|
|
# Helpers to parse args
|
|
# /////////////////////////////////////////////////////////////////////////////
|
|
def parse_comma_separated_ints(s: str):
|
|
try:
|
|
return tuple([int(x.strip()) for x in s.split(",")])
|
|
except ValueError:
|
|
raise argparse.ArgumentTypeError(
|
|
"Invalid format. Expected comma-separated integers."
|
|
)
|
|
|
|
|
|
def parse_arguments() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Example of MxNxKxL GEMM on Blackwell Geforce.")
|
|
|
|
parser.add_argument(
|
|
"--mnkl",
|
|
type=parse_comma_separated_ints,
|
|
default=(4096, 4096, 4096, 1),
|
|
help="mnkl dimensions (comma-separated)",
|
|
)
|
|
parser.add_argument(
|
|
"--tile_shape_mnk",
|
|
type=parse_comma_separated_ints,
|
|
choices=[
|
|
(64, 64, 64),
|
|
(64, 128, 64),
|
|
(128, 64, 64),
|
|
(128, 128, 64),
|
|
(128, 256, 64),
|
|
(128, 128, 128),
|
|
],
|
|
default=(64, 64, 64),
|
|
help="CTA tile shape (comma-separated)",
|
|
)
|
|
parser.add_argument(
|
|
"--a_dtype",
|
|
type=cutlass.dtype,
|
|
default=cutlass.Float16,
|
|
)
|
|
parser.add_argument(
|
|
"--b_dtype",
|
|
type=cutlass.dtype,
|
|
default=cutlass.Float16,
|
|
)
|
|
parser.add_argument(
|
|
"--c_dtype",
|
|
type=cutlass.dtype,
|
|
default=cutlass.Float16,
|
|
)
|
|
parser.add_argument(
|
|
"--acc_dtype",
|
|
type=cutlass.dtype,
|
|
default=cutlass.Float32,
|
|
)
|
|
parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
|
|
parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
|
|
parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n")
|
|
parser.add_argument(
|
|
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
|
|
)
|
|
parser.add_argument(
|
|
"--warmup_iterations", type=int, default=0, help="Warmup iterations"
|
|
)
|
|
parser.add_argument(
|
|
"--iterations",
|
|
type=int,
|
|
default=1,
|
|
help="Number of iterations to run the kernel",
|
|
)
|
|
parser.add_argument(
|
|
"--skip_ref_check",
|
|
action="store_true",
|
|
default=False,
|
|
help="Skip reference checking",
|
|
)
|
|
parser.add_argument(
|
|
"--use_cold_l2",
|
|
action="store_true",
|
|
default=False,
|
|
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if len(args.mnkl) != 4:
|
|
parser.error("--mnkl must contain exactly 4 values")
|
|
|
|
return args
|
|
|
|
|
|
# /////////////////////////////////////////////////////////////////////////////
|
|
# Host setup and device kernel launch
|
|
# /////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
class Sm120GemmKernel:
|
|
def __init__(
|
|
self,
|
|
acc_dtype,
|
|
tile_shape_mnk,
|
|
):
|
|
self.acc_dtype = acc_dtype
|
|
self.cluster_shape_mnk = (1, 1, 1)
|
|
self.tile_shape_mnk = tuple(tile_shape_mnk)
|
|
self.tiled_mma = None
|
|
self.num_mcast_ctas_a = None
|
|
self.num_mcast_ctas_b = None
|
|
self.is_a_mcast = False
|
|
self.is_b_mcast = False
|
|
|
|
self.occupancy = 1
|
|
# TODO: remove this hard code for user input ?
|
|
self.atom_layout = (2, 2, 1)
|
|
self.num_mma_warps = (
|
|
self.atom_layout[0] * self.atom_layout[1] * self.atom_layout[2]
|
|
)
|
|
self.num_threads_per_warp = 32
|
|
self.threads_per_cta = (
|
|
self.num_mma_warps + 1 # 1 warp for DMA
|
|
) * self.num_threads_per_warp
|
|
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_120")
|
|
|
|
self.ab_stage = None
|
|
self.epi_stage = None
|
|
|
|
self.a_smem_layout_staged = None
|
|
self.b_smem_layout_staged = None
|
|
self.epi_smem_layout_staged = None
|
|
self.epi_tile = None
|
|
|
|
self.shared_storage = None
|
|
self.buffer_align_bytes = 1024
|
|
|
|
self.epilog_sync_barrier = pipeline.NamedBarrier(
|
|
barrier_id=2,
|
|
num_threads=self.num_mma_warps * self.num_threads_per_warp,
|
|
)
|
|
self.load_register_requirement = 40
|
|
self.mma_register_requirement = 232
|
|
|
|
def _setup_attributes(self):
|
|
# TODO: remove this hard code for user input ?
|
|
self.mma_inst_mnk = (16, 8, 16)
|
|
op = cute.nvgpu.warp.MmaF16BF16Op(
|
|
self.a_dtype,
|
|
self.acc_dtype,
|
|
self.mma_inst_mnk,
|
|
)
|
|
tC = cute.make_layout(self.atom_layout)
|
|
permutation_mnk = (
|
|
self.atom_layout[0] * self.mma_inst_mnk[0],
|
|
# TODO: to leverage ldmatrix.x4, when self.atom_layout[1] is 1, mma tile is ((8x16)x2)
|
|
self.atom_layout[1] * self.mma_inst_mnk[1] * 2,
|
|
self.atom_layout[2] * self.mma_inst_mnk[2],
|
|
)
|
|
self.tiled_mma = cute.make_tiled_mma(
|
|
op,
|
|
tC,
|
|
permutation_mnk=permutation_mnk,
|
|
)
|
|
|
|
self.cta_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
|
|
|
|
self.num_mcast_ctas_a = self.cluster_shape_mnk[1]
|
|
self.num_mcast_ctas_b = self.cluster_shape_mnk[0]
|
|
self.is_a_mcast = self.num_mcast_ctas_a > 1
|
|
self.is_b_mcast = self.num_mcast_ctas_b > 1
|
|
|
|
self.epi_tile = sm90_utils.compute_tile_shape_or_override(
|
|
self.tile_shape_mnk, self.c_dtype, is_cooperative=False
|
|
)
|
|
|
|
# Compute stage before compute smem layout
|
|
self.ab_stage, self.epi_stage = self._compute_stages(
|
|
self.tile_shape_mnk,
|
|
self.a_dtype,
|
|
self.b_dtype,
|
|
self.epi_tile,
|
|
self.c_dtype,
|
|
self.smem_capacity,
|
|
self.occupancy,
|
|
)
|
|
|
|
import sys
|
|
|
|
if self.ab_stage == 0:
|
|
print("ab_stage == 0, no enough shared memory. This case will be skipped.")
|
|
sys.exit(0)
|
|
|
|
(
|
|
self.a_smem_layout_staged,
|
|
self.b_smem_layout_staged,
|
|
self.epi_smem_layout_staged,
|
|
) = self._make_smem_layouts(
|
|
self.tile_shape_mnk,
|
|
self.epi_tile,
|
|
self.a_dtype,
|
|
self.a_layout,
|
|
self.b_dtype,
|
|
self.b_layout,
|
|
self.ab_stage,
|
|
self.c_dtype,
|
|
self.c_layout,
|
|
self.epi_stage,
|
|
)
|
|
|
|
@cute.jit
|
|
def __call__(
|
|
self,
|
|
a: cute.Tensor,
|
|
b: cute.Tensor,
|
|
c: cute.Tensor,
|
|
max_active_clusters: cutlass.Constexpr,
|
|
stream: cuda.CUstream,
|
|
):
|
|
"""Execute the GEMM operation in steps:
|
|
- Setup static attributes
|
|
- Setup TMA load/store atoms and tensors
|
|
- Compute grid size
|
|
- Define shared storage for kernel
|
|
- Launch the kernel synchronously
|
|
|
|
:param a: Input tensor A
|
|
:type a: cute.Tensor
|
|
:param b: Input tensor B
|
|
:type b: cute.Tensor
|
|
:param c: Output tensor C
|
|
:type c: cute.Tensor
|
|
:param stream: CUDA stream for asynchronous execution
|
|
:type stream: cuda.CUstream
|
|
"""
|
|
|
|
# setup static attributes before smem/grid/tma computation
|
|
self.a_dtype = a.element_type
|
|
self.b_dtype = b.element_type
|
|
self.c_dtype = c.element_type
|
|
|
|
self.a_layout = utils.LayoutEnum.from_tensor(a)
|
|
self.b_layout = utils.LayoutEnum.from_tensor(b)
|
|
self.c_layout = utils.LayoutEnum.from_tensor(c)
|
|
|
|
if cutlass.const_expr(
|
|
self.a_dtype.width == 16 and self.a_dtype != self.b_dtype
|
|
):
|
|
raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}")
|
|
if cutlass.const_expr(self.a_dtype.width != self.b_dtype.width):
|
|
raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}")
|
|
if cutlass.const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8):
|
|
raise TypeError("a_dtype should be float16 or float8")
|
|
if cutlass.const_expr(self.b_dtype.width != 16 and self.b_dtype.width != 8):
|
|
raise TypeError("b_dtype should be float16 or float8")
|
|
|
|
self._setup_attributes()
|
|
|
|
tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
|
|
a,
|
|
self.a_smem_layout_staged,
|
|
(self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
|
|
1,
|
|
)
|
|
|
|
tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
|
|
b,
|
|
self.b_smem_layout_staged,
|
|
(self.tile_shape_mnk[1], self.tile_shape_mnk[2]),
|
|
1,
|
|
)
|
|
|
|
tma_atom_c, tma_tensor_c = self._make_tma_store_atoms_and_tensors(
|
|
c,
|
|
self.epi_smem_layout_staged,
|
|
self.epi_tile,
|
|
)
|
|
|
|
tile_sched_params, grid = self._compute_grid(
|
|
c,
|
|
self.tile_shape_mnk,
|
|
max_active_clusters,
|
|
)
|
|
|
|
@cute.struct
|
|
class SharedStorage:
|
|
mainloop_pipeline_array_ptr: cute.struct.MemRange[
|
|
cutlass.Int64, self.ab_stage * 2
|
|
]
|
|
sA: cute.struct.Align[
|
|
cute.struct.MemRange[
|
|
self.a_dtype, cute.cosize(self.a_smem_layout_staged)
|
|
],
|
|
self.buffer_align_bytes,
|
|
]
|
|
sB: cute.struct.Align[
|
|
cute.struct.MemRange[
|
|
self.b_dtype, cute.cosize(self.b_smem_layout_staged)
|
|
],
|
|
self.buffer_align_bytes,
|
|
]
|
|
sC: cute.struct.Align[
|
|
cute.struct.MemRange[
|
|
self.c_dtype, cute.cosize(self.epi_smem_layout_staged)
|
|
],
|
|
self.buffer_align_bytes,
|
|
]
|
|
|
|
self.shared_storage = SharedStorage
|
|
|
|
# Launch the kernel synchronously
|
|
self.kernel(
|
|
tma_atom_a,
|
|
tma_tensor_a,
|
|
tma_atom_b,
|
|
tma_tensor_b,
|
|
tma_atom_c,
|
|
tma_tensor_c,
|
|
self.tiled_mma,
|
|
self.cta_layout_mnk,
|
|
self.a_smem_layout_staged,
|
|
self.b_smem_layout_staged,
|
|
self.epi_smem_layout_staged,
|
|
tile_sched_params,
|
|
).launch(
|
|
grid=grid,
|
|
block=[self.threads_per_cta, 1, 1],
|
|
cluster=[1, 1, 1],
|
|
stream=stream,
|
|
)
|
|
return
|
|
|
|
# GPU device kernel
|
|
@cute.kernel
|
|
def kernel(
|
|
self,
|
|
tma_atom_a: cute.CopyAtom,
|
|
mA_mkl: cute.Tensor,
|
|
tma_atom_b: cute.CopyAtom,
|
|
mB_nkl: cute.Tensor,
|
|
tma_atom_c: cute.CopyAtom,
|
|
mC_mnl: cute.Tensor,
|
|
tiled_mma: cute.TiledMma,
|
|
cta_layout_mnk: cute.Layout,
|
|
a_smem_layout_staged: cute.ComposedLayout,
|
|
b_smem_layout_staged: cute.ComposedLayout,
|
|
epi_smem_layout_staged: cute.ComposedLayout,
|
|
tile_sched_params: utils.PersistentTileSchedulerParams,
|
|
):
|
|
"""
|
|
GPU device kernel performing the batched GEMM computation.
|
|
|
|
:param tma_atom_a: TMA copy atom for A tensor
|
|
:type tma_atom_a: cute.CopyAtom
|
|
:param mA_mkl: Input tensor A
|
|
:type mA_mkl: cute.Tensor
|
|
:param tma_atom_b: TMA copy atom for B tensor
|
|
:type tma_atom_b: cute.CopyAtom
|
|
:param mB_nkl: Input tensor B
|
|
:type mB_nkl: cute.Tensor
|
|
:param tma_atom_c: TMA copy atom for C tensor
|
|
:type tma_atom_c: cute.CopyAtom
|
|
:param mC_mnl: Output tensor C
|
|
:type mC_mnl: cute.Tensor
|
|
:param tiled_mma: Tiled MMA object
|
|
:type tiled_mma: cute.TiledMma
|
|
:param cta_layout_mnk: CTA layout
|
|
:type cta_layout_mnk: cute.Layout
|
|
:param a_smem_layout_staged: Shared memory layout for A
|
|
:type a_smem_layout_staged: cute.ComposedLayout
|
|
:param b_smem_layout_staged: Shared memory layout for B
|
|
:type b_smem_layout_staged: cute.ComposedLayout
|
|
:param epi_smem_layout_staged: Shared memory layout for epilogue
|
|
:type epi_smem_layout_staged: cute.ComposedLayout
|
|
"""
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Get cta/warp/thread idx
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
warp_idx = cute.arch.warp_idx()
|
|
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
|
# bidx, bidy, bidz = cute.arch.block_idx()
|
|
# bdimx, bdimy, bdimz = cute.arch.grid_dim()
|
|
|
|
# /////////////////////////////////////////////////////////////////////////////
|
|
# Prefetch Tma desc
|
|
# /////////////////////////////////////////////////////////////////////////////
|
|
if warp_idx == 0:
|
|
cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_a)
|
|
cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_b)
|
|
cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_c)
|
|
|
|
cta_rank_in_cluster = cute.arch.make_warp_uniform(
|
|
cute.arch.block_idx_in_cluster()
|
|
)
|
|
cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster)
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Get mcast mask
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
a_mcast_mask = cute.make_layout_image_mask(
|
|
cta_layout_mnk, cluster_coord_mnk, mode=1
|
|
)
|
|
b_mcast_mask = cute.make_layout_image_mask(
|
|
cta_layout_mnk, cluster_coord_mnk, mode=0
|
|
)
|
|
|
|
a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0
|
|
b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0
|
|
a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0))
|
|
b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0))
|
|
tma_copy_bytes = cute.size_in_bytes(
|
|
self.a_dtype, a_smem_layout
|
|
) + cute.size_in_bytes(self.b_dtype, b_smem_layout)
|
|
|
|
# /////////////////////////////////////////////////////////////////////////////
|
|
# Alloc and init AB full/empty + ACC full mbar (pipeline)
|
|
# /////////////////////////////////////////////////////////////////////////////
|
|
smem = cutlass.utils.SmemAllocator()
|
|
storage = smem.allocate(self.shared_storage)
|
|
|
|
# mbar arrays
|
|
mainloop_pipeline_array_ptr = storage.mainloop_pipeline_array_ptr.data_ptr()
|
|
|
|
# Threads/warps participating in this pipeline
|
|
mainloop_pipeline_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread
|
|
)
|
|
# Each warp will constribute to the arrive count with the number of mcast size
|
|
mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
|
|
consumer_arrive_cnt = mcast_size * self.num_mma_warps
|
|
mainloop_pipeline_consumer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, consumer_arrive_cnt
|
|
)
|
|
|
|
cta_layout_vmnk = cute.make_layout((1, *cta_layout_mnk.shape))
|
|
mainloop_pipeline = pipeline.PipelineTmaAsync.create(
|
|
num_stages=self.ab_stage,
|
|
producer_group=mainloop_pipeline_producer_group,
|
|
consumer_group=mainloop_pipeline_consumer_group,
|
|
tx_count=tma_copy_bytes,
|
|
barrier_storage=mainloop_pipeline_array_ptr,
|
|
cta_layout_vmnk=cta_layout_vmnk,
|
|
)
|
|
|
|
# Cluster arrive after barrier init
|
|
if cute.size(self.cluster_shape_mnk) > 1:
|
|
cute.arch.cluster_arrive_relaxed()
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Generate smem tensor A/B
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
sA = storage.sA.get_tensor(
|
|
a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner
|
|
)
|
|
sB = storage.sB.get_tensor(
|
|
b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner
|
|
)
|
|
sC = storage.sC.get_tensor(
|
|
epi_smem_layout_staged.outer, swizzle=epi_smem_layout_staged.inner
|
|
)
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Local_tile partition global tensors
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# (bM, bK, loopM, loopK, loopL)
|
|
gA_mkl = cute.local_tile(
|
|
mA_mkl,
|
|
cute.slice_(self.tile_shape_mnk, (None, 0, None)),
|
|
(None, None, None),
|
|
)
|
|
# (bN, bK, loopN, loopK, loopL)
|
|
gB_nkl = cute.local_tile(
|
|
mB_nkl,
|
|
cute.slice_(self.tile_shape_mnk, (0, None, None)),
|
|
(None, None, None),
|
|
)
|
|
# (bM, bN, loopM, loopN, loopL)
|
|
gC_mnl = cute.local_tile(
|
|
mC_mnl,
|
|
cute.slice_(self.tile_shape_mnk, (None, None, 0)),
|
|
(None, None, None),
|
|
)
|
|
|
|
# //////////////////////////////////////////////////////////////////////////////
|
|
# Partition global tensor for TiledMMA_A/B/C
|
|
# //////////////////////////////////////////////////////////////////////////////
|
|
thr_mma = tiled_mma.get_slice(tidx)
|
|
|
|
# //////////////////////////////////////////////////////////////////////////////
|
|
# Partition shared tensor for TMA load A/B
|
|
# //////////////////////////////////////////////////////////////////////////////
|
|
# TMA load A partition_S/D
|
|
a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape)
|
|
a_cta_crd = cluster_coord_mnk[1]
|
|
tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
|
|
tma_atom_a,
|
|
a_cta_crd,
|
|
a_cta_layout,
|
|
cute.group_modes(sA, 0, 2),
|
|
cute.group_modes(gA_mkl, 0, 2),
|
|
)
|
|
|
|
# TMA load B partition_S/D
|
|
b_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape)
|
|
b_cta_crd = cluster_coord_mnk[0]
|
|
tBsB, tBgB = cute.nvgpu.cpasync.tma_partition(
|
|
tma_atom_b,
|
|
b_cta_crd,
|
|
b_cta_layout,
|
|
cute.group_modes(sB, 0, 2),
|
|
cute.group_modes(gB_nkl, 0, 2),
|
|
)
|
|
|
|
# Make frangments
|
|
tCsA = thr_mma.partition_A(sA)
|
|
tCsB = thr_mma.partition_B(sB)
|
|
tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0])
|
|
tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0])
|
|
|
|
tCgC = thr_mma.partition_C(gC_mnl)
|
|
acc_shape = tCgC.shape[:3]
|
|
accumulators = cute.make_rmem_tensor(acc_shape, self.acc_dtype)
|
|
|
|
# cluster wait for barrier init
|
|
if cute.size(self.cluster_shape_mnk) > 1:
|
|
cute.arch.cluster_wait()
|
|
else:
|
|
pipeline.sync(barrier_id=1)
|
|
|
|
k_tile_cnt = cute.size(gA_mkl, mode=[3])
|
|
|
|
# Create the tile scheduler
|
|
tile_sched = utils.StaticPersistentTileScheduler.create(
|
|
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
|
)
|
|
work_tile = tile_sched.initial_work_tile_info()
|
|
|
|
# Create the pipeline states for producer and consumer
|
|
mainloop_producer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Producer, self.ab_stage
|
|
)
|
|
mainloop_consumer_state = pipeline.make_pipeline_state(
|
|
pipeline.PipelineUserType.Consumer, self.ab_stage
|
|
)
|
|
|
|
# MMA warp group
|
|
if warp_idx < self.num_mma_warps:
|
|
cute.arch.setmaxregister_increase(self.mma_register_requirement)
|
|
|
|
num_k_blocks = cute.size(tCrA, mode=[2])
|
|
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
# Copy Atom A/B retiling for TMA load A/B
|
|
# ///////////////////////////////////////////////////////////////////////////////
|
|
atom_copy_ldmatrix_A = cute.make_copy_atom(
|
|
cute.nvgpu.warp.LdMatrix8x8x16bOp(self.a_layout.is_m_major_a(), 4),
|
|
self.a_dtype,
|
|
)
|
|
atom_copy_ldmatrix_B = cute.make_copy_atom(
|
|
cute.nvgpu.warp.LdMatrix8x8x16bOp(self.b_layout.is_n_major_b(), 4),
|
|
self.b_dtype,
|
|
)
|
|
smem_tiled_copy_A = cute.make_tiled_copy_A(atom_copy_ldmatrix_A, tiled_mma)
|
|
|
|
smem_tiled_copy_B = cute.make_tiled_copy_B(atom_copy_ldmatrix_B, tiled_mma)
|
|
|
|
thr_copy_ldmatrix_A = smem_tiled_copy_A.get_slice(tidx)
|
|
thr_copy_ldmatrix_B = smem_tiled_copy_B.get_slice(tidx)
|
|
tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA)
|
|
tCrA_copy_view = thr_copy_ldmatrix_A.retile(tCrA)
|
|
|
|
tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB)
|
|
tCrB_copy_view = thr_copy_ldmatrix_B.retile(tCrB)
|
|
|
|
while work_tile.is_valid_tile:
|
|
tile_coord_mnl = work_tile.tile_idx
|
|
gC_mnl_slice = gC_mnl[(None, None, *tile_coord_mnl)]
|
|
# Clear the accumulator
|
|
accumulators.fill(0.0)
|
|
|
|
# /////////////////////////////////////////////////////////////////////////////
|
|
# Pipelined MAINLOOP
|
|
# /////////////////////////////////////////////////////////////////////////////
|
|
|
|
mainloop_consumer_state.reset_count()
|
|
|
|
peek_ab_full_status = cutlass.Boolean(1)
|
|
if mainloop_consumer_state.count < k_tile_cnt:
|
|
peek_ab_full_status = mainloop_pipeline.consumer_try_wait(
|
|
mainloop_consumer_state
|
|
)
|
|
|
|
# Wait for TMA copies to complete
|
|
mainloop_pipeline.consumer_wait(
|
|
mainloop_consumer_state, peek_ab_full_status
|
|
)
|
|
# tCsA_p: (MMA, (4, MMA_M / 4), MMA_K), tCsA_p: (MMA, (4, MMA_N / 4), MMA_K)
|
|
tCsA_p = tCsA_copy_view[None, None, None, mainloop_consumer_state.index]
|
|
tCsB_p = tCsB_copy_view[None, None, None, mainloop_consumer_state.index]
|
|
cute.copy(
|
|
smem_tiled_copy_A,
|
|
tCsA_p[None, None, 0],
|
|
tCrA_copy_view[None, None, 0],
|
|
)
|
|
cute.copy(
|
|
smem_tiled_copy_B,
|
|
tCsB_p[None, None, 0],
|
|
tCrB_copy_view[None, None, 0],
|
|
)
|
|
|
|
for k_tile in range(0, k_tile_cnt - 1, 1, unroll=1):
|
|
# unroll the loop
|
|
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
|
|
k_block_next = (
|
|
0 if k_block_idx + 1 == num_k_blocks else k_block_idx + 1
|
|
)
|
|
|
|
if k_block_idx == num_k_blocks - 1:
|
|
mainloop_pipeline.consumer_release(mainloop_consumer_state)
|
|
mainloop_consumer_state.advance()
|
|
|
|
peek_ab_full_status = cutlass.Boolean(1)
|
|
peek_ab_full_status = mainloop_pipeline.consumer_try_wait(
|
|
mainloop_consumer_state
|
|
)
|
|
|
|
# tCsA_p: (MMA, (4, MMA_M / 4), MMA_K), tCsA_p: (MMA, (4, MMA_N / 4), MMA_K)
|
|
tCsA_p = tCsA_copy_view[
|
|
None, None, None, mainloop_consumer_state.index
|
|
]
|
|
tCsB_p = tCsB_copy_view[
|
|
None, None, None, mainloop_consumer_state.index
|
|
]
|
|
mainloop_pipeline.consumer_wait(
|
|
mainloop_consumer_state, peek_ab_full_status
|
|
)
|
|
|
|
# Copy data from smem to tCrA/tCrB for the next k_block
|
|
cute.copy(
|
|
smem_tiled_copy_A,
|
|
tCsA_p[None, None, k_block_next],
|
|
tCrA_copy_view[None, None, k_block_next],
|
|
)
|
|
cute.copy(
|
|
smem_tiled_copy_B,
|
|
tCsB_p[None, None, k_block_next],
|
|
tCrB_copy_view[None, None, k_block_next],
|
|
)
|
|
# Gemm of the current k_block
|
|
cute.gemm(
|
|
tiled_mma,
|
|
accumulators,
|
|
tCrA[None, None, k_block_idx],
|
|
tCrB[None, None, k_block_idx],
|
|
accumulators,
|
|
)
|
|
# end of for loop
|
|
# Hoist out last k_tile
|
|
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
|
|
k_block_next = (
|
|
0 if k_block_idx + 1 == num_k_blocks else k_block_idx + 1
|
|
)
|
|
|
|
if k_block_idx == num_k_blocks - 1:
|
|
mainloop_pipeline.consumer_release(mainloop_consumer_state)
|
|
mainloop_consumer_state.advance()
|
|
|
|
if k_block_next > 0:
|
|
cute.copy(
|
|
smem_tiled_copy_A,
|
|
tCsA_p[None, None, k_block_next],
|
|
tCrA_copy_view[None, None, k_block_next],
|
|
)
|
|
cute.copy(
|
|
smem_tiled_copy_B,
|
|
tCsB_p[None, None, k_block_next],
|
|
tCrB_copy_view[None, None, k_block_next],
|
|
)
|
|
# Gemm of the current k_block
|
|
cute.gemm(
|
|
tiled_mma,
|
|
accumulators,
|
|
tCrA[None, None, k_block_idx],
|
|
tCrB[None, None, k_block_idx],
|
|
accumulators,
|
|
)
|
|
|
|
# /////////////////////////////////////////////////////////////////////////////
|
|
# EPILOG
|
|
# /////////////////////////////////////////////////////////////////////////////
|
|
|
|
copy_atom_r2s = sm90_utils.sm90_get_smem_store_op(
|
|
self.c_layout,
|
|
elem_ty_d=self.c_dtype,
|
|
elem_ty_acc=self.acc_dtype,
|
|
)
|
|
|
|
copy_atom_C = cute.make_copy_atom(
|
|
cute.nvgpu.warp.StMatrix8x8x16bOp(
|
|
self.c_layout.is_m_major_c(),
|
|
4,
|
|
),
|
|
self.c_dtype,
|
|
)
|
|
|
|
tiled_copy_C_Atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
|
|
|
|
tiled_copy_r2s = cute.make_tiled_copy_S(
|
|
copy_atom_r2s,
|
|
tiled_copy_C_Atom,
|
|
)
|
|
|
|
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
|
# (R2S, R2S_M, R2S_N, PIPE_D)
|
|
tRS_sD = thr_copy_r2s.partition_D(sC)
|
|
# (R2S, R2S_M, R2S_N)
|
|
tRS_rAcc = tiled_copy_r2s.retile(accumulators)
|
|
|
|
# Allocate D registers.
|
|
rD_shape = cute.shape(thr_copy_r2s.partition_S(sC))
|
|
tRS_rD_layout = cute.make_layout(rD_shape[:3])
|
|
tRS_rD = cute.make_rmem_tensor(tRS_rD_layout.shape, self.acc_dtype)
|
|
size_tRS_rD = cute.size(tRS_rD)
|
|
|
|
sepi_for_tma_partition = cute.group_modes(sC, 0, 2)
|
|
tcgc_for_tma_partition = cute.zipped_divide(gC_mnl_slice, self.epi_tile)
|
|
|
|
bSG_sD, bSG_gD = cute.nvgpu.cpasync.tma_partition(
|
|
tma_atom_c,
|
|
0,
|
|
cute.make_layout(1),
|
|
sepi_for_tma_partition,
|
|
tcgc_for_tma_partition,
|
|
)
|
|
|
|
epi_tile_num = cute.size(tcgc_for_tma_partition, mode=[1])
|
|
epi_tile_shape = tcgc_for_tma_partition.shape[1]
|
|
epi_tile_layout = cute.make_layout(
|
|
epi_tile_shape, stride=(1, epi_tile_shape[0])
|
|
)
|
|
|
|
# Initialize tma store pipeline
|
|
tma_store_producer_group = pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread,
|
|
self.num_mma_warps * self.num_threads_per_warp,
|
|
)
|
|
tma_store_pipeline = pipeline.PipelineTmaStore.create(
|
|
num_stages=self.epi_stage,
|
|
producer_group=tma_store_producer_group,
|
|
)
|
|
|
|
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
|
# Copy from accumulators to D registers
|
|
for epi_v in cutlass.range_constexpr(size_tRS_rD):
|
|
tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v]
|
|
|
|
# Type conversion
|
|
tRS_rD_out = cute.make_rmem_tensor(
|
|
tRS_rD_layout.shape, self.c_dtype
|
|
)
|
|
acc_vec = tRS_rD.load()
|
|
tRS_rD_out.store(acc_vec.to(self.c_dtype))
|
|
|
|
# Register to shared memory
|
|
epi_buffer = epi_idx % cute.size(tRS_sD, mode=[3])
|
|
cute.copy(
|
|
tiled_copy_r2s,
|
|
tRS_rD_out,
|
|
tRS_sD[(None, None, None, epi_buffer)],
|
|
)
|
|
|
|
cute.arch.fence_proxy(
|
|
"async.shared",
|
|
space="cta",
|
|
)
|
|
# barrier for sync
|
|
self.epilog_sync_barrier.arrive_and_wait()
|
|
|
|
# Get the global memory coordinate for the current epi tile.
|
|
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
# Copy from shared memory to global memory
|
|
if warp_idx == 0:
|
|
cute.copy(
|
|
tma_atom_c,
|
|
bSG_sD[(None, epi_buffer)],
|
|
bSG_gD[(None, gmem_coord)],
|
|
)
|
|
tma_store_pipeline.producer_commit()
|
|
tma_store_pipeline.producer_acquire()
|
|
|
|
# Advance to the next work tile
|
|
tile_sched.advance_to_next_work()
|
|
work_tile = tile_sched.get_current_work()
|
|
tma_store_pipeline.producer_tail()
|
|
# End of for k_tile loop
|
|
# End of while loop
|
|
# End of MMA warp group
|
|
# Start of DMA warp group
|
|
elif warp_idx == self.num_mma_warps:
|
|
cute.arch.setmaxregister_decrease(self.load_register_requirement)
|
|
|
|
while work_tile.is_valid_tile:
|
|
tile_coord_mnl = work_tile.tile_idx
|
|
tAgA_mkl = tAgA[(None, tile_coord_mnl[0], None, tile_coord_mnl[2])]
|
|
tBgB_nkl = tBgB[(None, tile_coord_mnl[1], None, tile_coord_mnl[2])]
|
|
|
|
mainloop_producer_state.reset_count()
|
|
|
|
for k_tile in range(0, k_tile_cnt, 1, unroll=1):
|
|
# /////////////////////////////////////////////////////////////////////////////
|
|
# Wait for A/B buffers to be empty before loading into them
|
|
# Also sets the transaction barrier for the A/B buffers
|
|
# /////////////////////////////////////////////////////////////////////////////
|
|
mainloop_pipeline.producer_acquire(mainloop_producer_state)
|
|
|
|
# /////////////////////////////////////////////////////////////////////////////
|
|
# Slice to global/shared memref to current k_tile
|
|
# /////////////////////////////////////////////////////////////////////////////
|
|
tAgA_k = tAgA_mkl[(None, mainloop_producer_state.count)]
|
|
tAsA_pipe = tAsA[(None, mainloop_producer_state.index)]
|
|
|
|
tBgB_k = tBgB_nkl[(None, mainloop_producer_state.count)]
|
|
tBsB_pipe = tBsB[(None, mainloop_producer_state.index)]
|
|
|
|
# /////////////////////////////////////////////////////////////////////////////
|
|
# TMA load A/B
|
|
# /////////////////////////////////////////////////////////////////////////////
|
|
cute.copy(
|
|
tma_atom_a,
|
|
tAgA_k,
|
|
tAsA_pipe,
|
|
tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
|
|
mainloop_producer_state
|
|
),
|
|
mcast_mask=a_mcast_mask,
|
|
)
|
|
cute.copy(
|
|
tma_atom_b,
|
|
tBgB_k,
|
|
tBsB_pipe,
|
|
tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
|
|
mainloop_producer_state
|
|
),
|
|
mcast_mask=b_mcast_mask,
|
|
)
|
|
# Mainloop pipeline's producer commit is a NOP
|
|
mainloop_pipeline.producer_commit(mainloop_producer_state)
|
|
mainloop_producer_state.advance()
|
|
|
|
tile_sched.advance_to_next_work()
|
|
work_tile = tile_sched.get_current_work()
|
|
# end of while loop
|
|
|
|
# Wait A/B buffer empty
|
|
mainloop_pipeline.producer_tail(mainloop_producer_state)
|
|
return
|
|
|
|
@staticmethod
|
|
def _compute_stages(
|
|
tile_shape_mnk: tuple[int, int, int],
|
|
a_dtype: type[cutlass.Numeric],
|
|
b_dtype: type[cutlass.Numeric],
|
|
epi_tile: tuple[int, int],
|
|
c_dtype: type[cutlass.Numeric],
|
|
smem_capacity: int,
|
|
occupancy: int,
|
|
) -> tuple[int, int]:
|
|
"""Computes the number of stages for A/B/C operands based on heuristics.
|
|
|
|
:param tile_shape_mnk: The shape (M, N, K) of the CTA tile.
|
|
:type tile_shape_mnk: tuple[int, int, int]
|
|
:param a_dtype: Data type of operand A.
|
|
:type a_dtype: type[cutlass.Numeric]
|
|
:param b_dtype: Data type of operand B.
|
|
:type b_dtype: type[cutlass.Numeric]
|
|
:param smem_capacity: Total available shared memory capacity in bytes.
|
|
:type smem_capacity: int
|
|
:param occupancy: Target number of CTAs per SM (occupancy).
|
|
:type occupancy: int
|
|
|
|
:return: A tuple containing the computed number of stages for:
|
|
(A/B operand stages, epilogue stages)
|
|
:rtype: tuple[int, int]
|
|
"""
|
|
epi_stage = 8
|
|
c_bytes_per_stage = cute.size(epi_tile) * c_dtype.width // 8
|
|
epi_bytes = c_bytes_per_stage * epi_stage
|
|
|
|
a_shape = cute.slice_(tile_shape_mnk, (None, 0, None))
|
|
b_shape = cute.slice_(tile_shape_mnk, (0, None, None))
|
|
ab_bytes_per_stage = (
|
|
cute.size(a_shape) * a_dtype.width // 8
|
|
+ cute.size(b_shape) * b_dtype.width // 8
|
|
)
|
|
mbar_helpers_bytes = 1024
|
|
|
|
ab_stage = (
|
|
(smem_capacity - occupancy * 1024) // occupancy
|
|
- mbar_helpers_bytes
|
|
- epi_bytes
|
|
) // ab_bytes_per_stage
|
|
return ab_stage, epi_stage
|
|
|
|
@staticmethod
|
|
def _make_smem_layouts(
|
|
tile_shape_mnk: tuple[int, int, int],
|
|
epi_tile: tuple[int, int],
|
|
a_dtype: type[cutlass.Numeric],
|
|
a_layout: cute.Layout,
|
|
b_dtype: type[cutlass.Numeric],
|
|
b_layout: cute.Layout,
|
|
ab_stage: int,
|
|
c_dtype: type[cutlass.Numeric],
|
|
c_layout: cute.Layout,
|
|
epi_stage: int,
|
|
) -> tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]:
|
|
"""Create shared memory layouts for A, B, and C tensors.
|
|
|
|
:param tile_shape_mnk: CTA tile shape (M,N,K)
|
|
:type tile_shape_mnk: Tuple[int, int, int]
|
|
:param epi_tile: Epilogue tile shape
|
|
:type epi_tile: Tuple[int, int]
|
|
:param a_dtype: Data type for matrix A
|
|
:type a_dtype: type[cutlass.Numeric]
|
|
:param a_layout: Layout for matrix A
|
|
:type a_layout: Layout
|
|
:param b_dtype: Data type for matrix B
|
|
:type b_dtype: type[cutlass.Numeric]
|
|
:param b_layout: Layout for matrix B
|
|
:type b_layout: Layout
|
|
:param ab_stage: Number of stages for A/B tensors
|
|
:type ab_stage: int
|
|
:param c_dtype: Data type for output matrix C
|
|
:type c_dtype: type[cutlass.Numeric]
|
|
:param c_layout: leading dimension of the output matrix C
|
|
:type c_layout: Layout
|
|
:param epi_stage: Number of epilogue stages
|
|
:type epi_stage: int
|
|
|
|
:return: Tuple of shared memory layouts for A, B, and C
|
|
:rtype: Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]
|
|
"""
|
|
a_smem_layout_staged = sm90_utils.make_smem_layout_a(
|
|
a_layout,
|
|
tile_shape_mnk,
|
|
a_dtype,
|
|
ab_stage,
|
|
)
|
|
|
|
b_smem_layout_staged = sm90_utils.make_smem_layout_b(
|
|
b_layout,
|
|
tile_shape_mnk,
|
|
b_dtype,
|
|
ab_stage,
|
|
)
|
|
|
|
epi_smem_layout_staged = sm90_utils.make_smem_layout_epi(
|
|
c_dtype,
|
|
c_layout,
|
|
epi_tile,
|
|
epi_stage,
|
|
)
|
|
|
|
return a_smem_layout_staged, b_smem_layout_staged, epi_smem_layout_staged
|
|
|
|
@staticmethod
|
|
def _compute_grid(
|
|
c: cute.Tensor,
|
|
tile_shape_mnk: tuple[int, int, int],
|
|
max_active_clusters: cutlass.Constexpr,
|
|
) -> tuple[int, int, int]:
|
|
"""Compute grid shape for the output tensor C.
|
|
|
|
:param c: The output tensor C
|
|
:type c: cute.Tensor
|
|
:param tile_shape_mnk: The shape (M, N, K) of the CTA tile.
|
|
:type tile_shape_mnk: tuple[int, int, int]
|
|
|
|
:return: Grid shape for kernel launch.
|
|
:rtype: tuple[int, int, int]
|
|
"""
|
|
|
|
c_shape = cute.slice_(tile_shape_mnk, (None, None, 0))
|
|
gc = cute.zipped_divide(c, tiler=c_shape)
|
|
num_ctas_mnl = gc[(0, (None, None, None))].shape
|
|
cluster_shape_mnl = (1, 1, 1)
|
|
tile_sched_params = utils.PersistentTileSchedulerParams(
|
|
num_ctas_mnl, cluster_shape_mnl
|
|
)
|
|
grid = utils.StaticPersistentTileScheduler.get_grid_shape(
|
|
tile_sched_params, max_active_clusters
|
|
)
|
|
return tile_sched_params, grid
|
|
|
|
@staticmethod
|
|
def _make_tma_store_atoms_and_tensors(
|
|
tensor_c: cute.Tensor,
|
|
epi_smem_layout_staged: cute.ComposedLayout,
|
|
epi_tile: tuple[int, int],
|
|
) -> tuple[cute.CopyAtom, cute.Tensor]:
|
|
"""Create TMA atoms and tensors for C tensor storage.
|
|
|
|
:param tensor_c: Output tensor C
|
|
:type tensor_c: cute.Tensor
|
|
:param epi_smem_layout_staged: Shared memory layout for epilogue
|
|
:type epi_smem_layout_staged: cute.ComposedLayout
|
|
:param epi_tile: Epilogue tile shape
|
|
:type epi_tile: Tuple[int, int]
|
|
|
|
:return: TMA atom and tensor for C
|
|
:rtype: Tuple[cute.CopyAtom, cute.Tensor]
|
|
"""
|
|
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
|
|
tma_atom_c, tma_tensor_c = cute.nvgpu.cpasync.make_tiled_tma_atom(
|
|
cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(),
|
|
tensor_c,
|
|
epi_smem_layout,
|
|
epi_tile,
|
|
)
|
|
|
|
return tma_atom_c, tma_tensor_c
|
|
|
|
@staticmethod
|
|
def _make_tma_atoms_and_tensors(
|
|
tensor: cute.Tensor,
|
|
smem_layout_staged: cute.ComposedLayout,
|
|
smem_tile: tuple[int, int],
|
|
mcast_dim: int,
|
|
) -> tuple[cute.CopyAtom, cute.Tensor]:
|
|
"""Create TMA atoms and tensors for input tensors.
|
|
|
|
:param tensor: Input tensor (A or B)
|
|
:type tensor: cute.Tensor
|
|
:param smem_layout_staged: Shared memory layout for the tensor
|
|
:type smem_layout_staged: cute.ComposedLayout
|
|
:param smem_tile: Shared memory tile shape
|
|
:type smem_tile: Tuple[int, int]
|
|
:param mcast_dim: Multicast dimension
|
|
:type mcast_dim: int
|
|
|
|
:return: TMA atom and tensor
|
|
:rtype: Tuple[cute.CopyAtom, cute.Tensor]
|
|
"""
|
|
op = (
|
|
cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp()
|
|
if mcast_dim == 1
|
|
else cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp()
|
|
)
|
|
|
|
smem_layout = cute.slice_(smem_layout_staged, (None, None, 0))
|
|
tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
|
|
op,
|
|
tensor,
|
|
smem_layout,
|
|
smem_tile,
|
|
num_multicast=mcast_dim,
|
|
)
|
|
return tma_atom, tma_tensor
|
|
|
|
|
|
def run(
|
|
mnkl: Tuple[int, int, int, int],
|
|
a_dtype: Type[cutlass.Numeric],
|
|
b_dtype: Type[cutlass.Numeric],
|
|
c_dtype: Type[cutlass.Numeric],
|
|
acc_dtype: Type[cutlass.Numeric],
|
|
a_major: str,
|
|
b_major: str,
|
|
c_major: str,
|
|
tile_shape_mnk: Tuple[int, int, int],
|
|
tolerance: float,
|
|
warmup_iterations: int,
|
|
iterations: int,
|
|
skip_ref_check: bool,
|
|
use_cold_l2: bool = False,
|
|
**kwargs,
|
|
):
|
|
import torch
|
|
import cutlass.torch as cutlass_torch
|
|
|
|
print("Running Blackwell Geforce Dense GEMM with:")
|
|
print(f"mnkl: {mnkl}")
|
|
print(
|
|
f"A dtype: {a_dtype}, B dtype: {b_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}"
|
|
)
|
|
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
|
|
print(f"Tile Shape: {tile_shape_mnk}")
|
|
print(f"Tolerance: {tolerance}")
|
|
print(f"Warmup iterations: {warmup_iterations}")
|
|
print(f"Iterations: {iterations}")
|
|
print(f"Skip reference checking: {skip_ref_check}")
|
|
print(f"Use cold L2: {use_cold_l2}")
|
|
|
|
a_dtype = getattr(cutlass, a_dtype) if isinstance(a_dtype, str) else a_dtype
|
|
b_dtype = getattr(cutlass, b_dtype) if isinstance(b_dtype, str) else b_dtype
|
|
c_dtype = getattr(cutlass, c_dtype) if isinstance(c_dtype, str) else c_dtype
|
|
acc_dtype = getattr(cutlass, acc_dtype) if isinstance(acc_dtype, str) else acc_dtype
|
|
|
|
# Unpack parameters
|
|
m, n, k, l = mnkl
|
|
cluster_shape_mnk = (1, 1, 1)
|
|
|
|
# Prepare pytorch tensors: A, B (random from 0 to 2) and C (all zero)
|
|
if not torch.cuda.is_available():
|
|
raise RuntimeError("GPU is required to run this example!")
|
|
|
|
a_torch_cpu = cutlass_torch.matrix(l, m, k, a_major, a_dtype)
|
|
b_torch_cpu = cutlass_torch.matrix(l, n, k, b_major, b_dtype)
|
|
c_torch_cpu = cutlass_torch.matrix(l, m, n, c_major, c_dtype)
|
|
|
|
def create_cute_tensor(data_ref, cutlass_dtype):
|
|
cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like(
|
|
data_ref, cutlass_dtype, True, 16
|
|
)
|
|
|
|
if cutlass_dtype.is_float and cutlass_dtype.width == 8:
|
|
f32_torch_tensor = data_ref.to(dtype=torch.float32)
|
|
cute_tensor = cutlass_torch.convert_cute_tensor(
|
|
f32_torch_tensor,
|
|
cute_tensor,
|
|
cutlass_dtype,
|
|
is_dynamic_layout=True,
|
|
)
|
|
return cute_tensor, torch_tensor
|
|
|
|
a_tensor, a_torch_gpu = create_cute_tensor(a_torch_cpu, a_dtype)
|
|
b_tensor, b_torch_gpu = create_cute_tensor(b_torch_cpu, b_dtype)
|
|
c_tensor, c_torch_gpu = create_cute_tensor(c_torch_cpu, c_dtype)
|
|
|
|
gemm = Sm120GemmKernel(
|
|
acc_dtype,
|
|
tile_shape_mnk,
|
|
)
|
|
|
|
# Compute max active clusters on current device
|
|
hardware_info = cutlass.utils.HardwareInfo()
|
|
max_active_clusters = hardware_info.get_max_active_clusters(
|
|
cluster_shape_mnk[0] * cluster_shape_mnk[1]
|
|
)
|
|
|
|
# Initialize stream
|
|
stream = cutlass_torch.default_stream()
|
|
# compile gemm kernel
|
|
compiled_gemm = cute.compile(
|
|
gemm, a_tensor, b_tensor, c_tensor, max_active_clusters, stream
|
|
)
|
|
|
|
if not skip_ref_check:
|
|
print("Reference checking ...")
|
|
# execution
|
|
compiled_gemm(a_tensor, b_tensor, c_tensor, stream)
|
|
torch.cuda.synchronize()
|
|
|
|
# Ref check
|
|
ref = torch.einsum(
|
|
"mkl,nkl->mnl",
|
|
a_torch_cpu.to(dtype=torch.float32),
|
|
b_torch_cpu.to(dtype=torch.float32),
|
|
)
|
|
|
|
# Copy gpu tensor to cpu
|
|
kernel_result = c_torch_gpu.cpu()
|
|
|
|
# Convert ref to c_dtype
|
|
_, ref_torch_gpu = create_cute_tensor(ref, c_dtype)
|
|
ref_result = ref_torch_gpu.cpu()
|
|
|
|
# Assert close results
|
|
torch.testing.assert_close(
|
|
kernel_result, ref_result, atol=tolerance, rtol=1e-03
|
|
)
|
|
|
|
def generate_tensors():
|
|
a_torch_cpu = cutlass_torch.matrix(l, m, k, a_major, a_dtype)
|
|
b_torch_cpu = cutlass_torch.matrix(l, n, k, b_major, b_dtype)
|
|
c_torch_cpu = cutlass_torch.matrix(l, m, n, c_major, c_dtype)
|
|
mA_workspace, _ = create_cute_tensor(a_torch_cpu, a_dtype)
|
|
mB_workspace, _ = create_cute_tensor(b_torch_cpu, b_dtype)
|
|
mC_workspace, _ = create_cute_tensor(c_torch_cpu, c_dtype)
|
|
return testing.JitArguments(mA_workspace, mB_workspace, mC_workspace, stream)
|
|
|
|
workspace_count = 1
|
|
if use_cold_l2:
|
|
one_workspace_bytes = (
|
|
a_torch_gpu.numel() * a_torch_gpu.element_size()
|
|
+ b_torch_gpu.numel() * b_torch_gpu.element_size()
|
|
+ c_torch_gpu.numel() * c_torch_gpu.element_size()
|
|
)
|
|
workspace_count = testing.get_workspace_count(
|
|
one_workspace_bytes, warmup_iterations, iterations
|
|
)
|
|
|
|
exec_time = testing.benchmark(
|
|
compiled_gemm,
|
|
workspace_generator=generate_tensors,
|
|
workspace_count=workspace_count,
|
|
stream=stream,
|
|
warmup_iterations=warmup_iterations,
|
|
iterations=iterations,
|
|
)
|
|
|
|
print(f"Execution time: {exec_time} microseconds per iteration")
|
|
|
|
return exec_time # Return execution time in microseconds
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_arguments()
|
|
run(
|
|
args.mnkl,
|
|
args.a_dtype,
|
|
args.b_dtype,
|
|
args.c_dtype,
|
|
args.acc_dtype,
|
|
args.a_major,
|
|
args.b_major,
|
|
args.c_major,
|
|
args.tile_shape_mnk,
|
|
args.tolerance,
|
|
args.warmup_iterations,
|
|
args.iterations,
|
|
args.skip_ref_check,
|
|
args.use_cold_l2,
|
|
)
|
|
print("PASS")
|