mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-12 17:25:45 +00:00
3153 lines
117 KiB
Python
3153 lines
117 KiB
Python
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||
# SPDX-License-Identifier: BSD-3-Clause
|
||
|
||
# Redistribution and use in source and binary forms, with or without
|
||
# modification, are permitted provided that the following conditions are met:
|
||
|
||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||
# list of conditions and the following disclaimer.
|
||
|
||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||
# this list of conditions and the following disclaimer in the documentation
|
||
# and/or other materials provided with the distribution.
|
||
|
||
# 3. Neither the name of the copyright holder nor the names of its
|
||
# contributors may be used to endorse or promote products derived from
|
||
# this software without specific prior written permission.
|
||
|
||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||
|
||
import argparse
|
||
from typing import Type, Tuple, Union, Literal
|
||
|
||
import cuda.bindings.driver as cuda
|
||
import torch
|
||
|
||
import cutlass
|
||
import cutlass.cute as cute
|
||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||
import cutlass.torch as cutlass_torch
|
||
import cutlass.utils as utils
|
||
import cutlass.pipeline as pipeline
|
||
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
|
||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||
import cutlass.utils.blockscaled_layout as blockscaled_utils
|
||
from cutlass.cute.runtime import make_ptr
|
||
|
||
"""
|
||
This example provides an experimental implementation of the SM100 batched dense blockscaled GEMM kernel, please note that the APIs and implementation details related to this kernel may change in future releases.
|
||
|
||
A high-performance persistent batched dense blockscaled GEMM example for the NVIDIA Blackwell SM100 architecture
|
||
using CUTE DSL.
|
||
- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") for MXF8 input type and can only be row-major("K") for MXF4/NVF4 input type
|
||
- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") for MXF8 input type and can only be row-major("K") for MXF4/NVF4 input type
|
||
- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M")
|
||
- Matrix SFA layout is filled internally according to A shape and BlockScaledBasicChunk, which has M×ceil_div(K, sf_vec_size)×L elements respectively
|
||
- Matrix SFB layout is filled internally according to B shape and BlockScaledBasicChunk, which has N×ceil_div(K, sf_vec_size)×L elements respectively
|
||
|
||
This GEMM kernel supports the following features:
|
||
- Utilizes Tensor Memory Access (TMA) for efficient memory operations
|
||
- Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations (including 2cta mma instructions)
|
||
- Implements TMA multicast with cluster to reduce L2 memory traffic
|
||
- Support persistent tile scheduling to better overlap memory load/store with mma between tiles
|
||
- Support warp specialization to avoid explicit pipelining between mainloop load and mma
|
||
|
||
This GEMM works as follows:
|
||
1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations.
|
||
2. MMA warp:
|
||
- Load scale factor A/B from shared memory (SMEM) to tensor memory (TMEM) using tcgen05.cp instruction.
|
||
- Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction.
|
||
3. EPILOGUE warp:
|
||
- Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld.
|
||
- Type convert C matrix to output type.
|
||
- Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations,
|
||
or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations.
|
||
- Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor:
|
||
e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0))
|
||
|
||
SM100 tcgen05.mma.kind.block_scale instructions operate as follows:
|
||
- Read matrix A from SMEM
|
||
- Read matrix B from SMEM
|
||
- Read scalefactor A from TMEM
|
||
- Read scalefactor B from TMEM
|
||
- Write accumulator to TMEM
|
||
The accumulator in TMEM must then be loaded to registers before writing back to GMEM.
|
||
|
||
Input arguments to this example is shown below:
|
||
|
||
.. code-block:: bash
|
||
|
||
python examples/blackwell/dense_blockscaled_gemm_persistent.py \
|
||
--ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 \
|
||
--c_dtype Float16 \
|
||
--mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \
|
||
--mnkl 8192,8192,1024,1
|
||
|
||
To collect performance with NCU profiler:
|
||
|
||
.. code-block:: bash
|
||
|
||
ncu python examples/blackwell/dense_blockscaled_gemm_persistent.py \
|
||
--ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 \
|
||
--c_dtype Float16 \
|
||
--mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \
|
||
--mnkl 8192,8192,1024,1 \
|
||
--warmup_iterations 1 --iterations 10 --skip_ref_check
|
||
|
||
|
||
Constraints:
|
||
* Supported input data types: mxf8, mxf4, nvf4
|
||
see detailed valid dtype combinations in below Sm100BlockScaledPersistentDenseGemmKernel class documentation
|
||
* A/B tensor must have the same data type, mixed data type is not supported (e.g., mxf8 x mxf4)
|
||
* Mma tiler M must be 128 or 256(use_2cta_instrs)
|
||
* Mma tiler N must be 64/128/192/256
|
||
* Cluster shape M/N must be positive and power of 2, total cluster size <= 16
|
||
* Cluster shape M must be multiple of 2 if Mma tiler M is 256(use_2cta_instrs)
|
||
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
|
||
i.e, number of elements is a multiple of 16 and 32 for Float8 and Float4, respectively.
|
||
"""
|
||
|
||
|
||
class Sm100BlockScaledPersistentDenseGemmKernel:
|
||
"""This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types
|
||
and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
|
||
|
||
:param sf_vec_size: Scalefactor vector size.
|
||
:type sf_vec_size: int
|
||
:param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N)
|
||
:type mma_tiler_mn: Tuple[int, int]
|
||
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
|
||
:type cluster_shape_mn: Tuple[int, int]
|
||
|
||
:note: In current version, A and B tensor must have the same data type
|
||
- i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported
|
||
|
||
:note: Supported combinations of A/B data types, SF data typs and SF vector size:
|
||
- MXF8: A/B: Float8E5M2/Float8E4M3FN + SF: Float8E8M0FNU + sf_vec_size: 32
|
||
- MXF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU + sf_vec_size: 32
|
||
- NVF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU/Float8E4M3FN + sf_vec_size: 16
|
||
|
||
:note: Supported accumulator data types:
|
||
- Float32
|
||
|
||
:note: Supported C data types:
|
||
- Float32
|
||
- Float16/BFloat16
|
||
- Float8E4M3FN/Float8E5M2
|
||
:note: Constraints:
|
||
- MMA tiler M must be 128 or 256 (use_2cta_instrs)
|
||
- MMA tiler N must be 64/128/192/256
|
||
- Cluster shape M must be multiple of 2 if Mma tiler M is 256
|
||
- Cluster shape M/N must be positive and power of 2, total cluster size <= 16
|
||
- Also, Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors
|
||
|
||
Example:
|
||
>>> gemm = Sm100BlockScaledPersistentDenseGemmKernel(
|
||
... sf_vec_size=16,
|
||
... mma_tiler_mn=(256, 128),
|
||
... cluster_shape_mn=(2, 1)
|
||
... )
|
||
>>> gemm(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, max_active_clusters, stream)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
sf_vec_size: int,
|
||
mma_tiler_mn: Tuple[int, int],
|
||
cluster_shape_mn: Tuple[int, int],
|
||
):
|
||
"""Initializes the configuration for a Blackwell dense GEMM kernel.
|
||
|
||
This configuration includes several key aspects:
|
||
|
||
1. MMA Instruction Settings (tcgen05):
|
||
- acc_dtype: Data types for MMA accumulator, always set to Float32
|
||
- sf_vec_size: Scalefactor A/B vector size.
|
||
- mma_tiler_mn: The (M, N) shape of the MMA instruction tiler.
|
||
|
||
2. Cluster Shape:
|
||
- cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster.
|
||
|
||
:param sf_vec_size: Scalefactor vector size.
|
||
:type sf_vec_size: int
|
||
:param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction.
|
||
:type mma_tiler_mn: Tuple[int, int]
|
||
:param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster.
|
||
:type cluster_shape_mn: Tuple[int, int]
|
||
"""
|
||
|
||
self.acc_dtype = cutlass.Float32
|
||
self.sf_vec_size = sf_vec_size
|
||
self.use_2cta_instrs = mma_tiler_mn[0] == 256
|
||
self.cluster_shape_mn = cluster_shape_mn
|
||
# K dimension is deferred in _setup_attributes
|
||
self.mma_tiler = (*mma_tiler_mn, 1)
|
||
|
||
self.cta_group = (
|
||
tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||
)
|
||
|
||
self.occupancy = 1
|
||
# Set specialized warp ids
|
||
self.epilog_warp_id = (
|
||
0,
|
||
1,
|
||
2,
|
||
3,
|
||
)
|
||
self.mma_warp_id = 4
|
||
self.tma_warp_id = 5
|
||
self.threads_per_warp = 32
|
||
self.threads_per_cta = self.threads_per_warp * len(
|
||
(self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id)
|
||
)
|
||
# Set barrier id for epilogue sync and tmem ptr sync
|
||
self.epilog_sync_barrier = pipeline.NamedBarrier(
|
||
barrier_id=1,
|
||
num_threads=self.threads_per_warp * len(self.epilog_warp_id),
|
||
)
|
||
self.tmem_alloc_barrier = pipeline.NamedBarrier(
|
||
barrier_id=2,
|
||
num_threads=self.threads_per_warp * len((self.mma_warp_id, *self.epilog_warp_id)),
|
||
)
|
||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||
self.num_tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_100")
|
||
|
||
def _setup_attributes(self):
|
||
"""Set up configurations that are dependent on GEMM inputs
|
||
|
||
This method configures various attributes based on the input tensor properties
|
||
(data types, leading dimensions) and kernel settings:
|
||
- Configuring tiled MMA
|
||
- Computing MMA/cluster/tile shapes
|
||
- Computing cluster layout
|
||
- Computing multicast CTAs for A/B/SFA/SFB
|
||
- Computing epilogue subtile
|
||
- Setting up A/B/SFA/SFB/C stage counts in shared memory
|
||
- Computing A/B/SFA/SFB/C shared memory layout
|
||
"""
|
||
# Compute mma instruction shapes
|
||
# (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K)
|
||
self.mma_inst_shape_mn = (
|
||
self.mma_tiler[0],
|
||
self.mma_tiler[1],
|
||
)
|
||
# (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K)
|
||
self.mma_inst_shape_mn_sfb = (
|
||
self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1),
|
||
cute.round_up(self.mma_inst_shape_mn[1], 128),
|
||
)
|
||
|
||
tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
||
self.a_dtype,
|
||
self.a_major_mode,
|
||
self.b_major_mode,
|
||
self.sf_dtype,
|
||
self.sf_vec_size,
|
||
self.cta_group,
|
||
self.mma_inst_shape_mn,
|
||
)
|
||
|
||
tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
||
self.a_dtype,
|
||
self.a_major_mode,
|
||
self.b_major_mode,
|
||
self.sf_dtype,
|
||
self.sf_vec_size,
|
||
cute.nvgpu.tcgen05.CtaGroup.ONE,
|
||
self.mma_inst_shape_mn_sfb,
|
||
)
|
||
|
||
# Compute mma/cluster/tile shapes
|
||
mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
|
||
mma_inst_tile_k = 4
|
||
self.mma_tiler = (
|
||
self.mma_inst_shape_mn[0],
|
||
self.mma_inst_shape_mn[1],
|
||
mma_inst_shape_k * mma_inst_tile_k,
|
||
)
|
||
self.mma_tiler_sfb = (
|
||
self.mma_inst_shape_mn_sfb[0],
|
||
self.mma_inst_shape_mn_sfb[1],
|
||
mma_inst_shape_k * mma_inst_tile_k,
|
||
)
|
||
self.cta_tile_shape_mnk = (
|
||
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
|
||
self.mma_tiler[1],
|
||
self.mma_tiler[2],
|
||
)
|
||
self.cta_tile_shape_mnk_sfb = (
|
||
self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape),
|
||
self.mma_tiler_sfb[1],
|
||
self.mma_tiler_sfb[2],
|
||
)
|
||
|
||
# Compute cluster layout
|
||
self.cluster_layout_vmnk = cute.tiled_divide(
|
||
cute.make_layout((*self.cluster_shape_mn, 1)),
|
||
(tiled_mma.thr_id.shape,),
|
||
)
|
||
self.cluster_layout_sfb_vmnk = cute.tiled_divide(
|
||
cute.make_layout((*self.cluster_shape_mn, 1)),
|
||
(tiled_mma_sfb.thr_id.shape,),
|
||
)
|
||
|
||
# Compute number of multicast CTAs for A/B
|
||
self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
|
||
self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
|
||
self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1])
|
||
self.is_a_mcast = self.num_mcast_ctas_a > 1
|
||
self.is_b_mcast = self.num_mcast_ctas_b > 1
|
||
self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1
|
||
|
||
# Compute epilogue subtile
|
||
self.epi_tile = sm100_utils.compute_epilogue_tile_shape(
|
||
self.cta_tile_shape_mnk,
|
||
self.use_2cta_instrs,
|
||
self.c_layout,
|
||
self.c_dtype,
|
||
)
|
||
self.epi_tile_n = cute.size(self.epi_tile[1])
|
||
|
||
# Setup A/B/C stage count in shared memory and ACC stage count in tensor memory
|
||
self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages(
|
||
tiled_mma,
|
||
self.mma_tiler,
|
||
self.a_dtype,
|
||
self.b_dtype,
|
||
self.epi_tile,
|
||
self.c_dtype,
|
||
self.c_layout,
|
||
self.sf_dtype,
|
||
self.sf_vec_size,
|
||
self.smem_capacity,
|
||
self.occupancy,
|
||
)
|
||
|
||
# Compute A/B/SFA/SFB/C shared memory layout
|
||
self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
|
||
tiled_mma,
|
||
self.mma_tiler,
|
||
self.a_dtype,
|
||
self.num_ab_stage,
|
||
)
|
||
self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
|
||
tiled_mma,
|
||
self.mma_tiler,
|
||
self.b_dtype,
|
||
self.num_ab_stage,
|
||
)
|
||
self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa(
|
||
tiled_mma,
|
||
self.mma_tiler,
|
||
self.sf_vec_size,
|
||
self.num_ab_stage,
|
||
)
|
||
self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb(
|
||
tiled_mma,
|
||
self.mma_tiler,
|
||
self.sf_vec_size,
|
||
self.num_ab_stage,
|
||
)
|
||
self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
|
||
self.c_dtype,
|
||
self.c_layout,
|
||
self.epi_tile,
|
||
self.num_c_stage,
|
||
)
|
||
|
||
# Overlap and double buffer accumulator when num_acc_stage == 1 for cta_tile_n = 256 case
|
||
self.overlapping_accum = self.num_acc_stage == 1
|
||
|
||
# Compute number of TMEM columns for SFA/SFB/Accumulator
|
||
sf_atom_mn = 32
|
||
self.num_sfa_tmem_cols = (
|
||
self.cta_tile_shape_mnk[0] // sf_atom_mn
|
||
) * mma_inst_tile_k
|
||
self.num_sfb_tmem_cols = (
|
||
self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn
|
||
) * mma_inst_tile_k
|
||
self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols
|
||
self.num_accumulator_tmem_cols = (
|
||
self.cta_tile_shape_mnk[1] * self.num_acc_stage
|
||
if not self.overlapping_accum
|
||
else self.cta_tile_shape_mnk[1] * 2 - self.num_sf_tmem_cols
|
||
)
|
||
|
||
# Only when overlapping_accum is enabled, we need to release accumulator buffer early in epilogue
|
||
self.iter_acc_early_release_in_epilogue = (
|
||
self.num_sf_tmem_cols // self.epi_tile_n
|
||
)
|
||
|
||
@cute.jit
|
||
def __call__(
|
||
self,
|
||
a_ptr: cute.Pointer,
|
||
b_ptr: cute.Pointer,
|
||
sfa_ptr: cute.Pointer,
|
||
sfb_ptr: cute.Pointer,
|
||
c_ptr: cute.Pointer,
|
||
layouts: cutlass.Constexpr[
|
||
Tuple[tcgen05.OperandMajorMode, tcgen05.OperandMajorMode, utils.LayoutEnum]
|
||
],
|
||
problem_mnkl: Tuple[int, int, int, int],
|
||
max_active_clusters: cutlass.Constexpr,
|
||
stream: cuda.CUstream,
|
||
epilogue_op: cutlass.Constexpr = lambda x: x,
|
||
):
|
||
"""Execute the GEMM operation in steps:
|
||
- Setup static attributes before smem/grid/tma computation
|
||
- Setup TMA load/store atoms and tensors
|
||
- Compute grid size with regard to hardware constraints
|
||
- Define shared storage for kernel
|
||
- Launch the kernel synchronously
|
||
|
||
:param a_tensor: Input tensor A
|
||
:type a_tensor: cute.Tensor
|
||
:param b_tensor: Input tensor B
|
||
:type b_tensor: cute.Tensor
|
||
:param sfa_tensor: Scale factor tensor A
|
||
:type sfa_tensor: cute.Tensor
|
||
:param sfb_tensor: Scale factor tensor B
|
||
:type sfb_tensor: cute.Tensor
|
||
:param c_tensor: Output tensor C
|
||
:type c_tensor: cute.Tensor
|
||
:param max_active_clusters: Maximum number of active clusters
|
||
:type max_active_clusters: cutlass.Constexpr
|
||
:param stream: CUDA stream for asynchronous execution
|
||
:type stream: cuda.CUstream
|
||
:param epilogue_op: Optional elementwise lambda function to apply to the output tensor
|
||
:type epilogue_op: cutlass.Constexpr
|
||
:raises TypeError: If input data types are incompatible with the MMA instruction.
|
||
"""
|
||
# Setup static attributes before smem/grid/tma computation
|
||
self.a_dtype: Type[cutlass.Numeric] = a_ptr.value_type
|
||
self.b_dtype: Type[cutlass.Numeric] = b_ptr.value_type
|
||
self.sf_dtype: Type[cutlass.Numeric] = sfa_ptr.value_type
|
||
self.c_dtype: Type[cutlass.Numeric] = c_ptr.value_type
|
||
|
||
m, n, k, l = problem_mnkl
|
||
self.a_major_mode, self.b_major_mode, self.c_layout = layouts
|
||
|
||
# Check if input data types are compatible with MMA instruction
|
||
if cutlass.const_expr(self.a_dtype != self.b_dtype):
|
||
raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}")
|
||
|
||
# Setup attributes that dependent on gemm inputs
|
||
self._setup_attributes()
|
||
|
||
a_layout = cute.make_ordered_layout((m, cute.assume(k, 32), l), order=(0, 1, 2))
|
||
if cutlass.const_expr(self.a_major_mode == tcgen05.OperandMajorMode.K):
|
||
a_layout = cute.make_ordered_layout(
|
||
(cute.assume(m, 32), k, l), order=(1, 0, 2)
|
||
)
|
||
b_layout = cute.make_ordered_layout((n, cute.assume(k, 32), l), order=(0, 1, 2))
|
||
if cutlass.const_expr(self.b_major_mode == tcgen05.OperandMajorMode.K):
|
||
b_layout = cute.make_ordered_layout(
|
||
(cute.assume(n, 32), k, l), order=(1, 0, 2)
|
||
)
|
||
c_layout = cute.make_ordered_layout((cute.assume(m, 32), n, l), order=(0, 1, 2))
|
||
if cutlass.const_expr(self.c_layout == utils.LayoutEnum.ROW_MAJOR):
|
||
c_layout = cute.make_ordered_layout(
|
||
(m, cute.assume(n, 32), l), order=(1, 0, 2)
|
||
)
|
||
a_tensor = cute.make_tensor(a_ptr, a_layout)
|
||
b_tensor = cute.make_tensor(b_ptr, b_layout)
|
||
c_tensor = cute.make_tensor(c_ptr, c_layout)
|
||
|
||
# Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout
|
||
# ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL)
|
||
sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(
|
||
a_tensor.shape, self.sf_vec_size
|
||
)
|
||
sfa_tensor = cute.make_tensor(sfa_ptr, sfa_layout)
|
||
|
||
# ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL)
|
||
sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(
|
||
b_tensor.shape, self.sf_vec_size
|
||
)
|
||
sfb_tensor = cute.make_tensor(sfb_ptr, sfb_layout)
|
||
|
||
tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
||
self.a_dtype,
|
||
self.a_major_mode,
|
||
self.b_major_mode,
|
||
self.sf_dtype,
|
||
self.sf_vec_size,
|
||
self.cta_group,
|
||
self.mma_inst_shape_mn,
|
||
)
|
||
|
||
tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
||
self.a_dtype,
|
||
self.a_major_mode,
|
||
self.b_major_mode,
|
||
self.sf_dtype,
|
||
self.sf_vec_size,
|
||
cute.nvgpu.tcgen05.CtaGroup.ONE,
|
||
self.mma_inst_shape_mn_sfb,
|
||
)
|
||
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
|
||
|
||
# Setup TMA load for A
|
||
a_op = sm100_utils.cluster_shape_to_tma_atom_A(
|
||
self.cluster_shape_mn, tiled_mma.thr_id
|
||
)
|
||
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
|
||
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
|
||
a_op,
|
||
a_tensor,
|
||
a_smem_layout,
|
||
self.mma_tiler,
|
||
tiled_mma,
|
||
self.cluster_layout_vmnk.shape,
|
||
)
|
||
|
||
# Setup TMA load for B
|
||
b_op = sm100_utils.cluster_shape_to_tma_atom_B(
|
||
self.cluster_shape_mn, tiled_mma.thr_id
|
||
)
|
||
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
|
||
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
|
||
b_op,
|
||
b_tensor,
|
||
b_smem_layout,
|
||
self.mma_tiler,
|
||
tiled_mma,
|
||
self.cluster_layout_vmnk.shape,
|
||
)
|
||
|
||
# Setup TMA load for SFA
|
||
sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(
|
||
self.cluster_shape_mn, tiled_mma.thr_id
|
||
)
|
||
sfa_smem_layout = cute.slice_(
|
||
self.sfa_smem_layout_staged, (None, None, None, 0)
|
||
)
|
||
tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
|
||
sfa_op,
|
||
sfa_tensor,
|
||
sfa_smem_layout,
|
||
self.mma_tiler,
|
||
tiled_mma,
|
||
self.cluster_layout_vmnk.shape,
|
||
internal_type=cutlass.Int16,
|
||
)
|
||
|
||
# Setup TMA load for SFB
|
||
sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(
|
||
self.cluster_shape_mn, tiled_mma.thr_id
|
||
)
|
||
sfb_smem_layout = cute.slice_(
|
||
self.sfb_smem_layout_staged, (None, None, None, 0)
|
||
)
|
||
tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
|
||
sfb_op,
|
||
sfb_tensor,
|
||
sfb_smem_layout,
|
||
self.mma_tiler_sfb,
|
||
tiled_mma_sfb,
|
||
self.cluster_layout_sfb_vmnk.shape,
|
||
internal_type=cutlass.Int16,
|
||
)
|
||
|
||
if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192):
|
||
x = tma_tensor_sfb.stride[0][1]
|
||
y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4)
|
||
|
||
new_shape = (
|
||
(tma_tensor_sfb.shape[0][0], ((2, 2), y)),
|
||
tma_tensor_sfb.shape[1],
|
||
tma_tensor_sfb.shape[2],
|
||
)
|
||
# Use right multiplication for ScaledBasis (3 * x instead of x * 3)
|
||
x_times_3 = 3 * x
|
||
new_stride = (
|
||
(tma_tensor_sfb.stride[0][0], ((x, x), x_times_3)),
|
||
tma_tensor_sfb.stride[1],
|
||
tma_tensor_sfb.stride[2],
|
||
)
|
||
tma_tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride)
|
||
tma_tensor_sfb = cute.make_tensor(
|
||
tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout
|
||
)
|
||
|
||
a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout)
|
||
b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout)
|
||
sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout)
|
||
sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout)
|
||
self.num_tma_load_bytes = (
|
||
a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size
|
||
) * atom_thr_size
|
||
|
||
# Setup TMA store for C
|
||
epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
|
||
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
|
||
cpasync.CopyBulkTensorTileS2GOp(),
|
||
c_tensor,
|
||
epi_smem_layout,
|
||
self.epi_tile,
|
||
)
|
||
|
||
# Compute grid size
|
||
self.tile_sched_params, grid = self._compute_grid(
|
||
c_tensor,
|
||
self.cta_tile_shape_mnk,
|
||
self.cluster_shape_mn,
|
||
max_active_clusters,
|
||
)
|
||
|
||
self.buffer_align_bytes = 1024
|
||
|
||
# Define shared storage for kernel
|
||
@cute.struct
|
||
class SharedStorage:
|
||
ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
|
||
ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
|
||
acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
|
||
acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
|
||
tmem_dealloc_mbar_ptr: cutlass.Int64
|
||
tmem_holding_buf: cutlass.Int32
|
||
# (EPI_TILE_M, EPI_TILE_N, STAGE)
|
||
sC: cute.struct.Align[
|
||
cute.struct.MemRange[
|
||
self.c_dtype,
|
||
cute.cosize(self.c_smem_layout_staged.outer),
|
||
],
|
||
self.buffer_align_bytes,
|
||
]
|
||
# (MMA, MMA_M, MMA_K, STAGE)
|
||
sA: cute.struct.Align[
|
||
cute.struct.MemRange[
|
||
self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)
|
||
],
|
||
self.buffer_align_bytes,
|
||
]
|
||
# (MMA, MMA_N, MMA_K, STAGE)
|
||
sB: cute.struct.Align[
|
||
cute.struct.MemRange[
|
||
self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)
|
||
],
|
||
self.buffer_align_bytes,
|
||
]
|
||
# (MMA, MMA_M, MMA_K, STAGE)
|
||
sSFA: cute.struct.Align[
|
||
cute.struct.MemRange[
|
||
self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged)
|
||
],
|
||
self.buffer_align_bytes,
|
||
]
|
||
# (MMA, MMA_N, MMA_K, STAGE)
|
||
sSFB: cute.struct.Align[
|
||
cute.struct.MemRange[
|
||
self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged)
|
||
],
|
||
self.buffer_align_bytes,
|
||
]
|
||
|
||
self.shared_storage = SharedStorage
|
||
|
||
# Launch the kernel synchronously
|
||
self.kernel(
|
||
tiled_mma,
|
||
tiled_mma_sfb,
|
||
tma_atom_a,
|
||
tma_tensor_a,
|
||
tma_atom_b,
|
||
tma_tensor_b,
|
||
tma_atom_sfa,
|
||
tma_tensor_sfa,
|
||
tma_atom_sfb,
|
||
tma_tensor_sfb,
|
||
tma_atom_c,
|
||
tma_tensor_c,
|
||
self.cluster_layout_vmnk,
|
||
self.cluster_layout_sfb_vmnk,
|
||
self.a_smem_layout_staged,
|
||
self.b_smem_layout_staged,
|
||
self.sfa_smem_layout_staged,
|
||
self.sfb_smem_layout_staged,
|
||
self.c_smem_layout_staged,
|
||
self.epi_tile,
|
||
self.tile_sched_params,
|
||
epilogue_op,
|
||
).launch(
|
||
grid=grid,
|
||
block=[self.threads_per_cta, 1, 1],
|
||
cluster=(*self.cluster_shape_mn, 1),
|
||
stream=stream,
|
||
min_blocks_per_mp=1,
|
||
)
|
||
return
|
||
|
||
# GPU device kernel
|
||
@cute.kernel
|
||
def kernel(
|
||
self,
|
||
tiled_mma: cute.TiledMma,
|
||
tiled_mma_sfb: cute.TiledMma,
|
||
tma_atom_a: cute.CopyAtom,
|
||
mA_mkl: cute.Tensor,
|
||
tma_atom_b: cute.CopyAtom,
|
||
mB_nkl: cute.Tensor,
|
||
tma_atom_sfa: cute.CopyAtom,
|
||
mSFA_mkl: cute.Tensor,
|
||
tma_atom_sfb: cute.CopyAtom,
|
||
mSFB_nkl: cute.Tensor,
|
||
tma_atom_c: cute.CopyAtom,
|
||
mC_mnl: cute.Tensor,
|
||
cluster_layout_vmnk: cute.Layout,
|
||
cluster_layout_sfb_vmnk: cute.Layout,
|
||
a_smem_layout_staged: cute.ComposedLayout,
|
||
b_smem_layout_staged: cute.ComposedLayout,
|
||
sfa_smem_layout_staged: cute.Layout,
|
||
sfb_smem_layout_staged: cute.Layout,
|
||
c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout],
|
||
epi_tile: cute.Tile,
|
||
tile_sched_params: utils.PersistentTileSchedulerParams,
|
||
epilogue_op: cutlass.Constexpr,
|
||
):
|
||
"""
|
||
GPU device kernel performing the Persistent batched GEMM computation.
|
||
"""
|
||
warp_idx = cute.arch.warp_idx()
|
||
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
||
|
||
#
|
||
# Prefetch tma desc
|
||
#
|
||
if warp_idx == self.tma_warp_id:
|
||
cpasync.prefetch_descriptor(tma_atom_a)
|
||
cpasync.prefetch_descriptor(tma_atom_b)
|
||
cpasync.prefetch_descriptor(tma_atom_sfa)
|
||
cpasync.prefetch_descriptor(tma_atom_sfb)
|
||
cpasync.prefetch_descriptor(tma_atom_c)
|
||
|
||
use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
|
||
|
||
#
|
||
# Setup cta/thread coordinates
|
||
#
|
||
# Coords inside cluster
|
||
bidx, bidy, bidz = cute.arch.block_idx()
|
||
mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
|
||
is_leader_cta = mma_tile_coord_v == 0
|
||
cta_rank_in_cluster = cute.arch.make_warp_uniform(
|
||
cute.arch.block_idx_in_cluster()
|
||
)
|
||
block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(
|
||
cta_rank_in_cluster
|
||
)
|
||
block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(
|
||
cta_rank_in_cluster
|
||
)
|
||
# Coord inside cta
|
||
tidx, _, _ = cute.arch.thread_idx()
|
||
|
||
#
|
||
# Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier
|
||
#
|
||
smem = utils.SmemAllocator()
|
||
storage = smem.allocate(self.shared_storage)
|
||
|
||
# Initialize mainloop ab_pipeline (barrier) and states
|
||
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
||
num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
|
||
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
|
||
pipeline.Agent.Thread, num_tma_producer
|
||
)
|
||
ab_pipeline = pipeline.PipelineTmaUmma.create(
|
||
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
|
||
num_stages=self.num_ab_stage,
|
||
producer_group=ab_pipeline_producer_group,
|
||
consumer_group=ab_pipeline_consumer_group,
|
||
tx_count=self.num_tma_load_bytes,
|
||
cta_layout_vmnk=cluster_layout_vmnk,
|
||
defer_sync=True,
|
||
)
|
||
|
||
# Initialize acc_pipeline (barrier) and states
|
||
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
||
num_acc_consumer_threads = self.threads_per_warp * len(self.epilog_warp_id) * (
|
||
2 if use_2cta_instrs else 1
|
||
)
|
||
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
|
||
pipeline.Agent.Thread, num_acc_consumer_threads
|
||
)
|
||
acc_pipeline = pipeline.PipelineUmmaAsync.create(
|
||
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
|
||
num_stages=self.num_acc_stage,
|
||
producer_group=acc_pipeline_producer_group,
|
||
consumer_group=acc_pipeline_consumer_group,
|
||
cta_layout_vmnk=cluster_layout_vmnk,
|
||
defer_sync=True,
|
||
)
|
||
|
||
# Tensor memory dealloc barrier init
|
||
tmem = utils.TmemAllocator(
|
||
storage.tmem_holding_buf,
|
||
barrier_for_retrieve=self.tmem_alloc_barrier,
|
||
allocator_warp_id=self.epilog_warp_id[0],
|
||
is_two_cta=use_2cta_instrs,
|
||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
|
||
)
|
||
|
||
# Cluster arrive after barrier init
|
||
pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True)
|
||
|
||
#
|
||
# Setup smem tensor A/B/SFA/SFB/C
|
||
#
|
||
# (EPI_TILE_M, EPI_TILE_N, STAGE)
|
||
sC = storage.sC.get_tensor(
|
||
c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner
|
||
)
|
||
# (MMA, MMA_M, MMA_K, STAGE)
|
||
sA = storage.sA.get_tensor(
|
||
a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner
|
||
)
|
||
# (MMA, MMA_N, MMA_K, STAGE)
|
||
sB = storage.sB.get_tensor(
|
||
b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner
|
||
)
|
||
# (MMA, MMA_M, MMA_K, STAGE)
|
||
sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged)
|
||
# (MMA, MMA_N, MMA_K, STAGE)
|
||
sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged)
|
||
|
||
#
|
||
# Compute multicast mask for A/B/SFA/SFB buffer full
|
||
#
|
||
a_full_mcast_mask = None
|
||
b_full_mcast_mask = None
|
||
sfa_full_mcast_mask = None
|
||
sfb_full_mcast_mask = None
|
||
if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
|
||
a_full_mcast_mask = cpasync.create_tma_multicast_mask(
|
||
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
|
||
)
|
||
b_full_mcast_mask = cpasync.create_tma_multicast_mask(
|
||
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1
|
||
)
|
||
sfa_full_mcast_mask = cpasync.create_tma_multicast_mask(
|
||
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
|
||
)
|
||
sfb_full_mcast_mask = cpasync.create_tma_multicast_mask(
|
||
cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1
|
||
)
|
||
|
||
#
|
||
# Local_tile partition global tensors
|
||
#
|
||
# (bM, bK, RestM, RestK, RestL)
|
||
gA_mkl = cute.local_tile(
|
||
mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)
|
||
)
|
||
# (bN, bK, RestN, RestK, RestL)
|
||
gB_nkl = cute.local_tile(
|
||
mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)
|
||
)
|
||
# (bM, bK, RestM, RestK, RestL)
|
||
gSFA_mkl = cute.local_tile(
|
||
mSFA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)
|
||
)
|
||
# (bN, bK, RestN, RestK, RestL)
|
||
gSFB_nkl = cute.local_tile(
|
||
mSFB_nkl,
|
||
cute.slice_(self.mma_tiler_sfb, (0, None, None)),
|
||
(None, None, None),
|
||
)
|
||
# (bM, bN, RestM, RestN, RestL)
|
||
gC_mnl = cute.local_tile(
|
||
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
|
||
)
|
||
k_tile_cnt = cute.size(gA_mkl, mode=[3])
|
||
|
||
#
|
||
# Partition global tensor for TiledMMA_A/B/C
|
||
#
|
||
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
|
||
thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v)
|
||
# (MMA, MMA_M, MMA_K, RestM, RestK, RestL)
|
||
tCgA = thr_mma.partition_A(gA_mkl)
|
||
# (MMA, MMA_N, MMA_K, RestN, RestK, RestL)
|
||
tCgB = thr_mma.partition_B(gB_nkl)
|
||
# (MMA, MMA_M, MMA_K, RestM, RestK, RestL)
|
||
tCgSFA = thr_mma.partition_A(gSFA_mkl)
|
||
# (MMA, MMA_N, MMA_K, RestN, RestK, RestL)
|
||
tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl)
|
||
# (MMA, MMA_M, MMA_N, RestM, RestN, RestL)
|
||
tCgC = thr_mma.partition_C(gC_mnl)
|
||
|
||
#
|
||
# Partition global/shared tensor for TMA load A/B
|
||
#
|
||
# TMA load A partition_S/D
|
||
a_cta_layout = cute.make_layout(
|
||
cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
|
||
)
|
||
# ((atom_v, rest_v), STAGE)
|
||
# ((atom_v, rest_v), RestM, RestK, RestL)
|
||
tAsA, tAgA = cpasync.tma_partition(
|
||
tma_atom_a,
|
||
block_in_cluster_coord_vmnk[2],
|
||
a_cta_layout,
|
||
cute.group_modes(sA, 0, 3),
|
||
cute.group_modes(tCgA, 0, 3),
|
||
)
|
||
# TMA load B partition_S/D
|
||
b_cta_layout = cute.make_layout(
|
||
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
|
||
)
|
||
# ((atom_v, rest_v), STAGE)
|
||
# ((atom_v, rest_v), RestN, RestK, RestL)
|
||
tBsB, tBgB = cpasync.tma_partition(
|
||
tma_atom_b,
|
||
block_in_cluster_coord_vmnk[1],
|
||
b_cta_layout,
|
||
cute.group_modes(sB, 0, 3),
|
||
cute.group_modes(tCgB, 0, 3),
|
||
)
|
||
|
||
# TMA load scaled factor A partition_S/D
|
||
sfa_cta_layout = a_cta_layout
|
||
# ((atom_v, rest_v), STAGE)
|
||
# ((atom_v, rest_v), RestM, RestK, RestL)
|
||
tAsSFA, tAgSFA = cute.nvgpu.cpasync.tma_partition(
|
||
tma_atom_sfa,
|
||
block_in_cluster_coord_vmnk[2],
|
||
sfa_cta_layout,
|
||
cute.group_modes(sSFA, 0, 3),
|
||
cute.group_modes(tCgSFA, 0, 3),
|
||
)
|
||
tAsSFA = cute.filter_zeros(tAsSFA)
|
||
tAgSFA = cute.filter_zeros(tAgSFA)
|
||
|
||
# TMA load scaled factor B partition_S/D
|
||
sfb_cta_layout = cute.make_layout(
|
||
cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape
|
||
)
|
||
# ((atom_v, rest_v), STAGE)
|
||
# ((atom_v, rest_v), RestN, RestK, RestL)
|
||
tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition(
|
||
tma_atom_sfb,
|
||
block_in_cluster_coord_sfb_vmnk[1],
|
||
sfb_cta_layout,
|
||
cute.group_modes(sSFB, 0, 3),
|
||
cute.group_modes(tCgSFB, 0, 3),
|
||
)
|
||
tBsSFB = cute.filter_zeros(tBsSFB)
|
||
tBgSFB = cute.filter_zeros(tBgSFB)
|
||
|
||
#
|
||
# Partition shared/tensor memory tensor for TiledMMA_A/B/C
|
||
#
|
||
# (MMA, MMA_M, MMA_K, STAGE)
|
||
tCrA = tiled_mma.make_fragment_A(sA)
|
||
# (MMA, MMA_N, MMA_K, STAGE)
|
||
tCrB = tiled_mma.make_fragment_B(sB)
|
||
# (MMA, MMA_M, MMA_N)
|
||
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
|
||
if cutlass.const_expr(self.overlapping_accum):
|
||
num_acc_stage_overlapped = 2
|
||
tCtAcc_fake = tiled_mma.make_fragment_C(
|
||
cute.append(acc_shape, num_acc_stage_overlapped)
|
||
)
|
||
# (MMA, MMA_M, MMA_N, STAGE)
|
||
tCtAcc_fake = cute.make_tensor(
|
||
tCtAcc_fake.iterator,
|
||
cute.make_layout(
|
||
tCtAcc_fake.shape,
|
||
stride=(
|
||
tCtAcc_fake.stride[0],
|
||
tCtAcc_fake.stride[1],
|
||
tCtAcc_fake.stride[2],
|
||
(256 - self.num_sf_tmem_cols) * tCtAcc_fake.stride[0][1],
|
||
),
|
||
),
|
||
)
|
||
else:
|
||
# (MMA, MMA_M, MMA_N, STAGE)
|
||
tCtAcc_fake = tiled_mma.make_fragment_C(
|
||
cute.append(acc_shape, self.num_acc_stage)
|
||
)
|
||
|
||
#
|
||
# Cluster wait before tensor memory alloc
|
||
#
|
||
pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn)
|
||
|
||
#
|
||
# Specialized TMA load warp
|
||
#
|
||
if warp_idx == self.tma_warp_id:
|
||
#
|
||
# Persistent tile scheduling loop
|
||
#
|
||
tile_sched = utils.StaticPersistentTileScheduler.create(
|
||
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
||
)
|
||
work_tile = tile_sched.initial_work_tile_info()
|
||
|
||
ab_producer_state = pipeline.make_pipeline_state(
|
||
pipeline.PipelineUserType.Producer, self.num_ab_stage
|
||
)
|
||
|
||
while work_tile.is_valid_tile:
|
||
# Get tile coord from tile scheduler
|
||
cur_tile_coord = work_tile.tile_idx
|
||
mma_tile_coord_mnl = (
|
||
cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
|
||
cur_tile_coord[1],
|
||
cur_tile_coord[2],
|
||
)
|
||
|
||
#
|
||
# Slice to per mma tile index
|
||
#
|
||
# ((atom_v, rest_v), RestK)
|
||
tAgA_slice = tAgA[
|
||
(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])
|
||
]
|
||
# ((atom_v, rest_v), RestK)
|
||
tBgB_slice = tBgB[
|
||
(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])
|
||
]
|
||
|
||
# ((atom_v, rest_v), RestK)
|
||
tAgSFA_slice = tAgSFA[
|
||
(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])
|
||
]
|
||
|
||
slice_n = mma_tile_coord_mnl[1]
|
||
if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64):
|
||
slice_n = mma_tile_coord_mnl[1] // 2
|
||
# ((atom_v, rest_v), RestK)
|
||
tBgSFB_slice = tBgSFB[(None, slice_n, None, mma_tile_coord_mnl[2])]
|
||
|
||
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt
|
||
ab_producer_state.reset_count()
|
||
peek_ab_empty_status = cutlass.Boolean(1)
|
||
if ab_producer_state.count < k_tile_cnt:
|
||
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
|
||
ab_producer_state
|
||
)
|
||
#
|
||
# Tma load loop
|
||
#
|
||
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
|
||
# Conditionally wait for AB buffer empty
|
||
ab_pipeline.producer_acquire(
|
||
ab_producer_state, peek_ab_empty_status
|
||
)
|
||
|
||
# TMA load A/B/SFA/SFB
|
||
cute.copy(
|
||
tma_atom_a,
|
||
tAgA_slice[(None, ab_producer_state.count)],
|
||
tAsA[(None, ab_producer_state.index)],
|
||
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
||
mcast_mask=a_full_mcast_mask,
|
||
)
|
||
cute.copy(
|
||
tma_atom_b,
|
||
tBgB_slice[(None, ab_producer_state.count)],
|
||
tBsB[(None, ab_producer_state.index)],
|
||
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
||
mcast_mask=b_full_mcast_mask,
|
||
)
|
||
cute.copy(
|
||
tma_atom_sfa,
|
||
tAgSFA_slice[(None, ab_producer_state.count)],
|
||
tAsSFA[(None, ab_producer_state.index)],
|
||
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
||
mcast_mask=sfa_full_mcast_mask,
|
||
)
|
||
cute.copy(
|
||
tma_atom_sfb,
|
||
tBgSFB_slice[(None, ab_producer_state.count)],
|
||
tBsSFB[(None, ab_producer_state.index)],
|
||
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
|
||
mcast_mask=sfb_full_mcast_mask,
|
||
)
|
||
|
||
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1
|
||
ab_producer_state.advance()
|
||
peek_ab_empty_status = cutlass.Boolean(1)
|
||
if ab_producer_state.count < k_tile_cnt:
|
||
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
|
||
ab_producer_state
|
||
)
|
||
|
||
#
|
||
# Advance to next tile
|
||
#
|
||
tile_sched.advance_to_next_work()
|
||
work_tile = tile_sched.get_current_work()
|
||
|
||
#
|
||
# Wait A/B buffer empty
|
||
#
|
||
ab_pipeline.producer_tail(ab_producer_state)
|
||
|
||
#
|
||
# Specialized MMA warp
|
||
#
|
||
if warp_idx == self.mma_warp_id:
|
||
#
|
||
# Bar sync for retrieve tensor memory ptr from shared mem
|
||
#
|
||
tmem.wait_for_alloc()
|
||
|
||
#
|
||
# Retrieving tensor memory ptr and make accumulator/SFA/SFB tensor
|
||
#
|
||
acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
|
||
# Make accumulator tmem tensor
|
||
# (MMA, MMA_M, MMA_N, STAGE)
|
||
tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
|
||
|
||
# Make SFA tmem tensor
|
||
sfa_tmem_ptr = cute.recast_ptr(
|
||
acc_tmem_ptr + self.num_accumulator_tmem_cols,
|
||
dtype=self.sf_dtype,
|
||
)
|
||
# (MMA, MMA_M, MMA_K)
|
||
tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa(
|
||
tiled_mma,
|
||
self.mma_tiler,
|
||
self.sf_vec_size,
|
||
cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)),
|
||
)
|
||
tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout)
|
||
|
||
# Make SFB tmem tensor
|
||
sfb_tmem_ptr = cute.recast_ptr(
|
||
acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols,
|
||
dtype=self.sf_dtype,
|
||
)
|
||
# (MMA, MMA_N, MMA_K)
|
||
tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb(
|
||
tiled_mma,
|
||
self.mma_tiler,
|
||
self.sf_vec_size,
|
||
cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)),
|
||
)
|
||
tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout)
|
||
#
|
||
# Partition for S2T copy of SFA/SFB
|
||
#
|
||
(
|
||
tiled_copy_s2t_sfa,
|
||
tCsSFA_compact_s2t,
|
||
tCtSFA_compact_s2t,
|
||
) = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA)
|
||
(
|
||
tiled_copy_s2t_sfb,
|
||
tCsSFB_compact_s2t,
|
||
tCtSFB_compact_s2t,
|
||
) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB)
|
||
|
||
#
|
||
# Persistent tile scheduling loop
|
||
#
|
||
tile_sched = utils.StaticPersistentTileScheduler.create(
|
||
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
||
)
|
||
work_tile = tile_sched.initial_work_tile_info()
|
||
|
||
ab_consumer_state = pipeline.make_pipeline_state(
|
||
pipeline.PipelineUserType.Consumer, self.num_ab_stage
|
||
)
|
||
acc_producer_state = pipeline.make_pipeline_state(
|
||
pipeline.PipelineUserType.Producer, self.num_acc_stage
|
||
)
|
||
|
||
while work_tile.is_valid_tile:
|
||
# Get tile coord from tile scheduler
|
||
cur_tile_coord = work_tile.tile_idx
|
||
mma_tile_coord_mnl = (
|
||
cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
|
||
cur_tile_coord[1],
|
||
cur_tile_coord[2],
|
||
)
|
||
|
||
# Get accumulator stage index
|
||
if cutlass.const_expr(self.overlapping_accum):
|
||
acc_stage_index = acc_producer_state.phase ^ 1
|
||
else:
|
||
acc_stage_index = acc_producer_state.index
|
||
|
||
# Set tensor memory buffer for current tile
|
||
# (MMA, MMA_M, MMA_N)
|
||
tCtAcc = tCtAcc_base[(None, None, None, acc_stage_index)]
|
||
|
||
# Peek (try_wait) AB buffer full for k_tile = 0
|
||
ab_consumer_state.reset_count()
|
||
peek_ab_full_status = cutlass.Boolean(1)
|
||
if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
|
||
peek_ab_full_status = ab_pipeline.consumer_try_wait(
|
||
ab_consumer_state
|
||
)
|
||
|
||
#
|
||
# Wait for accumulator buffer empty
|
||
#
|
||
if is_leader_cta:
|
||
acc_pipeline.producer_acquire(acc_producer_state)
|
||
|
||
tCtSFB_mma = tCtSFB
|
||
if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192):
|
||
# If this is an ODD tile, shift the TMEM start address for cta_tile_shape_n=192 case by two words (ignores first 64 columns of SFB)
|
||
offset = (
|
||
cutlass.Int32(2)
|
||
if mma_tile_coord_mnl[1] % 2 == 1
|
||
else cutlass.Int32(0)
|
||
)
|
||
shifted_ptr = cute.recast_ptr(
|
||
acc_tmem_ptr
|
||
+ self.num_accumulator_tmem_cols
|
||
+ self.num_sfa_tmem_cols
|
||
+ offset,
|
||
dtype=self.sf_dtype,
|
||
)
|
||
tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout)
|
||
elif cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64):
|
||
# Move in increments of 64 columns of SFB
|
||
offset = cutlass.Int32((mma_tile_coord_mnl[1] % 2) * 2)
|
||
shifted_ptr = cute.recast_ptr(
|
||
acc_tmem_ptr
|
||
+ self.num_accumulator_tmem_cols
|
||
+ self.num_sfa_tmem_cols
|
||
+ offset,
|
||
dtype=self.sf_dtype,
|
||
)
|
||
tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout)
|
||
|
||
#
|
||
# Reset the ACCUMULATE field for each tile
|
||
#
|
||
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||
|
||
#
|
||
# Mma mainloop
|
||
#
|
||
for k_tile in range(k_tile_cnt):
|
||
if is_leader_cta:
|
||
# Conditionally wait for AB buffer full
|
||
ab_pipeline.consumer_wait(
|
||
ab_consumer_state, peek_ab_full_status
|
||
)
|
||
|
||
# Copy SFA/SFB from smem to tmem
|
||
s2t_stage_coord = (
|
||
None,
|
||
None,
|
||
None,
|
||
None,
|
||
ab_consumer_state.index,
|
||
)
|
||
tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord]
|
||
tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord]
|
||
cute.copy(
|
||
tiled_copy_s2t_sfa,
|
||
tCsSFA_compact_s2t_staged,
|
||
tCtSFA_compact_s2t,
|
||
)
|
||
cute.copy(
|
||
tiled_copy_s2t_sfb,
|
||
tCsSFB_compact_s2t_staged,
|
||
tCtSFB_compact_s2t,
|
||
)
|
||
|
||
# tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB
|
||
num_kblocks = cute.size(tCrA, mode=[2])
|
||
for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
|
||
kblock_coord = (
|
||
None,
|
||
None,
|
||
kblock_idx,
|
||
ab_consumer_state.index,
|
||
)
|
||
|
||
# Set SFA/SFB tensor to tiled_mma
|
||
sf_kblock_coord = (None, None, kblock_idx)
|
||
tiled_mma.set(
|
||
tcgen05.Field.SFA,
|
||
tCtSFA[sf_kblock_coord].iterator,
|
||
)
|
||
tiled_mma.set(
|
||
tcgen05.Field.SFB,
|
||
tCtSFB_mma[sf_kblock_coord].iterator,
|
||
)
|
||
|
||
cute.gemm(
|
||
tiled_mma,
|
||
tCtAcc,
|
||
tCrA[kblock_coord],
|
||
tCrB[kblock_coord],
|
||
tCtAcc,
|
||
)
|
||
|
||
# Enable accumulate on tCtAcc after first kblock
|
||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||
|
||
# Async arrive AB buffer empty
|
||
ab_pipeline.consumer_release(ab_consumer_state)
|
||
|
||
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
|
||
ab_consumer_state.advance()
|
||
peek_ab_full_status = cutlass.Boolean(1)
|
||
if ab_consumer_state.count < k_tile_cnt:
|
||
if is_leader_cta:
|
||
peek_ab_full_status = ab_pipeline.consumer_try_wait(
|
||
ab_consumer_state
|
||
)
|
||
|
||
#
|
||
# Async arrive accumulator buffer full
|
||
#
|
||
if is_leader_cta:
|
||
acc_pipeline.producer_commit(acc_producer_state)
|
||
acc_producer_state.advance()
|
||
|
||
#
|
||
# Advance to next tile
|
||
#
|
||
tile_sched.advance_to_next_work()
|
||
work_tile = tile_sched.get_current_work()
|
||
|
||
#
|
||
# Wait for accumulator buffer empty
|
||
#
|
||
acc_pipeline.producer_tail(acc_producer_state)
|
||
#
|
||
# Specialized epilogue warps
|
||
#
|
||
if warp_idx < self.mma_warp_id:
|
||
#
|
||
# Alloc tensor memory buffer
|
||
#
|
||
tmem.allocate(self.num_tmem_alloc_cols)
|
||
|
||
#
|
||
# Bar sync for retrieve tensor memory ptr from shared memory
|
||
#
|
||
tmem.wait_for_alloc()
|
||
|
||
#
|
||
# Retrieving tensor memory ptr and make accumulator tensor
|
||
#
|
||
acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
|
||
# (MMA, MMA_M, MMA_N, STAGE)
|
||
tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
|
||
|
||
#
|
||
# Partition for epilogue
|
||
#
|
||
epi_tidx = tidx
|
||
(
|
||
tiled_copy_t2r,
|
||
tTR_tAcc_base,
|
||
tTR_rAcc,
|
||
) = self.epilog_tmem_copy_and_partition(
|
||
epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs
|
||
)
|
||
|
||
tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype)
|
||
tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(
|
||
tiled_copy_t2r, tTR_rC, epi_tidx, sC
|
||
)
|
||
(
|
||
tma_atom_c,
|
||
bSG_sC,
|
||
bSG_gC_partitioned,
|
||
) = self.epilog_gmem_copy_and_partition(
|
||
epi_tidx, tma_atom_c, tCgC, epi_tile, sC
|
||
)
|
||
|
||
#
|
||
# Persistent tile scheduling loop
|
||
#
|
||
tile_sched = utils.StaticPersistentTileScheduler.create(
|
||
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
||
)
|
||
work_tile = tile_sched.initial_work_tile_info()
|
||
|
||
acc_consumer_state = pipeline.make_pipeline_state(
|
||
pipeline.PipelineUserType.Consumer, self.num_acc_stage
|
||
)
|
||
|
||
# Threads/warps participating in tma store pipeline
|
||
c_producer_group = pipeline.CooperativeGroup(
|
||
pipeline.Agent.Thread,
|
||
self.threads_per_warp * len(self.epilog_warp_id),
|
||
)
|
||
c_pipeline = pipeline.PipelineTmaStore.create(
|
||
num_stages=self.num_c_stage,
|
||
producer_group=c_producer_group,
|
||
)
|
||
|
||
while work_tile.is_valid_tile:
|
||
# Get tile coord from tile scheduler
|
||
cur_tile_coord = work_tile.tile_idx
|
||
mma_tile_coord_mnl = (
|
||
cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
|
||
cur_tile_coord[1],
|
||
cur_tile_coord[2],
|
||
)
|
||
|
||
#
|
||
# Slice to per mma tile index
|
||
#
|
||
# ((ATOM_V, REST_V), EPI_M, EPI_N)
|
||
bSG_gC = bSG_gC_partitioned[
|
||
(
|
||
None,
|
||
None,
|
||
None,
|
||
*mma_tile_coord_mnl,
|
||
)
|
||
]
|
||
|
||
# Get accumulator stage index
|
||
if cutlass.const_expr(self.overlapping_accum):
|
||
acc_stage_index = acc_consumer_state.phase
|
||
reverse_subtile = True if acc_stage_index == 0 else False
|
||
else:
|
||
acc_stage_index = acc_consumer_state.index
|
||
|
||
# Set tensor memory buffer for current tile
|
||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
|
||
tTR_tAcc = tTR_tAcc_base[
|
||
(None, None, None, None, None, acc_stage_index)
|
||
]
|
||
|
||
#
|
||
# Wait for accumulator buffer full
|
||
#
|
||
acc_pipeline.consumer_wait(acc_consumer_state)
|
||
|
||
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
|
||
bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC))
|
||
|
||
#
|
||
# Store accumulator to global memory in subtiles
|
||
#
|
||
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
|
||
num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt
|
||
for subtile_idx in cutlass.range(subtile_cnt):
|
||
real_subtile_idx = subtile_idx
|
||
if cutlass.const_expr(self.overlapping_accum):
|
||
if reverse_subtile:
|
||
real_subtile_idx = (
|
||
self.cta_tile_shape_mnk[1] // self.epi_tile_n
|
||
- 1
|
||
- subtile_idx
|
||
)
|
||
#
|
||
# Load accumulator from tensor memory buffer to register
|
||
#
|
||
tTR_tAcc_mn = tTR_tAcc[(None, None, None, real_subtile_idx)]
|
||
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
|
||
|
||
#
|
||
# Async arrive accumulator buffer empty ealier when overlapping_accum is enabled
|
||
#
|
||
if cutlass.const_expr(self.overlapping_accum):
|
||
if subtile_idx == self.iter_acc_early_release_in_epilogue:
|
||
# Fence for TMEM load
|
||
cute.arch.fence_view_async_tmem_load()
|
||
acc_pipeline.consumer_release(acc_consumer_state)
|
||
acc_consumer_state.advance()
|
||
|
||
#
|
||
# Convert to C type
|
||
#
|
||
acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
|
||
acc_vec = epilogue_op(acc_vec.to(self.c_dtype))
|
||
tRS_rC.store(acc_vec)
|
||
|
||
#
|
||
# Store C to shared memory
|
||
#
|
||
c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage
|
||
cute.copy(
|
||
tiled_copy_r2s,
|
||
tRS_rC,
|
||
tRS_sC[(None, None, None, c_buffer)],
|
||
)
|
||
# Fence and barrier to make sure shared memory store is visible to TMA store
|
||
cute.arch.fence_proxy(
|
||
"async.shared",
|
||
space="cta",
|
||
)
|
||
self.epilog_sync_barrier.arrive_and_wait()
|
||
|
||
#
|
||
# TMA store C to global memory
|
||
#
|
||
if warp_idx == self.epilog_warp_id[0]:
|
||
cute.copy(
|
||
tma_atom_c,
|
||
bSG_sC[(None, c_buffer)],
|
||
bSG_gC[(None, real_subtile_idx)],
|
||
)
|
||
# Fence and barrier to make sure shared memory store is visible to TMA store
|
||
c_pipeline.producer_commit()
|
||
c_pipeline.producer_acquire()
|
||
self.epilog_sync_barrier.arrive_and_wait()
|
||
|
||
#
|
||
# Async arrive accumulator buffer empty
|
||
#
|
||
if cutlass.const_expr(not self.overlapping_accum):
|
||
acc_pipeline.consumer_release(acc_consumer_state)
|
||
acc_consumer_state.advance()
|
||
|
||
#
|
||
# Advance to next tile
|
||
#
|
||
tile_sched.advance_to_next_work()
|
||
work_tile = tile_sched.get_current_work()
|
||
|
||
#
|
||
# Dealloc the tensor memory buffer
|
||
#
|
||
tmem.relinquish_alloc_permit()
|
||
self.epilog_sync_barrier.arrive_and_wait()
|
||
tmem.free(acc_tmem_ptr)
|
||
#
|
||
# Wait for C store complete
|
||
#
|
||
c_pipeline.producer_tail()
|
||
|
||
def mainloop_s2t_copy_and_partition(
|
||
self,
|
||
sSF: cute.Tensor,
|
||
tSF: cute.Tensor,
|
||
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
||
"""
|
||
Make tiledCopy for smem to tmem load for scale factor tensor, then use it to partition smem memory (source) and tensor memory (destination).
|
||
|
||
:param sSF: The scale factor tensor in smem
|
||
:type sSF: cute.Tensor
|
||
:param tSF: The scale factor tensor in tmem
|
||
:type tSF: cute.Tensor
|
||
|
||
:return: A tuple containing (tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t) where:
|
||
- tiled_copy_s2t: The tiled copy operation for smem to tmem load for scale factor tensor(s2t)
|
||
- tCsSF_compact_s2t: The partitioned scale factor tensor in smem
|
||
- tSF_compact_s2t: The partitioned scale factor tensor in tmem
|
||
:rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
|
||
"""
|
||
# (MMA, MMA_MN, MMA_K, STAGE)
|
||
tCsSF_compact = cute.filter_zeros(sSF)
|
||
# (MMA, MMA_MN, MMA_K)
|
||
tCtSF_compact = cute.filter_zeros(tSF)
|
||
|
||
# Make S2T CopyAtom and tiledCopy
|
||
copy_atom_s2t = cute.make_copy_atom(
|
||
tcgen05.Cp4x32x128bOp(self.cta_group),
|
||
self.sf_dtype,
|
||
)
|
||
tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact)
|
||
thr_copy_s2t = tiled_copy_s2t.get_slice(0)
|
||
|
||
# ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE)
|
||
tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact)
|
||
# ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE)
|
||
tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(
|
||
tiled_copy_s2t, tCsSF_compact_s2t_
|
||
)
|
||
# ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K)
|
||
tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact)
|
||
|
||
return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t
|
||
|
||
def epilog_tmem_copy_and_partition(
|
||
self,
|
||
tidx: cutlass.Int32,
|
||
tAcc: cute.Tensor,
|
||
gC_mnl: cute.Tensor,
|
||
epi_tile: cute.Tile,
|
||
use_2cta_instrs: Union[cutlass.Boolean, bool],
|
||
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
||
"""
|
||
Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination).
|
||
|
||
:param tidx: The thread index in epilogue warp groups
|
||
:type tidx: cutlass.Int32
|
||
:param tAcc: The accumulator tensor to be copied and partitioned
|
||
:type tAcc: cute.Tensor
|
||
:param gC_mnl: The global tensor C
|
||
:type gC_mnl: cute.Tensor
|
||
:param epi_tile: The epilogue tiler
|
||
:type epi_tile: cute.Tile
|
||
:param use_2cta_instrs: Whether use_2cta_instrs is enabled
|
||
:type use_2cta_instrs: bool
|
||
|
||
:return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where:
|
||
- tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r)
|
||
- tTR_tAcc: The partitioned accumulator tensor
|
||
- tTR_rAcc: The accumulated tensor in register used to hold t2r results
|
||
:rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
|
||
"""
|
||
# Make tiledCopy for tensor memory load
|
||
copy_atom_t2r = sm100_utils.get_tmem_load_op(
|
||
self.cta_tile_shape_mnk,
|
||
self.c_layout,
|
||
self.c_dtype,
|
||
self.acc_dtype,
|
||
epi_tile,
|
||
use_2cta_instrs,
|
||
)
|
||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE)
|
||
tAcc_epi = cute.flat_divide(
|
||
tAcc[((None, None), 0, 0, None)],
|
||
epi_tile,
|
||
)
|
||
# (EPI_TILE_M, EPI_TILE_N)
|
||
tiled_copy_t2r = tcgen05.make_tmem_copy(
|
||
copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]
|
||
)
|
||
|
||
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
|
||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE)
|
||
tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
|
||
|
||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL)
|
||
gC_mnl_epi = cute.flat_divide(
|
||
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
|
||
)
|
||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL)
|
||
tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi)
|
||
# (T2R, T2R_M, T2R_N)
|
||
tTR_rAcc = cute.make_rmem_tensor(
|
||
tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype
|
||
)
|
||
return tiled_copy_t2r, tTR_tAcc, tTR_rAcc
|
||
|
||
def epilog_smem_copy_and_partition(
|
||
self,
|
||
tiled_copy_t2r: cute.TiledCopy,
|
||
tTR_rC: cute.Tensor,
|
||
tidx: cutlass.Int32,
|
||
sC: cute.Tensor,
|
||
) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
||
"""
|
||
Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination).
|
||
|
||
:param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r)
|
||
:type tiled_copy_t2r: cute.TiledCopy
|
||
:param tTR_rC: The partitioned accumulator tensor
|
||
:type tTR_rC: cute.Tensor
|
||
:param tidx: The thread index in epilogue warp groups
|
||
:type tidx: cutlass.Int32
|
||
:param sC: The shared memory tensor to be copied and partitioned
|
||
:type sC: cute.Tensor
|
||
:type sepi: cute.Tensor
|
||
|
||
:return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where:
|
||
- tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s)
|
||
- tRS_rC: The partitioned tensor C (register source)
|
||
- tRS_sC: The partitioned tensor C (smem destination)
|
||
:rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
|
||
"""
|
||
copy_atom_r2s = sm100_utils.get_smem_store_op(
|
||
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
|
||
)
|
||
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
||
# (R2S, R2S_M, R2S_N, PIPE_D)
|
||
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
||
tRS_sC = thr_copy_r2s.partition_D(sC)
|
||
# (R2S, R2S_M, R2S_N)
|
||
tRS_rC = tiled_copy_r2s.retile(tTR_rC)
|
||
return tiled_copy_r2s, tRS_rC, tRS_sC
|
||
|
||
def epilog_gmem_copy_and_partition(
|
||
self,
|
||
tidx: cutlass.Int32,
|
||
atom: Union[cute.CopyAtom, cute.TiledCopy],
|
||
gC_mnl: cute.Tensor,
|
||
epi_tile: cute.Tile,
|
||
sC: cute.Tensor,
|
||
) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]:
|
||
"""Make tiledCopy for global memory store, then use it to:
|
||
partition shared memory (source) and global memory (destination) for TMA store version.
|
||
|
||
:param tidx: The thread index in epilogue warp groups
|
||
:type tidx: cutlass.Int32
|
||
:param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version
|
||
:type atom: cute.CopyAtom or cute.TiledCopy
|
||
:param gC_mnl: The global tensor C
|
||
:type gC_mnl: cute.Tensor
|
||
:param epi_tile: The epilogue tiler
|
||
:type epi_tile: cute.Tile
|
||
:param sC: The shared memory tensor to be copied and partitioned
|
||
:type sC: cute.Tensor
|
||
|
||
:return: A tuple containing (tma_atom_c, bSG_sC, bSG_gC) where:
|
||
- tma_atom_c: The TMA copy atom
|
||
- bSG_sC: The partitioned shared memory tensor C
|
||
- bSG_gC: The partitioned global tensor C
|
||
:rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]
|
||
"""
|
||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL)
|
||
gC_epi = cute.flat_divide(
|
||
gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile
|
||
)
|
||
|
||
tma_atom_c = atom
|
||
sC_for_tma_partition = cute.group_modes(sC, 0, 2)
|
||
gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2)
|
||
# ((ATOM_V, REST_V), EPI_M, EPI_N)
|
||
# ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL)
|
||
bSG_sC, bSG_gC = cpasync.tma_partition(
|
||
tma_atom_c,
|
||
0,
|
||
cute.make_layout(1),
|
||
sC_for_tma_partition,
|
||
gC_for_tma_partition,
|
||
)
|
||
return tma_atom_c, bSG_sC, bSG_gC
|
||
|
||
@staticmethod
|
||
def _compute_stages(
|
||
tiled_mma: cute.TiledMma,
|
||
mma_tiler_mnk: Tuple[int, int, int],
|
||
a_dtype: Type[cutlass.Numeric],
|
||
b_dtype: Type[cutlass.Numeric],
|
||
epi_tile: cute.Tile,
|
||
c_dtype: Type[cutlass.Numeric],
|
||
c_layout: utils.LayoutEnum,
|
||
sf_dtype: Type[cutlass.Numeric],
|
||
sf_vec_size: int,
|
||
smem_capacity: int,
|
||
occupancy: int,
|
||
) -> Tuple[int, int, int]:
|
||
"""Computes the number of stages for A/B/C operands based on heuristics.
|
||
|
||
:param tiled_mma: The tiled MMA object defining the core computation.
|
||
:type tiled_mma: cute.TiledMma
|
||
:param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler.
|
||
:type mma_tiler_mnk: tuple[int, int, int]
|
||
:param a_dtype: Data type of operand A.
|
||
:type a_dtype: type[cutlass.Numeric]
|
||
:param b_dtype: Data type of operand B.
|
||
:type b_dtype: type[cutlass.Numeric]
|
||
:param epi_tile: The epilogue tile shape.
|
||
:type epi_tile: cute.Tile
|
||
:param c_dtype: Data type of operand C (output).
|
||
:type c_dtype: type[cutlass.Numeric]
|
||
:param c_layout: Layout enum of operand C.
|
||
:type c_layout: utils.LayoutEnum
|
||
:param sf_dtype: Data type of Scale factor.
|
||
:type sf_dtype: type[cutlass.Numeric]
|
||
:param sf_vec_size: Scale factor vector size.
|
||
:type sf_vec_size: int
|
||
:param smem_capacity: Total available shared memory capacity in bytes.
|
||
:type smem_capacity: int
|
||
:param occupancy: Target number of CTAs per SM (occupancy).
|
||
:type occupancy: int
|
||
|
||
:return: A tuple containing the computed number of stages for:
|
||
(ACC stages, A/B operand stages, C stages)
|
||
:rtype: tuple[int, int, int]
|
||
"""
|
||
# ACC stages
|
||
num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2
|
||
|
||
# Default C stages
|
||
num_c_stage = 2
|
||
|
||
# Calculate smem layout and size for one stage of A, B, SFA, SFB and C
|
||
a_smem_layout_stage_one = sm100_utils.make_smem_layout_a(
|
||
tiled_mma,
|
||
mma_tiler_mnk,
|
||
a_dtype,
|
||
1, # a tmp 1 stage is provided
|
||
)
|
||
b_smem_layout_staged_one = sm100_utils.make_smem_layout_b(
|
||
tiled_mma,
|
||
mma_tiler_mnk,
|
||
b_dtype,
|
||
1, # a tmp 1 stage is provided
|
||
)
|
||
sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa(
|
||
tiled_mma,
|
||
mma_tiler_mnk,
|
||
sf_vec_size,
|
||
1, # a tmp 1 stage is provided
|
||
)
|
||
sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb(
|
||
tiled_mma,
|
||
mma_tiler_mnk,
|
||
sf_vec_size,
|
||
1, # a tmp 1 stage is provided
|
||
)
|
||
|
||
c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(
|
||
c_dtype,
|
||
c_layout,
|
||
epi_tile,
|
||
1,
|
||
)
|
||
|
||
ab_bytes_per_stage = (
|
||
cute.size_in_bytes(a_dtype, a_smem_layout_stage_one)
|
||
+ cute.size_in_bytes(b_dtype, b_smem_layout_staged_one)
|
||
+ cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one)
|
||
+ cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one)
|
||
)
|
||
mbar_helpers_bytes = 1024
|
||
c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one)
|
||
c_bytes = c_bytes_per_stage * num_c_stage
|
||
|
||
# Calculate A/B/SFA/SFB stages:
|
||
# Start with total smem per CTA (capacity / occupancy)
|
||
# Subtract reserved bytes and initial C stages bytes
|
||
# Divide remaining by bytes needed per A/B/SFA/SFB stage
|
||
num_ab_stage = (
|
||
smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes)
|
||
) // ab_bytes_per_stage
|
||
|
||
# Refine epilogue stages:
|
||
# Calculate remaining smem after allocating for A/B/SFA/SFB stages and reserved bytes
|
||
# Add remaining unused smem to epilogue
|
||
num_c_stage += (
|
||
smem_capacity
|
||
- occupancy * ab_bytes_per_stage * num_ab_stage
|
||
- occupancy * (mbar_helpers_bytes + c_bytes)
|
||
) // (occupancy * c_bytes_per_stage)
|
||
|
||
return num_acc_stage, num_ab_stage, num_c_stage
|
||
|
||
@staticmethod
|
||
def _compute_grid(
|
||
c: cute.Tensor,
|
||
cta_tile_shape_mnk: Tuple[int, int, int],
|
||
cluster_shape_mn: Tuple[int, int],
|
||
max_active_clusters: cutlass.Constexpr,
|
||
) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]:
|
||
"""Use persistent tile scheduler to compute the grid size for the output tensor C.
|
||
|
||
:param c: The output tensor C
|
||
:type c: cute.Tensor
|
||
:param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile.
|
||
:type cta_tile_shape_mnk: tuple[int, int, int]
|
||
:param cluster_shape_mn: Shape of each cluster in M, N dimensions.
|
||
:type cluster_shape_mn: tuple[int, int]
|
||
:param max_active_clusters: Maximum number of active clusters.
|
||
:type max_active_clusters: cutlass.Constexpr
|
||
|
||
:return: A tuple containing:
|
||
- tile_sched_params: Parameters for the persistent tile scheduler.
|
||
- grid: Grid shape for kernel launch.
|
||
:rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]]
|
||
"""
|
||
c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0))
|
||
gc = cute.zipped_divide(c, tiler=c_shape)
|
||
num_ctas_mnl = gc[(0, (None, None, None))].shape
|
||
cluster_shape_mnl = (*cluster_shape_mn, 1)
|
||
|
||
tile_sched_params = utils.PersistentTileSchedulerParams(
|
||
num_ctas_mnl, cluster_shape_mnl
|
||
)
|
||
grid = utils.StaticPersistentTileScheduler.get_grid_shape(
|
||
tile_sched_params, max_active_clusters
|
||
)
|
||
|
||
return tile_sched_params, grid
|
||
|
||
@staticmethod
|
||
def is_valid_dtypes_and_scale_factor_vec_size(
|
||
ab_dtype: Type[cutlass.Numeric],
|
||
sf_dtype: Type[cutlass.Numeric],
|
||
sf_vec_size: int,
|
||
c_dtype: Type[cutlass.Numeric],
|
||
) -> bool:
|
||
"""
|
||
Check if the dtypes and sf_vec_size are valid combinations
|
||
|
||
:param ab_dtype: The data type of the A and B operands
|
||
:type ab_dtype: Type[cutlass.Numeric]
|
||
:param sf_dtype: The data type of the scale factor
|
||
:type sf_dtype: Type[cutlass.Numeric]
|
||
:param sf_vec_size: The vector size of the scale factor
|
||
:type sf_vec_size: int
|
||
:param c_dtype: The data type of the output tensor
|
||
:type c_dtype: Type[cutlass.Numeric]
|
||
|
||
:return: True if the dtypes and sf_vec_size are valid, False otherwise
|
||
:rtype: bool
|
||
"""
|
||
is_valid = True
|
||
|
||
# Check valid ab_dtype
|
||
if ab_dtype not in {
|
||
cutlass.Float4E2M1FN,
|
||
cutlass.Float8E5M2,
|
||
cutlass.Float8E4M3FN,
|
||
}:
|
||
is_valid = False
|
||
|
||
# Check valid sf_vec_size
|
||
if sf_vec_size not in {16, 32}:
|
||
is_valid = False
|
||
|
||
# Check valid sf_dtype
|
||
if sf_dtype not in {cutlass.Float8E8M0FNU, cutlass.Float8E4M3FN}:
|
||
is_valid = False
|
||
|
||
# Check valid sf_dtype and sf_vec_size combinations
|
||
if sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 32:
|
||
is_valid = False
|
||
if ab_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} and sf_vec_size == 16:
|
||
is_valid = False
|
||
|
||
# Check valid c_dtype
|
||
if c_dtype not in {
|
||
cutlass.Float32,
|
||
cutlass.Float16,
|
||
cutlass.BFloat16,
|
||
cutlass.Float8E5M2,
|
||
cutlass.Float8E4M3FN,
|
||
}:
|
||
is_valid = False
|
||
|
||
return is_valid
|
||
|
||
@staticmethod
|
||
def is_valid_layouts(
|
||
ab_dtype: Type[cutlass.Numeric],
|
||
c_dtype: Type[cutlass.Numeric],
|
||
a_major: Literal["m", "k"],
|
||
b_major: Literal["n", "k"],
|
||
c_major: Literal["m", "n"],
|
||
) -> bool:
|
||
"""
|
||
Check if layouts and dtypes are valid combinations
|
||
|
||
:param ab_dtype: The data type of the A and B operands
|
||
:type ab_dtype: Type[cutlass.Numeric]
|
||
:param c_dtype: The data type of the output tensor
|
||
:type c_dtype: Type[cutlass.Numeric]
|
||
:param a_major: The major dimension of the A tensor
|
||
:type a_major: Literal["m", "k"]
|
||
:param b_major: The major dimension of the B tensor
|
||
:type b_major: Literal["n", "k"]
|
||
:param c_major: The major dimension of the C tensor
|
||
:type c_major: Literal["m", "n"]
|
||
|
||
:return: True if the layouts are valid, False otherwise
|
||
:rtype: bool
|
||
"""
|
||
is_valid = True
|
||
|
||
if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"):
|
||
is_valid = False
|
||
return is_valid
|
||
|
||
@staticmethod
|
||
def is_valid_mma_tiler_and_cluster_shape(
|
||
mma_tiler_mn: Tuple[int, int],
|
||
cluster_shape_mn: Tuple[int, int],
|
||
) -> bool:
|
||
"""
|
||
Check if the mma tiler and cluster shape are valid
|
||
|
||
:param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
|
||
:type mma_tiler_mn: Tuple[int, int]
|
||
:param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
|
||
:type cluster_shape_mn: Tuple[int, int]
|
||
|
||
:return: True if the mma tiler and cluster shape are valid, False otherwise
|
||
:rtype: bool
|
||
"""
|
||
is_valid = True
|
||
# Skip invalid mma tile shape
|
||
if mma_tiler_mn[0] not in [128, 256]:
|
||
is_valid = False
|
||
if mma_tiler_mn[1] not in [64, 128, 192, 256]:
|
||
is_valid = False
|
||
# Skip illegal cluster shape
|
||
if cluster_shape_mn[0] % (2 if mma_tiler_mn[0] == 256 else 1) != 0:
|
||
is_valid = False
|
||
# Skip invalid cluster shape
|
||
is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0
|
||
if (
|
||
cluster_shape_mn[0] * cluster_shape_mn[1] > 16
|
||
or cluster_shape_mn[0] <= 0
|
||
or cluster_shape_mn[1] <= 0
|
||
# Special cluster shape check for scale factor multicasts.
|
||
# Due to limited size of scale factors, we can't multicast among more than 4 CTAs.
|
||
or cluster_shape_mn[0] > 4
|
||
or cluster_shape_mn[1] > 4
|
||
or not is_power_of_2(cluster_shape_mn[0])
|
||
or not is_power_of_2(cluster_shape_mn[1])
|
||
):
|
||
is_valid = False
|
||
return is_valid
|
||
|
||
@staticmethod
|
||
def is_valid_tensor_alignment(
|
||
m: int,
|
||
n: int,
|
||
k: int,
|
||
l: int,
|
||
ab_dtype: Type[cutlass.Numeric],
|
||
c_dtype: Type[cutlass.Numeric],
|
||
a_major: Literal["m", "k"],
|
||
b_major: Literal["n", "k"],
|
||
c_major: Literal["m", "n"],
|
||
) -> bool:
|
||
"""
|
||
Check if the tensor alignment is valid
|
||
|
||
:param m: The number of rows in the A tensor
|
||
:type m: int
|
||
:param n: The number of columns in the B tensor
|
||
:type n: int
|
||
:param k: The number of columns in the A tensor
|
||
:type k: int
|
||
:param l: The number of columns in the C tensor
|
||
:type l: int
|
||
:param ab_dtype: The data type of the A and B operands
|
||
:type ab_dtype: Type[cutlass.Numeric]
|
||
:param c_dtype: The data type of the output tensor
|
||
:type c_dtype: Type[cutlass.Numeric]
|
||
:param a_major: The major axis of the A tensor
|
||
:type a_major: Literal["m", "k"]
|
||
:param b_major: The major axis of the B tensor
|
||
:type b_major: Literal["n", "k"]
|
||
:param c_major: The major axis of the C tensor
|
||
:type c_major: Literal["m", "n"]
|
||
|
||
:return: True if the problem shape is valid, False otherwise
|
||
:rtype: bool
|
||
"""
|
||
is_valid = True
|
||
|
||
def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape):
|
||
major_mode_idx = 0 if is_mode0_major else 1
|
||
num_major_elements = tensor_shape[major_mode_idx]
|
||
num_contiguous_elements = 16 * 8 // dtype.width
|
||
return num_major_elements % num_contiguous_elements == 0
|
||
|
||
if (
|
||
not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l))
|
||
or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l))
|
||
or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l))
|
||
):
|
||
is_valid = False
|
||
return is_valid
|
||
|
||
@staticmethod
|
||
def can_implement(
|
||
mnkl: Tuple[int, int, int, int],
|
||
ab_dtype: Type[cutlass.Numeric],
|
||
sf_dtype: Type[cutlass.Numeric],
|
||
c_dtype: Type[cutlass.Numeric],
|
||
a_major: Literal["m", "k"],
|
||
b_major: Literal["n", "k"],
|
||
c_major: Literal["m", "n"],
|
||
sf_vec_size: int,
|
||
mma_tiler_mn: Tuple[int, int],
|
||
cluster_shape_mn: Tuple[int, int],
|
||
) -> bool:
|
||
"""
|
||
Check if the gemm can be implemented
|
||
|
||
:param mnkl: The problem size as a tuple (M, N, K, L).
|
||
:type mnkl: Tuple[int, int, int, int]
|
||
:param ab_dtype: The data type of the A and B operands
|
||
:type ab_dtype: Type[cutlass.Numeric]
|
||
:param sf_dtype: The data type of the scale factor tensor
|
||
:type sf_dtype: Type[cutlass.Numeric]
|
||
:param a_major: The major axis of the A tensor
|
||
:type a_major: Literal["m", "k"]
|
||
:param b_major: The major axis of the B tensor
|
||
:type b_major: Literal["n", "k"]
|
||
:param c_major: The major axis of the C tensor
|
||
:type c_major: Literal["m", "n"]
|
||
:param sf_vec_size: The vector size
|
||
:type sf_vec_size: int
|
||
:param c_dtype: The data type of the output tensor
|
||
:type c_dtype: Type[cutlass.Numeric]
|
||
:param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
|
||
:type mma_tiler_mn: Tuple[int, int]
|
||
:param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
|
||
:type cluster_shape_mn: Tuple[int, int]
|
||
:return: True if the gemm can be implemented, False otherwise
|
||
:rtype: bool
|
||
"""
|
||
# Unpack parameters
|
||
m, n, k, l = mnkl
|
||
can_implement = True
|
||
# Skip unsupported types
|
||
if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_dtypes_and_scale_factor_vec_size(
|
||
ab_dtype, sf_dtype, sf_vec_size, c_dtype
|
||
):
|
||
can_implement = False
|
||
# Skip unsupported layouts
|
||
if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_layouts(
|
||
ab_dtype, c_dtype, a_major, b_major, c_major
|
||
):
|
||
can_implement = False
|
||
# Skip invalid mma tile shape and cluster shape
|
||
if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_mma_tiler_and_cluster_shape(
|
||
mma_tiler_mn, cluster_shape_mn
|
||
):
|
||
can_implement = False
|
||
# Skip illegal problem shape for load/store alignment
|
||
if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_tensor_alignment(
|
||
m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major
|
||
):
|
||
can_implement = False
|
||
return can_implement
|
||
|
||
|
||
# Helper function to convert scale factor tensor from MKL layout to (32, 4, restM, 4, restK, l) format
|
||
@cute.jit
|
||
def cvt_sf_MKL_to_M32x4xrm_K4xrk_L(
|
||
sf_ref_ptr: cute.Pointer,
|
||
sf_mma_ptr: cute.Pointer,
|
||
mn: int,
|
||
sf_k: int,
|
||
l: int,
|
||
mma_shape: tuple,
|
||
):
|
||
mma_permute_order = (3, 4, 1, 5, 2, 0)
|
||
permuted_shape = tuple(mma_shape[i] for i in mma_permute_order)
|
||
cute_layout = cute.make_ordered_layout(permuted_shape, order=(2, 1, 4, 0, 3, 5))
|
||
|
||
sf_ref_tensor = cute.make_tensor(
|
||
sf_ref_ptr, cute.make_layout((mn, sf_k, l), stride=(sf_k, 1, mn * sf_k))
|
||
)
|
||
sf_mma_tensor = cute.make_tensor(sf_mma_ptr, cute_layout)
|
||
|
||
sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3)
|
||
sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3)
|
||
for i in cutlass.range(cute.size(sf_ref_tensor)):
|
||
mkl_coord = sf_ref_tensor.layout.get_hier_coord(i)
|
||
sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord]
|
||
pass
|
||
|
||
|
||
# Helper function for ceil division
|
||
def ceil_div(a, b):
|
||
return (a + b - 1) // b
|
||
|
||
|
||
# Convert scale factor tensors from (m, k, l) to (32, 4, restM, 4, restK, l) format
|
||
def create_and_reorder_scale_factor_tensor(
|
||
l, mn, k, sf_vec_size, sf_dtype, torch_tensor
|
||
):
|
||
"""
|
||
Create the CUTE-format scale factor tensor on CUDA based on the reference tensor.
|
||
"""
|
||
sf_k = ceil_div(k, sf_vec_size)
|
||
atom_m = (32, 4)
|
||
atom_k = 4
|
||
mma_shape = (
|
||
l, # batch size
|
||
ceil_div(mn, atom_m[0] * atom_m[1]),
|
||
ceil_div(sf_k, atom_k),
|
||
atom_m[0],
|
||
atom_m[1],
|
||
atom_k,
|
||
)
|
||
|
||
# Generate a random int8 tensor, then convert to float8_e4m3fn
|
||
cute_tensor = torch.ones(mma_shape, dtype=cutlass_torch.dtype(sf_dtype)).permute(
|
||
3, 4, 1, 5, 2, 0
|
||
)
|
||
|
||
# Call the helper function to do layout conversion
|
||
cvt_sf_MKL_to_M32x4xrm_K4xrk_L(
|
||
make_ptr(
|
||
sf_dtype,
|
||
torch_tensor.data_ptr(),
|
||
cute.AddressSpace.gmem,
|
||
assumed_align=32,
|
||
),
|
||
make_ptr(
|
||
sf_dtype,
|
||
cute_tensor.data_ptr(),
|
||
cute.AddressSpace.gmem,
|
||
assumed_align=32,
|
||
),
|
||
mn,
|
||
sf_k,
|
||
l,
|
||
mma_shape,
|
||
)
|
||
return cute_tensor.cuda()
|
||
|
||
|
||
# Compile the persistent dense blockscaled GEMM operation
|
||
def scaled_mm(
|
||
gemm_obj: Sm100BlockScaledPersistentDenseGemmKernel,
|
||
ab_dtype: Type[cutlass.Numeric],
|
||
c_dtype: Type[cutlass.Numeric],
|
||
sf_dtype: Type[cutlass.Numeric],
|
||
a_major: Literal["m", "k"],
|
||
b_major: Literal["n", "k"],
|
||
c_major: Literal["m", "n"],
|
||
max_active_clusters: cutlass.Constexpr,
|
||
stream: cuda.CUstream,
|
||
epilogue_op: cutlass.Constexpr = lambda x: x,
|
||
options: str = "",
|
||
):
|
||
# Construct CuTe Pointers
|
||
a_ptr = make_ptr(ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16)
|
||
b_ptr = make_ptr(ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16)
|
||
c_ptr = make_ptr(c_dtype, 0, cute.AddressSpace.gmem, assumed_align=16)
|
||
sfa_ptr = make_ptr(sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32)
|
||
sfb_ptr = make_ptr(sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32)
|
||
|
||
a_major_mode = (
|
||
tcgen05.OperandMajorMode.K if a_major == "k" else tcgen05.OperandMajorMode.MN
|
||
)
|
||
b_major_mode = (
|
||
tcgen05.OperandMajorMode.K if b_major == "k" else tcgen05.OperandMajorMode.MN
|
||
)
|
||
c_layout = (
|
||
utils.LayoutEnum.ROW_MAJOR if c_major == "n" else utils.LayoutEnum.COL_MAJOR
|
||
)
|
||
return cute.compile(
|
||
gemm_obj,
|
||
a_ptr,
|
||
b_ptr,
|
||
sfa_ptr,
|
||
sfb_ptr,
|
||
c_ptr,
|
||
(a_major_mode, b_major_mode, c_layout),
|
||
(cutlass.Int32(0), cutlass.Int32(0), cutlass.Int32(0), cutlass.Int32(0)),
|
||
max_active_clusters,
|
||
stream,
|
||
epilogue_op,
|
||
options=options,
|
||
)
|
||
|
||
|
||
def is_emulated_dtype(
|
||
ab_dtype: Type[cutlass.Numeric],
|
||
sf_dtype: Type[cutlass.Numeric],
|
||
c_dtype: Type[cutlass.Numeric],
|
||
) -> bool:
|
||
if c_dtype in {
|
||
cutlass.Float32,
|
||
cutlass.Float16,
|
||
cutlass.BFloat16,
|
||
}:
|
||
if ab_dtype == cutlass.Float4E2M1FN and sf_dtype == cutlass.Float8E4M3FN:
|
||
return False
|
||
if ab_dtype == cutlass.Float8E4M3FN and sf_dtype == cutlass.Float8E8M0FNU:
|
||
return False
|
||
|
||
return True
|
||
|
||
|
||
# Convert scale factor tensor from MKL layout to blocked layout
|
||
def to_blocked(input_matrix):
|
||
rows, cols = input_matrix.shape
|
||
# Please ensure rows and cols are multiples of 128 and 4 respectively
|
||
n_row_blocks = ceil_div(rows, 128)
|
||
n_col_blocks = ceil_div(cols, 4)
|
||
padded_rows = n_row_blocks * 128
|
||
padded_cols = n_col_blocks * 4
|
||
|
||
# Pad the input matrix if necessary
|
||
if padded_rows != rows or padded_cols != cols:
|
||
# For FP8 types, convert to float32 for padding, then convert back
|
||
original_dtype = input_matrix.dtype
|
||
input_float32 = input_matrix.to(torch.float32)
|
||
padded = torch.nn.functional.pad(
|
||
input_float32,
|
||
(0, padded_cols - cols, 0, padded_rows - rows),
|
||
mode="constant",
|
||
value=0,
|
||
)
|
||
# Convert back to original dtype if needed
|
||
if original_dtype != input_float32.dtype:
|
||
padded = padded.to(original_dtype)
|
||
else:
|
||
padded = input_matrix
|
||
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
|
||
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
||
return rearranged.flatten()
|
||
|
||
|
||
# Reference implementation of the persistent dense blockscaled GEMM operation (emulated version)
|
||
def reference_scaled_mm_emulated(
|
||
a: torch.Tensor,
|
||
b: torch.Tensor,
|
||
sfa: torch.Tensor,
|
||
sfb: torch.Tensor,
|
||
c: torch.Tensor,
|
||
mnkl: Tuple[int, int, int, int],
|
||
sf_vec_size: int,
|
||
c_dtype: Type[cutlass.Numeric],
|
||
):
|
||
m, n, k, l = mnkl
|
||
sfa_expanded = (
|
||
torch.repeat_interleave(sfa, sf_vec_size, dim=1)[:, :k, :]
|
||
.to(dtype=torch.float32)
|
||
.cuda()
|
||
)
|
||
sfb_expanded = (
|
||
torch.repeat_interleave(sfb, sf_vec_size, dim=1)[:, :k, :]
|
||
.to(dtype=torch.float32)
|
||
.cuda()
|
||
)
|
||
res_a = torch.einsum("mkl,mkl->mkl", a, sfa_expanded)
|
||
res_b = torch.einsum("nkl,nkl->nkl", b, sfb_expanded)
|
||
# Cast res_a and res_b to float32 for einsum to avoid NotImplementedError on 'Byte'
|
||
ref = torch.einsum("mkl,nkl->mnl", res_a, res_b)
|
||
c_ref = ref.to(dtype=cutlass_torch.dtype(c_dtype))
|
||
return c_ref
|
||
|
||
|
||
# Reference implementation of the persistent dense blockscaled GEMM operation (non-emulated version)
|
||
def reference_scaled_mm(
|
||
a: torch.Tensor,
|
||
b: torch.Tensor,
|
||
sfa: torch.Tensor,
|
||
sfb: torch.Tensor,
|
||
c: torch.Tensor,
|
||
mnkl: Tuple[int, int, int, int],
|
||
c_dtype: Type[cutlass.Numeric],
|
||
):
|
||
m, n, k, l = mnkl
|
||
c_ref = torch.clone(c)
|
||
for l_idx in range(l):
|
||
# Convert the scale factor tensor to blocked format
|
||
scale_a = to_blocked(sfa[:, :, l_idx])
|
||
scale_b = to_blocked(sfb[:, :, l_idx])
|
||
# Ensure a_slice is row-major (M, K) with stride (K, 1)
|
||
a_slice = a[:, :, l_idx].contiguous()
|
||
# Ensure b_slice is row-major (N, K) so that transpose gives column-major (K, N)
|
||
b_slice = b[:, :, l_idx].contiguous()
|
||
# (m, k) @ (n, k).T -> (m, n)
|
||
res = torch._scaled_mm(
|
||
a_slice,
|
||
b_slice.transpose(0, 1),
|
||
scale_a.cuda(),
|
||
scale_b.cuda(),
|
||
bias=None,
|
||
out_dtype=c_ref.dtype,
|
||
)
|
||
c_ref[:, :, l_idx] = res
|
||
return c_ref
|
||
|
||
|
||
# Construct CuTe Pointers for the persistent dense blockscaled GEMM operation (emulated version)
|
||
def construct_cute_pointers_emulated(
|
||
a: torch.Tensor,
|
||
b: torch.Tensor,
|
||
sfa: torch.Tensor,
|
||
sfb: torch.Tensor,
|
||
c: torch.Tensor,
|
||
ab_dtype: Type[cutlass.Numeric],
|
||
sf_dtype: Type[cutlass.Numeric],
|
||
c_dtype: Type[cutlass.Numeric],
|
||
):
|
||
a_cute, _ = cutlass_torch.cute_tensor_like(
|
||
a.cpu(),
|
||
ab_dtype,
|
||
is_dynamic_layout=True,
|
||
assumed_align=16,
|
||
)
|
||
a_cute = cutlass_torch.convert_cute_tensor(
|
||
a,
|
||
a_cute,
|
||
ab_dtype,
|
||
is_dynamic_layout=True,
|
||
)
|
||
b_cute, _ = cutlass_torch.cute_tensor_like(
|
||
b.cpu(),
|
||
ab_dtype,
|
||
is_dynamic_layout=True,
|
||
assumed_align=16,
|
||
)
|
||
b_cute = cutlass_torch.convert_cute_tensor(
|
||
b,
|
||
b_cute,
|
||
ab_dtype,
|
||
is_dynamic_layout=True,
|
||
)
|
||
a_ptr = a_cute.iterator
|
||
b_ptr = b_cute.iterator
|
||
|
||
sfa_ptr = make_ptr(
|
||
sf_dtype, sfa.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
|
||
)
|
||
sfb_ptr = make_ptr(
|
||
sf_dtype, sfb.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
|
||
)
|
||
c_ptr = make_ptr(c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=16)
|
||
return a_ptr, b_ptr, c_ptr, sfa_ptr, sfb_ptr, a_cute, b_cute
|
||
|
||
|
||
# Construct CuTe Pointers for the persistent dense blockscaled GEMM operation (non-emulated version)
|
||
def construct_cute_pointers(
|
||
a: torch.Tensor,
|
||
b: torch.Tensor,
|
||
sfa: torch.Tensor,
|
||
sfb: torch.Tensor,
|
||
c: torch.Tensor,
|
||
ab_dtype: Type[cutlass.Numeric],
|
||
sf_dtype: Type[cutlass.Numeric],
|
||
c_dtype: Type[cutlass.Numeric],
|
||
):
|
||
a_ptr = make_ptr(ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=16)
|
||
b_ptr = make_ptr(ab_dtype, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=16)
|
||
sfa_ptr = make_ptr(
|
||
sf_dtype, sfa.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
|
||
)
|
||
sfb_ptr = make_ptr(
|
||
sf_dtype, sfb.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
|
||
)
|
||
c_ptr = make_ptr(c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=16)
|
||
return a_ptr, b_ptr, c_ptr, sfa_ptr, sfb_ptr
|
||
|
||
|
||
# Use uint8 and uint32 to emulate unsupported
|
||
# dtype in torch
|
||
def prepare_tensors_emulated(
|
||
mnkl: Tuple[int, int, int, int],
|
||
ab_dtype: Type[cutlass.Numeric],
|
||
sf_dtype: Type[cutlass.Numeric],
|
||
sf_vec_size: int,
|
||
c_dtype: Type[cutlass.Numeric],
|
||
a_major: Literal["m", "k"],
|
||
b_major: Literal["n", "k"],
|
||
c_major: Literal["m", "n"],
|
||
):
|
||
m, n, k, l = mnkl
|
||
sf_k = ceil_div(k, sf_vec_size)
|
||
|
||
# Create tensor SFA/SFB with values in [1, 3)
|
||
sfa = (
|
||
torch.randint(0, 3, (l, m, sf_k), dtype=torch.uint8)
|
||
.permute(1, 2, 0)
|
||
.to(dtype=cutlass_torch.dtype(sf_dtype))
|
||
)
|
||
sfb = (
|
||
torch.randint(0, 3, (l, n, sf_k), dtype=torch.uint8)
|
||
.permute(1, 2, 0)
|
||
.to(dtype=cutlass_torch.dtype(sf_dtype))
|
||
)
|
||
|
||
# Create tensor A/B with values in [0, 2)
|
||
if a_major == "k":
|
||
a = torch.randint(-2, 2, (l, m, k), dtype=torch.float32, device="cuda").permute(
|
||
1, 2, 0
|
||
)
|
||
else:
|
||
a = torch.randint(-2, 2, (l, k, m), dtype=torch.float32, device="cuda").permute(
|
||
2, 1, 0
|
||
)
|
||
if b_major == "k":
|
||
b = torch.randint(-2, 2, (l, n, k), dtype=torch.float32, device="cuda").permute(
|
||
1, 2, 0
|
||
)
|
||
else:
|
||
b = torch.randint(-2, 2, (l, k, n), dtype=torch.float32, device="cuda").permute(
|
||
2, 1, 0
|
||
)
|
||
if c_major == "n":
|
||
c = torch.empty(
|
||
(l, m, n), dtype=cutlass_torch.dtype(c_dtype), device="cuda"
|
||
).permute(1, 2, 0)
|
||
else:
|
||
c = torch.empty(
|
||
(l, n, m), dtype=cutlass_torch.dtype(c_dtype), device="cuda"
|
||
).permute(2, 1, 0)
|
||
return a, b, c, sfa, sfb
|
||
|
||
|
||
def prepare_tensors(
|
||
mnkl: Tuple[int, int, int, int],
|
||
ab_dtype: Type[cutlass.Numeric],
|
||
sf_dtype: Type[cutlass.Numeric],
|
||
sf_vec_size: int,
|
||
c_dtype: Type[cutlass.Numeric],
|
||
a_major: Literal["m", "k"],
|
||
b_major: Literal["n", "k"],
|
||
c_major: Literal["m", "n"],
|
||
):
|
||
m, n, k, l = mnkl
|
||
|
||
if ab_dtype == cutlass.Float4E2M1FN:
|
||
# Using int8 for torch.float4_e2m1fn_x2 tensor allocation
|
||
# Thus the size of k needs to be halved in this case.
|
||
k_fct = 2
|
||
else:
|
||
k_fct = 1
|
||
|
||
sf_k = ceil_div(k, sf_vec_size)
|
||
|
||
# Create tensor SFA/SFB
|
||
sfa = (
|
||
torch.randint(0, 3, (l, m, sf_k), dtype=torch.uint8)
|
||
.permute(1, 2, 0)
|
||
.to(dtype=cutlass_torch.dtype(sf_dtype))
|
||
)
|
||
sfb = (
|
||
torch.randint(0, 3, (l, n, sf_k), dtype=torch.uint8)
|
||
.permute(1, 2, 0)
|
||
.to(dtype=cutlass_torch.dtype(sf_dtype))
|
||
)
|
||
|
||
# Create tensor A/B/C
|
||
if a_major == "k":
|
||
a = torch.randint(
|
||
-2, 2, (l, m, k // k_fct), dtype=torch.int8, device="cuda"
|
||
).permute(1, 2, 0)
|
||
else:
|
||
a = torch.randint(-2, 2, (l, k, m), dtype=torch.int8, device="cuda").permute(
|
||
2, 1, 0
|
||
)
|
||
if b_major == "k":
|
||
b = torch.randint(
|
||
-2, 2, (l, n, k // k_fct), dtype=torch.int8, device="cuda"
|
||
).permute(1, 2, 0)
|
||
else:
|
||
b = torch.randint(-2, 2, (l, k, n), dtype=torch.int8, device="cuda").permute(
|
||
2, 1, 0
|
||
)
|
||
if c_major == "n":
|
||
c = torch.randint(
|
||
-2, 2, (l, m, n), dtype=cutlass_torch.dtype(c_dtype), device="cuda"
|
||
).permute(1, 2, 0)
|
||
else:
|
||
c = torch.randint(
|
||
-2, 2, (l, n, m), dtype=cutlass_torch.dtype(c_dtype), device="cuda"
|
||
).permute(2, 1, 0)
|
||
|
||
if ab_dtype == cutlass.Float4E2M1FN:
|
||
a = a.view(dtype=torch.float4_e2m1fn_x2)
|
||
b = b.view(dtype=torch.float4_e2m1fn_x2)
|
||
else:
|
||
a = a.to(dtype=cutlass_torch.dtype(ab_dtype))
|
||
b = b.to(dtype=cutlass_torch.dtype(ab_dtype))
|
||
|
||
c = c.to(dtype=cutlass_torch.dtype(c_dtype))
|
||
return a, b, c, sfa, sfb
|
||
|
||
|
||
# This will show how to covert torch tensor
|
||
# and pass to CuTe kernel
|
||
def run_scaled_mm(
|
||
mnkl: Tuple[int, int, int, int],
|
||
ab_dtype: Type[cutlass.Numeric],
|
||
sf_dtype: Type[cutlass.Numeric],
|
||
sf_vec_size: int,
|
||
c_dtype: Type[cutlass.Numeric],
|
||
a_major: Literal["m", "k"],
|
||
b_major: Literal["n", "k"],
|
||
c_major: Literal["m", "n"],
|
||
mma_tiler_mn: Tuple[int, int],
|
||
cluster_shape_mn: Tuple[int, int],
|
||
tolerance: float = 1e-01,
|
||
warmup_iterations: int = 0,
|
||
iterations: int = 1,
|
||
skip_ref_check: bool = False,
|
||
use_cold_l2: bool = False,
|
||
**kwargs,
|
||
):
|
||
"""Execute a persistent batched dense blockscaled GEMM operation on Blackwell architecture with performance benchmarking (non-emulated dtypes).
|
||
|
||
This function prepares input tensors, configures and launches the persistent GEMM kernel,
|
||
optionally performs reference validation, and benchmarks the execution performance.
|
||
|
||
:param mnkl: Problem size (M, N, K, L)
|
||
:type mnkl: Tuple[int, int, int, int]
|
||
:param ab_dtype: Data type for input tensors A and B
|
||
:type ab_dtype: Type[cutlass.Numeric]
|
||
:param sf_dtype: Data type for scale factor tensor
|
||
:type sf_dtype: Type[cutlass.Numeric]
|
||
:param sf_vec_size: Vector size for scale factor tensor
|
||
:type sf_vec_size: int
|
||
:param c_dtype: Data type for output tensor C
|
||
:type c_dtype: Type[cutlass.Numeric]
|
||
:param a_major/b_major/c_major: Memory layout of tensor A/B/C
|
||
:type a_major/b_major/c_major: Literal["m", "k", "n"]
|
||
:param mma_tiler_mn: MMA tiling size.
|
||
:type mma_tiler_mn: Tuple[int, int]
|
||
:param cluster_shape_mn: Cluster shape.
|
||
:type cluster_shape_mn: Tuple[int, int]
|
||
:param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01
|
||
:type tolerance: float, optional
|
||
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
|
||
:type warmup_iterations: int, optional
|
||
:param iterations: Number of benchmark iterations to run, defaults to 1
|
||
:type iterations: int, optional
|
||
:param skip_ref_check: Whether to skip reference result validation, defaults to False
|
||
:type skip_ref_check: bool, optional
|
||
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
|
||
:type use_cold_l2: bool, optional
|
||
:raises RuntimeError: If CUDA GPU is not available
|
||
:raises ValueError: If the configuration is invalid or unsupported by the kernel
|
||
:return: Execution time of the GEMM kernel
|
||
:rtype: float
|
||
"""
|
||
print("Running Sm100 Persistent Dense BlockScaled GEMM test with:")
|
||
print(f"mnkl: {mnkl}")
|
||
print(f"AB dtype: {ab_dtype}, SF dtype: {sf_dtype}, SF Vec size: {sf_vec_size}")
|
||
print(f"C dtype: {c_dtype}")
|
||
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
|
||
print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}")
|
||
print(f"Tolerance: {tolerance}")
|
||
print(f"Warmup iterations: {warmup_iterations}")
|
||
print(f"Iterations: {iterations}")
|
||
print(f"Skip reference checking: {skip_ref_check}")
|
||
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
|
||
|
||
# Unpack parameters
|
||
m, n, k, l = mnkl
|
||
|
||
# Configure gemm kernel
|
||
gemm = Sm100BlockScaledPersistentDenseGemmKernel(
|
||
sf_vec_size,
|
||
mma_tiler_mn,
|
||
cluster_shape_mn,
|
||
)
|
||
|
||
# Skip unsupported testcase
|
||
if not gemm.can_implement(
|
||
mnkl,
|
||
ab_dtype,
|
||
sf_dtype,
|
||
c_dtype,
|
||
a_major,
|
||
b_major,
|
||
c_major,
|
||
sf_vec_size,
|
||
mma_tiler_mn,
|
||
cluster_shape_mn,
|
||
):
|
||
raise TypeError(
|
||
f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}"
|
||
)
|
||
|
||
if not torch.cuda.is_available():
|
||
raise RuntimeError("GPU is required to run this example!")
|
||
|
||
torch.manual_seed(1111)
|
||
|
||
# Get current CUDA stream from PyTorch
|
||
torch_stream = torch.cuda.current_stream()
|
||
# Get the raw stream pointer as a CUstream
|
||
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
||
|
||
# Check if configuration can be implemented
|
||
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(
|
||
cluster_shape_mn[0] * cluster_shape_mn[1]
|
||
)
|
||
|
||
# Compile gemm kernel with fake tensors
|
||
compiled_gemm = scaled_mm(
|
||
gemm,
|
||
ab_dtype,
|
||
c_dtype,
|
||
sf_dtype,
|
||
a_major,
|
||
b_major,
|
||
c_major,
|
||
max_active_clusters,
|
||
current_stream,
|
||
options=f"--opt-level 2",
|
||
)
|
||
|
||
# Create Torch Tensors for A, scale factor A, B, scale factor B, C
|
||
a, b, c, sfa, sfb = prepare_tensors(
|
||
mnkl, ab_dtype, sf_dtype, sf_vec_size, c_dtype, a_major, b_major, c_major
|
||
)
|
||
# Reorder scale factor tensors to (32, 4, restM, 4, restK, l) format
|
||
sfa_reordered = create_and_reorder_scale_factor_tensor(
|
||
l, m, k, sf_vec_size, sf_dtype, sfa
|
||
)
|
||
sfb_reordered = create_and_reorder_scale_factor_tensor(
|
||
l, n, k, sf_vec_size, sf_dtype, sfb
|
||
)
|
||
# Construct CuTe Pointers
|
||
a_ptr, b_ptr, c_ptr, sfa_ptr, sfb_ptr = construct_cute_pointers(
|
||
a,
|
||
b,
|
||
sfa_reordered,
|
||
sfb_reordered,
|
||
c,
|
||
ab_dtype,
|
||
sf_dtype,
|
||
c_dtype,
|
||
)
|
||
|
||
# Compute reference result
|
||
if not skip_ref_check:
|
||
# Execute kernel once for reference checking
|
||
compiled_gemm(
|
||
a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (m, n, k, l), current_stream
|
||
)
|
||
c_ref = reference_scaled_mm(a, b, sfa, sfb, c, (m, n, k, l), c_dtype)
|
||
if c_dtype in (cutlass.Float8E5M2, cutlass.Float8E4M3FN):
|
||
# Rtol=0.001 and atol=0.1 are not supported for bitwise comparison of
|
||
# low dimensional floats. Please use rtol=0.0 and atol=0.0.
|
||
tolerance = 0.0
|
||
torch.testing.assert_close(c, c_ref, atol=tolerance, rtol=tolerance)
|
||
|
||
def generate_inputs():
|
||
a, b, c, sfa, sfb = prepare_tensors(
|
||
mnkl,
|
||
ab_dtype,
|
||
sf_dtype,
|
||
sf_vec_size,
|
||
c_dtype,
|
||
a_major,
|
||
b_major,
|
||
c_major,
|
||
)
|
||
# Reorder scale factor tensors to (32, 4, restM, 4, restK, l) format
|
||
sfa_reordered = create_and_reorder_scale_factor_tensor(
|
||
l, m, k, sf_vec_size, sf_dtype, sfa
|
||
)
|
||
sfb_reordered = create_and_reorder_scale_factor_tensor(
|
||
l, n, k, sf_vec_size, sf_dtype, sfb
|
||
)
|
||
# Construct CuTe Pointers
|
||
a_ptr, b_ptr, c_ptr, sfa_ptr, sfb_ptr = construct_cute_pointers(
|
||
a,
|
||
b,
|
||
sfa_reordered,
|
||
sfb_reordered,
|
||
c,
|
||
ab_dtype,
|
||
sf_dtype,
|
||
c_dtype,
|
||
)
|
||
jit_args = cute.testing.JitArguments(
|
||
a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (m, n, k, l), current_stream
|
||
)
|
||
# Keep references to external variables (e.g., Torch tensors when taking a view)
|
||
jit_args.add_to_scope([a, b, sfa_reordered, sfb_reordered, c])
|
||
return jit_args
|
||
|
||
workspace_count = 1
|
||
if use_cold_l2:
|
||
one_workspace_bytes = (
|
||
a.numel() * a.element_size()
|
||
+ b.numel() * b.element_size()
|
||
+ sfa.numel() * sfa.element_size()
|
||
+ sfb.numel() * sfb.element_size()
|
||
+ c.numel() * c.element_size()
|
||
)
|
||
workspace_count = cute.testing.get_workspace_count(
|
||
one_workspace_bytes, warmup_iterations, iterations
|
||
)
|
||
|
||
exec_time = cute.testing.benchmark(
|
||
compiled_gemm,
|
||
workspace_generator=generate_inputs,
|
||
workspace_count=workspace_count,
|
||
stream=current_stream,
|
||
warmup_iterations=warmup_iterations,
|
||
iterations=iterations,
|
||
)
|
||
return exec_time # Return execution time in microseconds
|
||
|
||
|
||
# This is to compatible with the other narrow
|
||
# precision combinations are not supported in either
|
||
# torch or dlpack. For example, Float4E2M1FN with Float8E8M0FNU.
|
||
def run_scaled_mm_with_emulated_dtype(
|
||
mnkl: Tuple[int, int, int, int],
|
||
ab_dtype: Type[cutlass.Numeric],
|
||
sf_dtype: Type[cutlass.Numeric],
|
||
sf_vec_size: int,
|
||
c_dtype: Type[cutlass.Numeric],
|
||
a_major: Literal["m", "k"],
|
||
b_major: Literal["n", "k"],
|
||
c_major: Literal["m", "n"],
|
||
mma_tiler_mn: Tuple[int, int],
|
||
cluster_shape_mn: Tuple[int, int],
|
||
tolerance: float = 1e-01,
|
||
warmup_iterations: int = 0,
|
||
iterations: int = 1,
|
||
skip_ref_check: bool = False,
|
||
use_cold_l2: bool = False,
|
||
**kwargs,
|
||
):
|
||
"""Execute a persistent batched dense blockscaled GEMM operation on Blackwell architecture with performance benchmarking (emulated dtypes).
|
||
|
||
This function prepares input tensors, configures and launches the persistent GEMM kernel,
|
||
optionally performs reference validation, and benchmarks the execution performance.
|
||
|
||
:param mnkl: Problem size (M, N, K, L)
|
||
:type mnkl: Tuple[int, int, int, int]
|
||
:param ab_dtype: Data type for input tensors A and B
|
||
:type ab_dtype: Type[cutlass.Numeric]
|
||
:param sf_dtype: Data type for scale factor tensor
|
||
:type sf_dtype: Type[cutlass.Numeric]
|
||
:param sf_vec_size: Vector size for scale factor tensor
|
||
:type sf_vec_size: int
|
||
:param c_dtype: Data type for output tensor C
|
||
:type c_dtype: Type[cutlass.Numeric]
|
||
:param a_major/b_major/c_major: Memory layout of tensor A/B/C
|
||
:type a_major/b_major/c_major: Literal["m", "n","k"]
|
||
:param mma_tiler_mn: MMA tiling size.
|
||
:type mma_tiler_mn: Tuple[int, int]
|
||
:param cluster_shape_mn: Cluster shape.
|
||
:type cluster_shape_mn: Tuple[int, int]
|
||
:param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01
|
||
:type tolerance: float, optional
|
||
:param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0
|
||
:type warmup_iterations: int, optional
|
||
:param iterations: Number of benchmark iterations to run, defaults to 1
|
||
:type iterations: int, optional
|
||
:param skip_ref_check: Whether to skip reference result validation, defaults to False
|
||
:type skip_ref_check: bool, optional
|
||
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
|
||
:type use_cold_l2: bool, optional
|
||
:raises RuntimeError: If CUDA GPU is not available
|
||
:raises ValueError: If the configuration is invalid or unsupported by the kernel
|
||
:return: Execution time of the GEMM kernel
|
||
:rtype: float
|
||
"""
|
||
print("Running Sm100 Persistent Dense BlockScaled GEMM test (Emulated) with:")
|
||
print(f"mnkl: {mnkl}")
|
||
print(f"AB dtype: {ab_dtype}, SF dtype: {sf_dtype}, SF Vec size: {sf_vec_size}")
|
||
print(f"C dtype: {c_dtype}")
|
||
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
|
||
print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}")
|
||
print(f"Tolerance: {tolerance}")
|
||
print(f"Warmup iterations: {warmup_iterations}")
|
||
print(f"Iterations: {iterations}")
|
||
print(f"Skip reference checking: {skip_ref_check}")
|
||
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
|
||
|
||
# Unpack parameters
|
||
m, n, k, l = mnkl
|
||
|
||
# Configure gemm kernel
|
||
gemm = Sm100BlockScaledPersistentDenseGemmKernel(
|
||
sf_vec_size,
|
||
mma_tiler_mn,
|
||
cluster_shape_mn,
|
||
)
|
||
|
||
# Skip unsupported testcase
|
||
if not gemm.can_implement(
|
||
mnkl,
|
||
ab_dtype,
|
||
sf_dtype,
|
||
c_dtype,
|
||
a_major,
|
||
b_major,
|
||
c_major,
|
||
sf_vec_size,
|
||
mma_tiler_mn,
|
||
cluster_shape_mn,
|
||
):
|
||
raise TypeError(
|
||
f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}"
|
||
)
|
||
|
||
if not torch.cuda.is_available():
|
||
raise RuntimeError("GPU is required to run this example!")
|
||
|
||
torch.manual_seed(1111)
|
||
|
||
# Get current CUDA stream from PyTorch
|
||
torch_stream = torch.cuda.current_stream()
|
||
# Get the raw stream pointer as a CUstream
|
||
current_stream = cuda.CUstream(torch_stream.cuda_stream)
|
||
|
||
# Check if configuration can be implemented
|
||
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(
|
||
cluster_shape_mn[0] * cluster_shape_mn[1]
|
||
)
|
||
|
||
# Compile gemm kernel with fake tensors
|
||
compiled_gemm = scaled_mm(
|
||
gemm,
|
||
ab_dtype,
|
||
c_dtype,
|
||
sf_dtype,
|
||
a_major,
|
||
b_major,
|
||
c_major,
|
||
max_active_clusters,
|
||
current_stream,
|
||
options=f"--opt-level 2",
|
||
)
|
||
|
||
# Create Torch Tensors for A, scale factor A, B, scale factor B, C
|
||
a, b, c, sfa, sfb = prepare_tensors_emulated(
|
||
mnkl, ab_dtype, sf_dtype, sf_vec_size, c_dtype, a_major, b_major, c_major
|
||
)
|
||
# Reorder scale factor tensors to (32, 4, restM, 4, restK, l) format
|
||
sfa_reordered = create_and_reorder_scale_factor_tensor(
|
||
l, m, k, sf_vec_size, sf_dtype, sfa
|
||
)
|
||
sfb_reordered = create_and_reorder_scale_factor_tensor(
|
||
l, n, k, sf_vec_size, sf_dtype, sfb
|
||
)
|
||
# Construct CuTe Pointers
|
||
a_ptr, b_ptr, c_ptr, sfa_ptr, sfb_ptr, a_cute, b_cute = (
|
||
construct_cute_pointers_emulated(
|
||
a,
|
||
b,
|
||
sfa_reordered,
|
||
sfb_reordered,
|
||
c,
|
||
ab_dtype,
|
||
sf_dtype,
|
||
c_dtype,
|
||
)
|
||
)
|
||
|
||
# Compute reference result
|
||
if not skip_ref_check:
|
||
# Execute kernel once for reference checking
|
||
compiled_gemm(
|
||
a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (m, n, k, l), current_stream
|
||
)
|
||
c_ref = reference_scaled_mm_emulated(
|
||
a, b, sfa, sfb, c, (m, n, k, l), sf_vec_size, c_dtype
|
||
)
|
||
if c_dtype in (cutlass.Float8E5M2, cutlass.Float8E4M3FN):
|
||
# Rtol=0.001 and atol=0.1 are not supported for bitwise comparison of
|
||
# low dimensional floats. Please use rtol=0.0 and atol=0.0.
|
||
tolerance = 0.0
|
||
torch.testing.assert_close(c, c_ref, atol=tolerance, rtol=tolerance)
|
||
|
||
def generate_inputs():
|
||
a, b, c, sfa, sfb = prepare_tensors_emulated(
|
||
mnkl,
|
||
ab_dtype,
|
||
sf_dtype,
|
||
sf_vec_size,
|
||
c_dtype,
|
||
a_major,
|
||
b_major,
|
||
c_major,
|
||
)
|
||
# Reorder scale factor tensors to (32, 4, restM, 4, restK, l) format
|
||
sfa_reordered = create_and_reorder_scale_factor_tensor(
|
||
l, m, k, sf_vec_size, sf_dtype, sfa
|
||
)
|
||
sfb_reordered = create_and_reorder_scale_factor_tensor(
|
||
l, n, k, sf_vec_size, sf_dtype, sfb
|
||
)
|
||
# Construct CuTe Pointers
|
||
a_ptr, b_ptr, c_ptr, sfa_ptr, sfb_ptr, a_cute, b_cute = (
|
||
construct_cute_pointers_emulated(
|
||
a,
|
||
b,
|
||
sfa_reordered,
|
||
sfb_reordered,
|
||
c,
|
||
ab_dtype,
|
||
sf_dtype,
|
||
c_dtype,
|
||
)
|
||
)
|
||
jit_args = cute.testing.JitArguments(
|
||
a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (m, n, k, l), current_stream
|
||
)
|
||
# Keep references to external variables (e.g., Torch tensors when taking a view)
|
||
jit_args.add_to_scope([a, b, sfa_reordered, sfb_reordered, c, a_cute, b_cute])
|
||
return jit_args
|
||
|
||
|
||
workspace_count = 1
|
||
if use_cold_l2:
|
||
one_workspace_bytes = (
|
||
a.numel() * a.element_size()
|
||
+ b.numel() * b.element_size()
|
||
+ sfa.numel() * sfa.element_size()
|
||
+ sfb.numel() * sfb.element_size()
|
||
+ c.numel() * c.element_size()
|
||
)
|
||
workspace_count = cute.testing.get_workspace_count(
|
||
one_workspace_bytes, warmup_iterations, iterations
|
||
)
|
||
|
||
exec_time = cute.testing.benchmark(
|
||
compiled_gemm,
|
||
workspace_generator=generate_inputs,
|
||
workspace_count=workspace_count,
|
||
stream=current_stream,
|
||
warmup_iterations=warmup_iterations,
|
||
iterations=iterations,
|
||
)
|
||
return exec_time # Return execution time in microseconds
|
||
|
||
|
||
def run(
|
||
mnkl: Tuple[int, int, int, int],
|
||
ab_dtype: Type[cutlass.Numeric],
|
||
sf_dtype: Type[cutlass.Numeric],
|
||
sf_vec_size: int,
|
||
c_dtype: Type[cutlass.Numeric],
|
||
a_major: Literal["m", "k"],
|
||
b_major: Literal["n", "k"],
|
||
c_major: Literal["m", "n"],
|
||
mma_tiler_mn: Tuple[int, int],
|
||
cluster_shape_mn: Tuple[int, int],
|
||
tolerance: float = 1e-01,
|
||
warmup_iterations: int = 0,
|
||
iterations: int = 1,
|
||
skip_ref_check: bool = False,
|
||
use_cold_l2: bool = False,
|
||
**kwargs,
|
||
):
|
||
"""
|
||
Execute the appropriate GEMM function based on dtype.
|
||
|
||
Routes to either run_scaled_mm_with_emulated_dtype or run_scaled_mm
|
||
depending on whether the dtypes require emulation.
|
||
"""
|
||
if is_emulated_dtype(ab_dtype, sf_dtype, c_dtype):
|
||
exec_time = run_scaled_mm_with_emulated_dtype(
|
||
mnkl,
|
||
ab_dtype,
|
||
sf_dtype,
|
||
sf_vec_size,
|
||
c_dtype,
|
||
a_major,
|
||
b_major,
|
||
c_major,
|
||
mma_tiler_mn,
|
||
cluster_shape_mn,
|
||
tolerance,
|
||
warmup_iterations,
|
||
iterations,
|
||
skip_ref_check,
|
||
use_cold_l2,
|
||
)
|
||
else:
|
||
exec_time = run_scaled_mm(
|
||
mnkl,
|
||
ab_dtype,
|
||
sf_dtype,
|
||
sf_vec_size,
|
||
c_dtype,
|
||
a_major,
|
||
b_major,
|
||
c_major,
|
||
mma_tiler_mn,
|
||
cluster_shape_mn,
|
||
tolerance,
|
||
warmup_iterations,
|
||
iterations,
|
||
skip_ref_check,
|
||
use_cold_l2,
|
||
)
|
||
return exec_time
|
||
|
||
|
||
if __name__ == "__main__":
|
||
|
||
def parse_comma_separated_ints(s: str) -> Tuple[int, ...]:
|
||
try:
|
||
return tuple(int(x.strip()) for x in s.split(","))
|
||
except ValueError:
|
||
raise argparse.ArgumentTypeError(
|
||
"Invalid format. Expected comma-separated integers."
|
||
)
|
||
|
||
parser = argparse.ArgumentParser(
|
||
description="Example of Sm100 Dense Persistent BlockScaled GEMM."
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--mnkl",
|
||
type=parse_comma_separated_ints,
|
||
default=(512, 256, 256, 1),
|
||
help="mnkl dimensions (comma-separated)",
|
||
)
|
||
parser.add_argument(
|
||
"--mma_tiler_mn",
|
||
type=parse_comma_separated_ints,
|
||
default=(128, 128),
|
||
help="Mma tile shape (comma-separated)",
|
||
)
|
||
parser.add_argument(
|
||
"--cluster_shape_mn",
|
||
type=parse_comma_separated_ints,
|
||
default=(1, 1),
|
||
help="Cluster shape (comma-separated)",
|
||
)
|
||
parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.Float4E2M1FN)
|
||
parser.add_argument("--sf_dtype", type=cutlass.dtype, default=cutlass.Float8E4M3FN)
|
||
parser.add_argument("--sf_vec_size", type=int, default=16)
|
||
parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float16)
|
||
parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
|
||
parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
|
||
parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n")
|
||
parser.add_argument(
|
||
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
|
||
)
|
||
parser.add_argument(
|
||
"--warmup_iterations", type=int, default=0, help="Warmup iterations"
|
||
)
|
||
parser.add_argument(
|
||
"--iterations",
|
||
type=int,
|
||
default=1,
|
||
help="Number of iterations to run the kernel",
|
||
)
|
||
parser.add_argument(
|
||
"--skip_ref_check", action="store_true", help="Skip reference checking"
|
||
)
|
||
parser.add_argument(
|
||
"--use_cold_l2",
|
||
action="store_true",
|
||
default=False,
|
||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
if len(args.mnkl) != 4:
|
||
parser.error("--mnkl must contain exactly 4 values")
|
||
|
||
if len(args.mma_tiler_mn) != 2:
|
||
parser.error("--mma_tiler_mn must contain exactly 2 values")
|
||
|
||
if len(args.cluster_shape_mn) != 2:
|
||
parser.error("--cluster_shape_mn must contain exactly 2 values")
|
||
|
||
# Execute GEMM with appropriate function based on dtype
|
||
run(
|
||
args.mnkl,
|
||
args.ab_dtype,
|
||
args.sf_dtype,
|
||
args.sf_vec_size,
|
||
args.c_dtype,
|
||
args.a_major,
|
||
args.b_major,
|
||
args.c_major,
|
||
args.mma_tiler_mn,
|
||
args.cluster_shape_mn,
|
||
args.tolerance,
|
||
args.warmup_iterations,
|
||
args.iterations,
|
||
args.skip_ref_check,
|
||
args.use_cold_l2,
|
||
)
|
||
print("PASS")
|