mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-23 14:25:03 +00:00
1359 lines
63 KiB
ReStructuredText
1359 lines
63 KiB
ReStructuredText
.. _wmma_programming:
|
||
|
||
Warp-Level MMA Instructions Programming Guide
|
||
=============================================
|
||
|
||
Ampere (SM80) introduced the modern **warp-level MMA** PTX instruction
|
||
family ``mma.sync.aligned``. A warp (32 threads) cooperates on one
|
||
synchronous ``D = A * B + C`` matrix multiply-accumulate; later
|
||
architectures extended the family with new data types and shapes — FP8 on
|
||
Ada (SM89) and block-scaled MX FP4 on Blackwell (SM120a) — while keeping
|
||
the same warp-synchronous issue model.
|
||
|
||
Key architectural characteristics:
|
||
|
||
* **Warp scope:** One MMA is issued collectively by a 32-thread warp
|
||
rather than by a warpgroup or a single thread.
|
||
* **Synchronous issue model:** ``mma.sync.aligned`` completes in program
|
||
order within the warp; no fences or commit/wait groups are required.
|
||
* **Register-resident operands and accumulator:** A, B, and C/D all live
|
||
in the register file (RMEM). Each thread holds a small fragment of every
|
||
operand in its own registers.
|
||
* **SMEM → RMEM loading:** Operands A and B are staged in shared memory
|
||
and loaded into register fragments via ``ldmatrix`` — a warp-collective
|
||
SMEM→RMEM load that distributes tiles in the exact layout the MMA
|
||
expects — or via regular shared-memory loads.
|
||
* **Fixed operand layout:** A is row-major (K-major) and B is col-major
|
||
(K-major); transpose is not supported at the instruction level.
|
||
|
||
The dense DSL op classes currently exposed are ``MmaF16BF16Op`` (F16/BF16,
|
||
SM80+), ``MmaFP8Op`` (FP8 E4M3/E5M2, SM89+), and ``MmaMXF4Op`` /
|
||
``MmaMXF4NVF4Op`` (block-scaled MX FP4, SM120a+); see `Setting up the
|
||
TiledMMA, MMA Ops`_ for their full constructor parameters, instruction
|
||
shapes, and architecture requirements.
|
||
|
||
.. {$nv-internal-release begin}
|
||
|
||
Internal builds additionally expose ``MmaF16BF16SparseOp`` (2:4 structured
|
||
sparsity, SM80+).
|
||
|
||
.. {$nv-internal-release end}
|
||
|
||
This guide outlines the CuTe Python DSL programming model for warp-level
|
||
MMA kernels: stage operands in SMEM, load register fragments with
|
||
``ldmatrix`` or regular shared-memory loads, launch warp-synchronous MMAs, and stage the RMEM accumulator
|
||
back to GMEM in the epilogue.
|
||
|
||
|
||
.. contents:: **Contents**
|
||
:local:
|
||
:depth: 2
|
||
|
||
Global Memory (GMEM) to MMA data flow overview
|
||
----------------------------------------------
|
||
|
||
Warp MMA (``mma.sync.aligned``) instructions require all operands --A, B, and
|
||
the accumulator C/D-- to live in registers (RMEM) of the 32 threads of the warp.
|
||
Operand data must therefore be explicitly loaded into registers before each MMA
|
||
instruction. The most common way to implement these GEMMs is to stage A and B
|
||
from GMEM into SMEM with ``cp.async``, then use ``ldmatrix`` (an SMEM→RMEM
|
||
warp-collective load) to fill the A/B register fragments just before ``cute.gemm()``.
|
||
|
||
The diagram below traces the full data flow of a warp MMA GEMM kernel, for the most
|
||
common case where A and B matrices are stored in GMEM and staged through SMEM
|
||
via ``cp.async``, and the output matrix --accumulated in RMEM-- is written back
|
||
to GMEM through an SMEM staging buffer for coalesced vectorized stores.
|
||
|
||
There are 3 parallel tracks where each has 2 sub-tracks. Three parallel tracks are
|
||
for operands A, B, and C/D, respectively. The two sub-tracks are for copying data between different memory
|
||
spaces and for MMA execution.
|
||
|
||
- **Operand A** (and symmetrically **Operand B**):
|
||
|
||
- First, we need to create SMEM tensors for A and B matrices: ``sA`` and ``sB``.
|
||
These tensors are physically allocated tensors that are the staging destination
|
||
of ``cp.async`` and the source of ``ldmatrix`` for the warp MMA instructions.
|
||
- Next the **data copy flow** creates the tensor views for copying data from GMEM to SMEM.
|
||
It starts with ``mA`` tensor that represents the matrix A in global memory.
|
||
Then ``mA`` → ``local_tile`` → ``gA`` operation creates the local tile view of A that is
|
||
the slice of A matrix needed to compute the given CTA's output tile.
|
||
A copy partition maps this tile to per-thread copy views (``tAgA``, ``tAsA``),
|
||
and the multi-stage ``cp.async`` pipeline performs
|
||
``copy(tiled_copy_A, tAgA[k], tAsA[stage])``.
|
||
- In parallel, the **MMA flow** turns the staged SMEM tensor into register fragments
|
||
consumed by the warp MMA. From the SMEM allocation ``sA``, MMA partitioning
|
||
produces the SMEM operand view ``tCsA = partition_A(sA)`` and the register-fragment
|
||
layout ``tCrA = make_fragment_A(tCsA)``. A dedicated S2R/``ldmatrix`` path then
|
||
retiles the source and destination (``partition_S`` on SMEM, ``retile`` on RMEM)
|
||
and executes ``copy(s2r_A, tCsA_copy_view[k_blk], tCrA_copy_view[k_blk])``
|
||
per k-block, filling the ``tCrA`` registers read by ``cute.gemm()``.
|
||
|
||
- **Accumulator C/D**:
|
||
|
||
- **RMEM accumulator flow** (MMA input/output): output tile views are formed by
|
||
``mC`` → ``local_tile`` → ``gC`` → ``partition_C`` → ``tCgC``, then
|
||
``make_fragment_C(tCgC)`` creates the register accumulator ``tCrC``.
|
||
Warp MMA keeps C/D entirely in RMEM, and ``tCrC`` is both the input C
|
||
and output D of ``cute.gemm()``.
|
||
- **Epilogue flow** (RMEM → SMEM → RMEM → GMEM): the epilogue converts accumulator
|
||
values (for example ``tCrD = epilogue_op(tCrC)``), stages them through SMEM
|
||
(``autovec_copy(tCrD, tCsC)``), reloads them into registers with the epilogue
|
||
copy layout, and performs coalesced vectorized GMEM stores via
|
||
``copy(tiled_copy_C, tCrC_epi, tCgC_epi)``.
|
||
|
||
.. code-block:: text
|
||
|
||
Operand A Dataflow Path Operand B Dataflow Path Accumulator C/D Dataflow Path
|
||
─────────────────────── ─────────────────────── ─────────────────────────────
|
||
|
||
mA: (M, K) [GMEM] mB: (N, K) [GMEM] ┌──── RMEM ──────────┐
|
||
│ │ │ make_fragment_C() │
|
||
│ local_tile(mA, cta_tiler, coord) │ local_tile(mB, cta_tiler, coord) │ tCrC: accumulator │
|
||
▼ ▼ └───────┬────────────┘
|
||
gA: (BM, BK, k) [GMEM] gB: (BN, BK, k) [GMEM] │
|
||
│ │ tCrC:(MMA,MMA_M,MMA_N) [RMEM]
|
||
│ ┌──── SMEM ─────────┐ │ ┌──── SMEM ─────────┐ │
|
||
│ │ sA: (BM,BK,PIPE) │ │ │ sB: (BN,BK,PIPE) │ │ mC: (M, N) [GMEM]
|
||
│ └──┬────────┬───────┘ │ └──┬────────┬───────┘ │ │
|
||
│ │ │ │ │ │ │ │ local_tile
|
||
│ │ thr_mma.partition_A(sA) │ │ thr_mma.partition_B(sB) │ ▼
|
||
│ │ ▼ │ │ ▼ │ gC: (BM, BN) [GMEM]
|
||
│ │ tCsA:(MMA,MMA_M, │ │ tCsB:(MMA,MMA_N, │ │ partition_C
|
||
│ │ MMA_K,PIPE) [SMEM] │ │ MMA_K,PIPE) [SMEM] │ ▼
|
||
│ │ │ │ │ │ │ tCgC:(MMA,MMA_M,
|
||
│ │ make_fragment_A(tCsA) │ │ make_fragment_B(tCsB) │ MMA_N)
|
||
│ │ ▼ │ │ ▼ │ [GMEM] (epi dest)
|
||
│ │ tCrA:(MMA,MMA_M, │ │ tCrB:(MMA,MMA_N, │ │
|
||
│ │ MMA_K) [RMEM] │ │ MMA_K) [RMEM] │ │
|
||
│ │ │ │ │ │ │ │
|
||
│ │ S2R retiling (ldmatrix): │ │ S2R retiling (ldmatrix): │ │
|
||
│ │ s2r_A = make_tiled_copy_A( │ │ s2r_B = make_tiled_copy_B( │ │
|
||
│ │ ldmatrix, mma) │ │ ldmatrix, mma) │ │
|
||
│ │ tCsA_copy_view = │ │ tCsB_copy_view = │ │
|
||
│ │ s2r_A.partition_S(sA) │ │ s2r_B.partition_S(sB) │ │
|
||
│ │ tCrA_copy_view = retile(tCrA) │ │ tCrB_copy_view = retile(tCrB) │ │
|
||
│ │ └─────────────┐ │ │ └─────────────┐ │ │
|
||
╰─────┤ │ ╰─────┤ │ │ │
|
||
▼ │ ▼ │ │ │
|
||
tAgA = thr_copy_A. │ tBgB = thr_copy_B. │ │ │
|
||
partition_S(gA) │ partition_S(gB) │ │ │
|
||
tAsA = thr_copy_A. │ tBsB = thr_copy_B. │ │ │
|
||
partition_D(sA) │ partition_D(sB) │ │ │
|
||
| │ | │ │ │
|
||
▼ │ ▼ │ │ │
|
||
┌───┴────────────────────┐ │ ┌──────┴─────────────────┐│ │ │
|
||
│ cp.async loop (k-tile):│ │ │ cp.async loop (k-tile):││ │ │
|
||
│ copy(tiled_copy_A, │ │ │ copy(tiled_copy_B, ││ │ │
|
||
│ tAgA[k], │ │ │ tBgB[k], ││ │ │
|
||
┌─▶│ tAsA[stage]) │ │ ┌──▶│ tBsB[stage]) ││ │ │
|
||
│ │ (writes into sA; │ │ │ │ (writes into sB; ││ │ │
|
||
│ │ ldmatrix reads sA) │ │ │ │ ldmatrix reads sB) ││ │ │
|
||
│ │ repeat for next k/stage│ │ │ │ repeat for next k/stage││ │ │
|
||
│ └────────────────────────┘ │ │ └────────────────────────┘│ │ │
|
||
│ │ │ │ │ │ │ │
|
||
└────────┘ ▼ └─────────┘ ▼ ▼ │
|
||
└───────┬───────────────────────────────┴───────────────────┘ │
|
||
│ │
|
||
▼ │
|
||
┌────────────────────────────────────────────────────────┐ │
|
||
│ MMA loop (k_blk): │ │
|
||
│ S2R: copy(s2r_A, tCsA_copy_view[k_blk], │ │
|
||
│ tCrA_copy_view[k_blk]) │ │
|
||
│ S2R: copy(s2r_B, tCsB_copy_view[k_blk], │ │
|
||
│ tCrB_copy_view[k_blk]) │ │
|
||
│ [SMEM → RMEM via ldmatrix; fills tCrA/tCrB] │ │
|
||
│ │ │
|
||
│ cute.gemm(tiled_mma, │ │
|
||
┌──▶ │ tCrC, D (output, RMEM), │ │
|
||
│ │ tCrA[k_blk], A (RMEM), │ │
|
||
│ │ tCrB[k_blk], B (RMEM), │ │
|
||
│ │ tCrC) C (accumulator, RMEM) │ │
|
||
│ └────────────────────────────────────────────────────────┘ │
|
||
│ │ │ │
|
||
└───────┘ | │
|
||
▼ │
|
||
Epilogue: │
|
||
tCrD = epilogue_op(tCrC) [RMEM] │
|
||
│ │
|
||
▼ │
|
||
sC = alloc(sC_layout) [SMEM] │
|
||
tCsC = thr_mma.partition_C(sC) │
|
||
R2S: autovec_copy(tCrD, tCsC) │
|
||
[RMEM → SMEM] │
|
||
│ │
|
||
▼ │
|
||
tCsC_epi = thr_copy_C.partition_S(sC) │
|
||
tCgC_epi = thr_copy_C.partition_D(gC) ◀─────────────────────────────────┘
|
||
tCrC_epi = make_fragment_like(...)
|
||
S2R: autovec_copy(tCsC_epi, tCrC_epi)
|
||
[SMEM → RMEM]
|
||
│
|
||
▼
|
||
Store: copy(tiled_copy_C, tCrC_epi, tCgC_epi)
|
||
[RMEM → GMEM]
|
||
|
||
**Naming convention:**
|
||
|
||
* ``mma_tiler`` = ``(BM, BN, BK)`` (CTA tiler dimensions)
|
||
* ``mX`` = global tensor (for example A as ``(M, K)``)
|
||
* ``gX`` = CTA-tiled GMEM slice (for example ``(BM, BK, k)`` for A)
|
||
* ``sX`` = SMEM allocation (for example ``(BM, BK, PIPE)``)
|
||
* ``tAgA`` / ``tAsA`` = ``cp.async`` source/destination partitions
|
||
(``CPY, CPY_M, CPY_K, ...``)
|
||
* ``tCsX`` = MMA-partitioned SMEM view (for example ``(MMA, MMA_M, MMA_K, PIPE)``)
|
||
* ``tCrX`` = register fragment (for example ``(MMA, MMA_M, MMA_K)``)
|
||
* ``tCrC`` = RMEM accumulator (``MMA, MMA_M, MMA_N``)
|
||
* ``tCgC`` = MMA-partitioned GMEM view for output (``MMA, MMA_M, MMA_N``)
|
||
* ``tCsA_copy_view`` / ``tCrA_copy_view`` = ``ldmatrix`` retile views for SMEM→RMEM
|
||
copy (from ``partition_S(sA)`` and ``retile(tCrA)`` on the S2R tiled copy;
|
||
C++ equivalents: ``tXsA`` / ``tXrA``)
|
||
* ``MMA`` = atom thread-value layout; ``MMA_M/MMA_N/MMA_K`` = repeat counts
|
||
(for example ``BM/inst_M``), ``k`` = outer K-tiles, ``PIPE`` = pipeline stages
|
||
|
||
|
||
Setting up the TiledMMA, MMA Ops
|
||
---------------------------------
|
||
|
||
As shown in the data flow overview, CuTe DSL provides many utilities to tile/partition
|
||
the global memory tensors, and create fragment views of SMEM and register tensors for MMA instructions.
|
||
|
||
To utilize these functions, we need to setup the TiledMMA, MMA Ops first.
|
||
|
||
Creating a Warp MMA Op
|
||
~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
A warp MMA op describes the hardware ``mma.sync.aligned`` instruction to use,
|
||
it has parameters like data types and instruction shape. The operand layout is
|
||
fixed (A = row-major, B = col-major).
|
||
|
||
.. code-block:: python
|
||
|
||
import cutlass
|
||
import cutlass.cute as cute
|
||
from cutlass.cute.nvgpu import warp
|
||
|
||
op = warp.MmaF16BF16Op(
|
||
cutlass.Float16, # A/B element type
|
||
cutlass.Float32, # accumulator type
|
||
(16, 8, 16), # instruction shape (M, N, K)
|
||
)
|
||
|
||
The key parameters are:
|
||
|
||
- **Instruction shape** ``(M, N, K)``: determines the size of one hardware MMA
|
||
instruction. Valid shapes depend on the data type (see ops table below).
|
||
- **A/B element type** (``ab_dtype``) and **accumulator type** (``acc_dtype``):
|
||
``Float32`` is always a valid accumulator; ``Float16`` is only valid for F16
|
||
inputs. Each op restricts ``ab_dtype`` to a specific family (F16/BF16, FP8,
|
||
MXF4, etc.).
|
||
- **Operand layout**: fixed to A = row-major (K-major), B = col-major (K-major).
|
||
Transpose is not supported. All 32 threads in a warp cooperate on each
|
||
instruction.
|
||
|
||
|
||
CuTe DSL provides implementation of many warp-level MMA ops:
|
||
|
||
.. list-table:: warp-level MMA ops
|
||
:header-rows: 1
|
||
:widths: 34 22 34 10
|
||
|
||
* - PTX name
|
||
- Python class
|
||
- Constructor parameters
|
||
- SM Arch
|
||
* - ``mma.sync.aligned.m16n8k{K}.row.col.{acc}.f16.f16`` / ``.bf16.bf16``
|
||
- ``warp.MmaF16BF16Op``
|
||
- ``ab_dtype, acc_dtype, shape_mnk``
|
||
- ``sm_80+``
|
||
* - ``mma.sync.aligned.m16n8k{K}.row.col.{acc}.{e4m3|e5m2}.{e4m3|e5m2}``
|
||
- ``warp.MmaFP8Op``
|
||
- ``ab_dtype, acc_dtype, shape_mnk``
|
||
- ``sm_89+``
|
||
* - ``mma.sync.aligned.kind::mxf4.block_scale.m16n8k64``
|
||
- ``warp.MmaMXF4Op``
|
||
- ``ab_dtype, acc_dtype, sf_type``
|
||
- ``sm_120a+``
|
||
* - ``mma.sync.aligned.kind::mxf4nvf4.block_scale.m16n8k64``
|
||
- ``warp.MmaMXF4NVF4Op``
|
||
- ``ab_dtype, acc_dtype, sf_type``
|
||
- ``sm_120a+``
|
||
|
||
.. {$nv-internal-release begin}
|
||
|
||
Internal builds additionally provide:
|
||
|
||
.. list-table:: Internal warp-level MMA ops
|
||
:header-rows: 1
|
||
:widths: 34 22 34 10
|
||
|
||
* - PTX name
|
||
- Python class
|
||
- Constructor parameters
|
||
- SM Arch
|
||
* - ``mma.sp.sync.aligned.m16n8k{K}.row.col.{acc}.f16.f16`` / ``.bf16.bf16``
|
||
- ``warp.MmaF16BF16SparseOp``
|
||
- ``ab_dtype, acc_dtype, shape_mnk, sparse_metadata_format``
|
||
- ``sm_80+``
|
||
|
||
.. {$nv-internal-release end}
|
||
|
||
Creating a Tiled MMA
|
||
~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
A ``TiledMma`` tiles the MMA atom across the thread block so that multiple
|
||
warps cooperate on a larger tile. You can pass the op directly or create an
|
||
explicit atom first:
|
||
|
||
.. code-block:: python
|
||
|
||
# Option 1: directly from op (common shorthand)
|
||
tiled_mma = cute.make_tiled_mma(op)
|
||
|
||
# Option 2: explicit atom creation
|
||
atom = cute.make_mma_atom(op)
|
||
tiled_mma = cute.make_tiled_mma(atom)
|
||
|
||
With no extra arguments this wraps a single atom — one warp, one
|
||
``(16, 8, K)`` tile. The optional ``atom_layout_mnk`` and
|
||
``permutation_mnk`` parameters (described in the subsections below)
|
||
control multi-warp tiling and per-thread value layout respectively.
|
||
|
||
Spatial tiling with a repeat count
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
A repeat tuple ``(M_rep, N_rep, K_rep)`` passed as ``atom_layout_mnk``
|
||
replicates the warp MMA atom across the M, N, and K dimensions, producing a
|
||
larger tiled MMA that is executed cooperatively by ``M_rep * N_rep * K_rep``
|
||
warps in a single ``cute.gemm`` call. Each entry in the repeat tuple
|
||
corresponds to one **warp** (32 threads), so ``(2, 2, 1)`` uses four warps —
|
||
a common configuration for warp-specialized SM80/SM89 kernels:
|
||
|
||
.. code-block:: python
|
||
|
||
atom = cute.make_mma_atom(op) # op shape: (16, 8, 16)
|
||
tiled_mma = cute.make_tiled_mma(
|
||
atom,
|
||
atom_layout_mnk=(2, 2, 1), # 4 warps: 2 in M, 2 in N
|
||
) # total tiled-MMA tile = (32, 16, 16)
|
||
|
||
The coordinates of atoms could be thought as a 3D coordinate: ``(m, n, k)``.
|
||
``m`` is the M repeat index, ``n`` is the N repeat index, and ``k`` is the K
|
||
repeat index. Each warp MMA atom is executed by a single warp within a
|
||
single CTA.
|
||
|
||
|
||
.. code-block:: text
|
||
|
||
Warp MMA Atom (16x8x16) make_tiled_mma(atom, (2, 2, 1))
|
||
+----------------+ +----------------+----------------+
|
||
| | | | | ^
|
||
| 16 x 8 | | Atom (0,0,0) | Atom (0,1,0) | |
|
||
| x 16 | --(2,2,1)--> | 16 x 8 | 16 x 8 | | 2 x inst_M
|
||
| | repeat | x 16 | x 16 | | = 32
|
||
| | | [Warp 0] | [Warp 2] | |
|
||
+----------------+ +----------------+----------------+ |
|
||
| | | |
|
||
| Atom (1,0,0) | Atom (1,1,0) | |
|
||
| 16 x 8 | 16 x 8 | |
|
||
| x 16 | x 16 | |
|
||
| [Warp 1] | [Warp 3] | v
|
||
+----------------+----------------+
|
||
<--- 2 x inst_N = 16 --->
|
||
K unchanged = 16
|
||
|
||
|
||
Custom tile permutation with ``permutation_mnk``
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
``permutation_mnk`` is an optional third argument to ``make_tiled_mma``.
|
||
Each of its three entries is a **per-mode permutation** of the M, N, and
|
||
K coordinates inside the tiled MMA. In the common case shown in this
|
||
section, each entry is just a size, which is the identity permutation of
|
||
that size; in that case ``permutation_mnk`` simply sets the **total tile
|
||
footprint** of the tiled MMA along each dimension. When a mode's size is
|
||
larger than the atom layout's natural coverage
|
||
(``atom_layout x inst_shape``), each thread receives additional values to
|
||
fill the extended region — the thread count stays the same, but every
|
||
thread holds more data. The general form, where an entry is a
|
||
``Layout`` that reorders coordinates inside a mode, is covered in the
|
||
subsection below.
|
||
|
||
The standard convention for warp MMA (used in ``tensorop_gemm.py`` and
|
||
throughout the Ampere examples) doubles the N dimension:
|
||
|
||
.. code-block:: python
|
||
|
||
# From examples/cute/ampere/kernel/dense_gemm/tensorop_gemm.py
|
||
permutation_mnk = (
|
||
atom_layout_mnk[0] * mma_inst_shape[0], # M: matches atom coverage
|
||
atom_layout_mnk[1] * mma_inst_shape[1] * 2, # N: 2x atom coverage
|
||
atom_layout_mnk[2] * mma_inst_shape[2], # K: matches atom coverage
|
||
)
|
||
|
||
tC = cute.make_layout(atom_layout_mnk)
|
||
tiled_mma = cute.make_tiled_mma(
|
||
op,
|
||
tC,
|
||
permutation_mnk=permutation_mnk,
|
||
)
|
||
|
||
**Why double N?** The atom's N dimension is only 8 (inst_N = 8). Without
|
||
a permutation, each thread's B-operand values span a single 8-wide
|
||
N-range, which may not align well with SMEM load widths. The ``* 2``
|
||
on N gives each thread's B fragment two 8-wide N-ranges instead of one,
|
||
aligning the access pattern with wider contiguous SMEM regions for more
|
||
efficient loads.
|
||
|
||
For ``atom_layout_mnk = (2, 2, 1)`` and ``inst_shape = (16, 8, 16)``:
|
||
|
||
- Atom coverage = ``(2x16, 2x8, 1x16) = (32, 16, 16)``
|
||
- ``permutation_mnk = (32, 32, 16)`` — N extended from 16 to 32
|
||
|
||
.. code-block:: text
|
||
|
||
Without permutation — natural atom coverage (M = 32, N = 16):
|
||
|
||
C tile (M=32, N=16)
|
||
+----------------+----------------+
|
||
| | | ^
|
||
| [Warp 0] | [Warp 2] | |
|
||
| 16 x 8 | 16 x 8 | | 2 x inst_M
|
||
| | | | = 32
|
||
+----------------+----------------+ |
|
||
| | | |
|
||
| [Warp 1] | [Warp 3] | |
|
||
| 16 x 8 | 16 x 8 | |
|
||
| | | v
|
||
+----------------+----------------+
|
||
<------------- N = 16 ---------->
|
||
(each warp owns one (16, 8) atom;
|
||
thread T0 of Warp 0 holds 4 C values in its 16x8 block)
|
||
|
||
With permutation_mnk = (32, 32, 16) — N extended from 16 to 32:
|
||
|
||
C tile (M=32, N=32)
|
||
+----------------+----------------+----------------+----------------+
|
||
| | | | | ^ N = 16 → 32:
|
||
| [Warp 0] | [Warp 2] | [Warp 0] | [Warp 2] | | atom pattern repeats
|
||
| 16 x 8 | 16 x 8 | 16 x 8 | 16 x 8 | | along N. Each thread
|
||
| | | | | | now holds 2x the
|
||
+----------------+----------------+----------------+----------------+ | values along N
|
||
| | | | | | (same threads, more
|
||
| [Warp 1] | [Warp 3] | [Warp 1] | [Warp 3] | | values per thread).
|
||
| 16 x 8 | 16 x 8 | 16 x 8 | 16 x 8 | |
|
||
| | | | | v
|
||
+----------------+----------------+----------------+----------------+
|
||
<---------------------------- N = 32 ---------------------------->
|
||
| atom coverage | value repeat |
|
||
|
||
|
||
Reordering coordinates with a per-mode ``Layout``
|
||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||
|
||
So far each entry of ``permutation_mnk`` has been an integer, which is
|
||
shorthand for the identity layout ``Layout<Shape<S>, Stride<_1>>`` — the
|
||
atom pattern simply tiles to fill an ``S``-wide footprint. The general
|
||
form lets each entry be a ``Layout`` that **reorders coordinates inside
|
||
that mode** while keeping the same total size. That reordering is what
|
||
gives the parameter its name; the integer-only cases used earlier are
|
||
just the identity permutation.
|
||
|
||
The canonical illustration is the SM70 example from
|
||
`0t_mma_atom.md <../../cpp/cute/0t_mma_atom.md>`_. Take a 2x2 tiled MMA
|
||
of ``SM70_8x8x4_F32F16F16F32_NT`` atoms with a ``32x32x4`` footprint.
|
||
Without any M-mode permutation, thread ``T0``'s 8 A-values land at the
|
||
following ``(m, k)`` coordinates::
|
||
|
||
T0V0 => (0, 0) T0V4 => (16, 0)
|
||
T0V1 => (1, 0) T0V5 => (17, 0)
|
||
T0V2 => (2, 0) T0V6 => (18, 0)
|
||
T0V3 => (3, 0) T0V7 => (19, 0)
|
||
|
||
— two separate runs of 4 along M, with a gap from m=4 to m=15. We may
|
||
prefer those 8 values to sit in **one contiguous run** in the logical
|
||
M-coordinates (e.g. so register or SMEM layouts pack cleanly). Passing
|
||
the M-mode layout ``(4, 4, 2):(1, 8, 4)`` does exactly that: it is a
|
||
scatter permutation telling each old m-coord where to go in the new
|
||
image.
|
||
|
||
.. code-block:: text
|
||
|
||
old m-coord: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
|
||
new m-coord: 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14 15 20 21 22 23 28 29 30 31
|
||
|
||
After the permutation, ``T0``'s 8 A-values occupy ``m = 0..7`` — one
|
||
contiguous run — and every other thread's M-values become equally
|
||
contiguous. Thread-data ownership and value counts are unchanged; only
|
||
the **mapping from values to m-coordinates** is permuted.
|
||
|
||
In CuTeDSL the permuted entry is built with ``cute.make_layout``;
|
||
identity entries stay as integers:
|
||
|
||
.. code-block:: python
|
||
|
||
m_perm = cute.make_layout((4, 4, 2), stride=(1, 8, 4))
|
||
tiled_mma = cute.make_tiled_mma(
|
||
op, # SM70_8x8x4 NT atom
|
||
atom_layout_mnk=(2, 2, 1),
|
||
permutation_mnk=(m_perm, 32, 4), # M: scatter, N/K: identity sizes
|
||
)
|
||
|
||
The same mechanism applies to the N and K modes — any subset of the
|
||
three entries can be an integer (identity) or a ``Layout`` (real
|
||
permutation). For warp MMAs the most common case in practice is still
|
||
the integer-only form shown earlier in this section; the ``Layout`` form
|
||
is the tool you reach for when a register or SMEM layout wants each
|
||
thread's fragment to be contiguous in logical coordinates.
|
||
|
||
|
||
Partitioning Tensors
|
||
---------------------
|
||
|
||
Before computing, partition the CTA-tiled tensors according to the
|
||
tiled MMA layout. Warp MMA partitioning is **per-thread**: each of
|
||
the 32 threads in a warp (or 128 threads across 4 warps) receives
|
||
its own slice of the data, sized to match the register fragments
|
||
the MMA instruction expects.
|
||
|
||
Example: ``GEMM (M, N, K) = (512, 512, 256)``,
|
||
``cta_tiler = (128, 128, 32)``, ``atom_layout_mnk = (2, 2, 1)``,
|
||
F16 atom = m16n8k16, ``permutation_mnk = (32, 32, 16)``,
|
||
``num_stages = 4``, 4 warps = 128 threads.
|
||
|
||
Global matrices:
|
||
|
||
.. code-block:: text
|
||
|
||
mA: (M, K) = (512, 256) mB: (N, K) = (512, 256) mC: (M, N) = (512, 512)
|
||
|
||
K=256 K=256 N=512
|
||
|<--------->| |<--------->| |<---------------->|
|
||
+-----------+ +-----------+ +----+----+----+---+
|
||
| | ^ | | ^ | | | | | ^
|
||
| mA | | M=512 | mB | | N=512 | | | | | | M=512
|
||
| | v | | v | | | | | v
|
||
+-----------+ +-----------+ +----+----+----+---+
|
||
|
||
Tiling with ``cta_tiler = (BM, BN, BK) = (128, 128, 32)`` gives
|
||
M/BM = 4 tiles, N/BN = 4 tiles, K/BK = 8 tiles:
|
||
|
||
.. code-block:: text
|
||
|
||
mA tiled into (M/BM x K/BK) mB tiled into (N/BN x K/BK) mC tiled into (M/BM x N/BN)
|
||
= (4 x 8) blocks = (4 x 8) blocks = (4 x 4) blocks
|
||
|
||
BK=32 x8 BK=32 x8 BN=128 x4
|
||
|<-->| |<-->| |<------>|
|
||
+----+----+-- --+ +----+----+-- --+ +--------+--------+-- --+
|
||
| | |..| | ^ BM=128 | | |..| | ^ BN=128 | (0,0) | (0,1) |.. | ^ BM=128
|
||
+----+----+-- --+ v +----+----+-- --+ v +--------+--------+ + v
|
||
| | |..| | ^ BM=128 | | |..| | ^ BN=128 | (1,0) | (1,1) |.. | ^ BM=128
|
||
+----+----+-- --+ v +----+----+-- --+ v +--------+--------+ + v
|
||
| | |..| | ^ | | |..| | ^ | ... | ... |.. | ^
|
||
+----+----+-- --+ v +----+----+-- --+ v +--------+--------+-- --+ v
|
||
| | |..| | ^ | | |..| | ^ | (3,0) | (3,1) |.. | ^
|
||
+----+----+-- --+ v +----+----+-- --+ v +--------+--------+-- --+ v
|
||
|
||
Each CTA picks one (M-tile, N-tile) coordinate.
|
||
For example, CTA at ``tiler_coord = (0, 1, :)``.
|
||
|
||
After ``local_tile`` — one CTA's tile (``k = K/BK = 256/32 = 8``):
|
||
|
||
.. code-block:: text
|
||
|
||
gA: (BM, BK, k) = (128, 32, 8) gB: (BN, BK, k) = (128, 32, 8) gC: (BM, BN) = (128, 128)
|
||
|
||
BK=32 BK=32 BN=128
|
||
|<----->| |<----->| |<-------->|
|
||
+-------+-- +-------+-- +----------+
|
||
| |.. | |.. | | ^
|
||
BM= | gA | k=8 BN= | gB | k=8 BM= | gC | | 128
|
||
128 | | 128 | | 128 | | v
|
||
+-------+ +-------+ +----------+
|
||
|
||
SMEM tensors ``sA`` and ``sB`` have a pipeline staging dimension:
|
||
|
||
.. code-block:: text
|
||
|
||
sA: (BM, BK, PIPE) = (128, 32, 4) sB: (BN, BK, PIPE) = (128, 32, 4)
|
||
|
||
``get_slice(tidx)`` — each thread receives its own per-thread partition.
|
||
The tiled MMA footprint is ``permutation_mnk = (32, 32, 16)``, so BM,
|
||
BN, and BK are each subdivided into MMA-sized blocks:
|
||
|
||
.. code-block:: text
|
||
|
||
sA: partition into (MMA, MMA_M, MMA_K, PIPE)
|
||
|
||
Each SMEM stage (BM=128, BK=32):
|
||
|
||
perm_K perm_K perm_M=32
|
||
=16 =16 |<---->|
|
||
|<--->|<--->| +------+------+------+------+
|
||
+-----+-----+ ^ | | | | | ^
|
||
| 0 | 1 | | perm_M=32 | 0 | 1 | 2 | 3 | | perm_N
|
||
+-----+-----+ v | | | | | v =32
|
||
| 0 | 1 | ^ +------+------+------+------+
|
||
| | | | perm_M=32 MMA_N = BN/perm_N = 4
|
||
+-----+-----+ v
|
||
| 0 | 1 | ^ sB: partition into (MMA, MMA_N, MMA_K, PIPE)
|
||
| | | |
|
||
+-----+-----+ v gC: partition into (MMA, MMA_M, MMA_N)
|
||
| 0 | 1 | ^
|
||
| | | |
|
||
+-----+-----+ v
|
||
MMA_K = BK/perm_K = 2
|
||
MMA_M = BM/perm_M = 4
|
||
|
||
After partition (per thread, e.g. thread ``tidx``):
|
||
|
||
- ``tCsA: (MMA, MMA_M, MMA_K, PIPE) = (MMA, 4, 2, 4)`` — MMA_M = BM/perm_M = 128/32 = 4, MMA_K = BK/perm_K = 32/16 = 2
|
||
- ``tCsB: (MMA, MMA_N, MMA_K, PIPE) = (MMA, 4, 2, 4)`` — MMA_N = BN/perm_N = 128/32 = 4, MMA_K = BK/perm_K = 32/16 = 2
|
||
- ``tCgC: (MMA, MMA_M, MMA_N) = (MMA, 4, 4)`` — MMA_M = 128/32 = 4, MMA_N = 128/32 = 4
|
||
|
||
The first mode ``MMA`` contains the atom's **thread × value** layout — it
|
||
encodes which registers within a single thread hold which matrix
|
||
elements. The remaining modes are repeat counts that tile the atom
|
||
across the full CTA tile.
|
||
|
||
.. code-block:: python
|
||
|
||
@cute.kernel
|
||
def kernel(tiled_mma: cute.TiledMma, ...):
|
||
tidx, _, _ = cute.arch.thread_idx()
|
||
|
||
# CTA-tiled global tensors
|
||
gA = cute.local_tile(mA, cta_tiler, tiler_coord, proj=(1, None, 1))
|
||
gB = cute.local_tile(mB, cta_tiler, tiler_coord, proj=(None, 1, 1))
|
||
gC = cute.local_tile(mC, cta_tiler, tiler_coord, proj=(1, 1, None))
|
||
|
||
# Per-thread partition via the thread index
|
||
thr_mma = tiled_mma.get_slice(tidx)
|
||
|
||
# SMEM partitions (used by make_fragment_A/B and ldmatrix retiling)
|
||
tCsA = thr_mma.partition_A(sA) # (MMA, MMA_M, MMA_K, PIPE)
|
||
tCsB = thr_mma.partition_B(sB) # (MMA, MMA_N, MMA_K, PIPE)
|
||
|
||
# C partitions for epilogue staging (SMEM) and destination (GMEM)
|
||
tCsC = thr_mma.partition_C(sC) # (MMA, MMA_M, MMA_N)
|
||
tCgC = thr_mma.partition_C(gC) # (MMA, MMA_M, MMA_N)
|
||
|
||
.. note:: The ``tCsA`` / ``tCsB`` SMEM partitions are not read directly
|
||
by the GEMM — they establish the **shape** that
|
||
``make_fragment_A`` / ``make_fragment_B`` use to allocate register
|
||
fragments. Actual SMEM→RMEM data movement goes through the S2R
|
||
``ldmatrix`` retiling path (see `Making Fragments`_).
|
||
|
||
Pre and Post-Conditions for Partitioning
|
||
-----------------------------------------
|
||
|
||
* The inputs of the partition should be at least rank-2 tensors.
|
||
* The output of the partition will have the layout that is compatible with the MMA atom's operand:
|
||
|
||
- For A, the output will have the layout ``(MMA, MMA_M, MMA_K, ...)``.
|
||
- For B, the output will have the layout ``(MMA, MMA_N, MMA_K, ...)``.
|
||
- For C, the output will have the layout ``(MMA, MMA_M, MMA_N, ...)``.
|
||
|
||
* Note that the partition doesn't enforce any rules on the tensor's memory space or the tensor's data type. It only cares about the layout.
|
||
|
||
|
||
Making Fragments
|
||
-----------------
|
||
|
||
Fragments are the tensors that the warp MMA instruction operates on. For
|
||
warp MMA:
|
||
|
||
- **Fragment A**: per-thread register fragment holding one operand-A K-block.
|
||
- **Fragment B**: per-thread register fragment holding one operand-B K-block.
|
||
- **Fragment C (accumulator)**: per-thread register fragment that lives in
|
||
RMEM and serves as both the input C and output D of ``cute.gemm()``.
|
||
|
||
Creating register fragments and ``ldmatrix`` copy views
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
Warp MMA fragments are actual per-thread register tensors, not descriptors.
|
||
Fragment creation has three parts:
|
||
|
||
**1. A and B fragments**
|
||
|
||
``make_fragment_A`` and ``make_fragment_B`` take one stage of the
|
||
MMA-partitioned SMEM views (``tCsA`` / ``tCsB``) and allocate register
|
||
fragments with a matching thread-local layout. This establishes the shape
|
||
only; no data is loaded yet.
|
||
|
||
.. code-block:: python
|
||
|
||
# Per-thread MMA partitions
|
||
# (sA/sB are the staged SMEM tensors — see "Creating SMEM layouts for A and B")
|
||
tCsA = thr_mma.partition_A(sA) # (MMA, MMA_M, MMA_K, PIPE)
|
||
tCsB = thr_mma.partition_B(sB) # (MMA, MMA_N, MMA_K, PIPE)
|
||
|
||
# Register fragments for one pipeline stage
|
||
tCrA = tiled_mma.make_fragment_A(
|
||
tCsA[None, None, None, 0]
|
||
) # (MMA, MMA_M, MMA_K)
|
||
tCrB = tiled_mma.make_fragment_B(
|
||
tCsB[None, None, None, 0]
|
||
) # (MMA, MMA_N, MMA_K)
|
||
|
||
Continuing the running example from `Partitioning Tensors`_ (F16
|
||
``m16n8k16``, ``cta_tiler = (128, 128, 32)``, ``permutation_mnk = (32, 32,
|
||
16)``, ``num_stages = 4``):
|
||
|
||
.. code-block:: text
|
||
|
||
tCsA: (MMA, MMA_M=4, MMA_K=2, PIPE=4)
|
||
tCsB: (MMA, MMA_N=4, MMA_K=2, PIPE=4)
|
||
|
||
make_fragment_A(tCsA[..., stage]) -> tCrA: (MMA, 4, 2)
|
||
make_fragment_B(tCsB[..., stage]) -> tCrB: (MMA, 4, 2)
|
||
|
||
Each element of ``tCrA`` / ``tCrB`` is a register value owned by the current
|
||
thread. Together, the 32 threads in the warp hold the full operand fragment
|
||
that one ``mma.sync.aligned`` instruction consumes.
|
||
|
||
**2. C fragment (accumulator)**
|
||
|
||
``make_fragment_C`` allocates the accumulator registers for the CTA tile
|
||
slice owned by the current thread. The accumulator usually starts at zero
|
||
before the K loop and is updated in-place by each ``cute.gemm()`` call.
|
||
|
||
.. code-block:: python
|
||
|
||
tCgC = thr_mma.partition_C(gC) # (MMA, MMA_M, MMA_N)
|
||
tCrC = tiled_mma.make_fragment_C(tCgC)
|
||
tCrC.fill(0.0)
|
||
|
||
For the same running example:
|
||
|
||
.. code-block:: text
|
||
|
||
tCgC: (MMA, MMA_M=4, MMA_N=4)
|
||
make_fragment_C(tCgC) -> tCrC: (MMA, 4, 4)
|
||
|
||
``tCrC`` stays in registers for the entire main loop and serves as both the
|
||
input C and output D argument of ``cute.gemm()``.
|
||
|
||
**3. SMEM → RMEM load (``ldmatrix`` retiling)**
|
||
|
||
The register fragments above are storage only — before ``cute.gemm()`` can
|
||
consume ``tCrA`` and ``tCrB``, each K-block must be loaded from shared
|
||
memory into those registers. This is done via a separate tiled copy built
|
||
from an ``ldmatrix`` copy atom and linked to the tiled MMA with
|
||
``make_tiled_copy_A`` / ``make_tiled_copy_B``. The copy's ``retile()``
|
||
call remaps the MMA fragment's register layout to match what the
|
||
``ldmatrix`` instruction writes.
|
||
|
||
.. code-block:: python
|
||
|
||
# 1. Create ldmatrix copy atom → tiled copy tied to the MMA layout
|
||
s2r_atom_A = cute.make_copy_atom(LdMatrix8x8x16bOp(...), dtype)
|
||
s2r_tiled_A = cute.make_tiled_copy_A(s2r_atom_A, tiled_mma)
|
||
|
||
# 2. Build SMEM-side and RMEM-side views for the copy
|
||
thr_s2r_A = s2r_tiled_A.get_slice(tidx)
|
||
tCsA_copy_view = thr_s2r_A.partition_S(sA) # SMEM source
|
||
tCrA_copy_view = thr_s2r_A.retile(tCrA) # RMEM dest (retiled)
|
||
|
||
# 3. Load one k-block from SMEM into the MMA fragment (in the main loop)
|
||
cute.copy(s2r_tiled_A, tCsA_copy_view[None, None, k_block],
|
||
tCrA_copy_view[None, None, k_block])
|
||
|
||
See ``tensorop_gemm.py`` for the complete implementation including the
|
||
``ldmatrix`` transpose flag, FP8 variants, and operand B.
|
||
|
||
|
||
Creating SMEM layouts for A and B
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
The SMEM layouts define how A and B tiles are staged in shared memory before
|
||
the ``ldmatrix`` loads. For warp MMA, these layouts must satisfy two goals at
|
||
the same time:
|
||
|
||
- **Efficient GMEM -> SMEM copy:** ``cp.async`` should write contiguous
|
||
16-byte regions for each thread.
|
||
- **Bank-conflict-free SMEM -> RMEM load:** the later ``ldmatrix`` loads
|
||
should see a swizzled layout that matches the warp MMA operand access
|
||
pattern.
|
||
|
||
The Ampere dense GEMM example
|
||
(``examples/cute/ampere/kernel/dense_gemm/tensorop_gemm.py``) builds these
|
||
layouts inline with a helper named ``_make_smem_layout_AB``.
|
||
|
||
**Host side** (``@cute.jit``):
|
||
|
||
.. code-block:: python
|
||
|
||
# 16 bytes per thread for GMEM -> SMEM copies
|
||
ab_copy_bits = 128
|
||
|
||
sA_layout, sA_swizzle = self._make_smem_layout_AB(
|
||
mA.element_type, # dtype (e.g. Float16)
|
||
self.a_major_mode, # row-major or col-major
|
||
ab_copy_bits, # copy width in bits (128 = 16 bytes)
|
||
(self.cta_tiler[0], # BM
|
||
self.cta_tiler[2], # BK
|
||
self.num_stages), # PIPE
|
||
)
|
||
sB_layout, sB_swizzle = self._make_smem_layout_AB(
|
||
mB.element_type,
|
||
self.b_major_mode,
|
||
ab_copy_bits,
|
||
(self.cta_tiler[1], # BN
|
||
self.cta_tiler[2], # BK
|
||
self.num_stages), # PIPE
|
||
)
|
||
|
||
Here ``smem_tiler`` is ``(M_or_N, K, PIPE)``: ``(BM, BK, PIPE)`` for A and
|
||
``(BN, BK, PIPE)`` for B. The helper returns:
|
||
|
||
- ``sX_layout``: the logical SMEM layout with shape ``(BM_or_BN, BK, PIPE)``.
|
||
- ``sX_swizzle``: the swizzle applied when the tensor is materialized in SMEM.
|
||
|
||
The helper from ``tensorop_gemm.py`` implements the following four steps:
|
||
|
||
1. **Pick the major-mode size.** For a row-major operand, the contiguous
|
||
dimension is K, so the helper uses ``smem_tiler[1]``. For a col-major
|
||
operand, the contiguous dimension is M or N, so it uses ``smem_tiler[0]``.
|
||
|
||
2. **Cap the contiguous span at 128 bytes.** This keeps the layout atom within
|
||
the swizzle span used by the example. The cap is 64 elements for F16/BF16
|
||
and 128 elements for FP8.
|
||
|
||
3. **Build the swizzle.** With ``copy_bits = 128`` (16 bytes), the helper
|
||
derives three arguments for ``make_swizzle``:
|
||
|
||
- ``swizzle_bits = log2(major_mode_size * dtype.width / copy_bits)``,
|
||
capped at 3. This is the number of address bits that get XOR'd.
|
||
- ``base_bits = log2(copy_bits / 8)`` — log2 of the copy width in
|
||
bytes (= 4 for 16-byte copies).
|
||
- ``shift_bits = log2(copy_bits / dtype.width)`` — log2 of the copy
|
||
width in elements (= 3 for F16 with 128-bit copies, i.e. 8 elements).
|
||
|
||
4. **Build an 8-row layout atom and tile it.** The constant 8 comes from
|
||
``ldmatrix``: each warp-level load touches 8 rows of shared memory
|
||
(32 threads, 4 matrices per load). Row-major uses an atom
|
||
``(8, major_mode_size):(major_mode_size, 1)`` — 8 rows of contiguous
|
||
K-elements. Col-major uses
|
||
``(major_mode_size, 8):(1, major_mode_size)`` — contiguous MN-elements
|
||
across 8 K-rows. ``tile_to_shape`` then broadcasts that atom across the
|
||
full ``(M_or_N, K, PIPE)`` SMEM tensor.
|
||
|
||
For the running F16 example (``cta_tiler = (128, 128, 32)``,
|
||
``num_stages = 4``, ``copy_bits = 128``):
|
||
|
||
.. code-block:: text
|
||
|
||
A operand (row-major, smem_tiler = (128, 32, 4)):
|
||
major_mode_size = 32
|
||
atom = (8, 32):(32, 1)
|
||
swizzle = make_swizzle(2, 4, 3)
|
||
tiled layout -> sA: (128, 32, 4)
|
||
|
||
B operand (col-major, smem_tiler = (128, 32, 4)):
|
||
major_mode_size = min(128, 64) = 64
|
||
atom = (64, 8):(1, 64)
|
||
swizzle = make_swizzle(3, 4, 3)
|
||
tiled layout -> sB: (128, 32, 4)
|
||
|
||
|
||
**Kernel side** (``@cute.kernel``):
|
||
|
||
The layout and swizzle are passed to shared-memory allocation:
|
||
|
||
.. code-block:: python
|
||
|
||
@cute.struct
|
||
class SharedStorageAB:
|
||
a: cute.struct.Align[
|
||
cute.struct.MemRange[mA.element_type, cute.cosize(sA_layout)],
|
||
16,
|
||
]
|
||
b: cute.struct.Align[
|
||
cute.struct.MemRange[mB.element_type, cute.cosize(sB_layout)],
|
||
16,
|
||
]
|
||
|
||
sA = SharedStorageAB(storage).a.get_tensor(sA_layout, swizzle=sA_swizzle)
|
||
sB = SharedStorageAB(storage).b.get_tensor(sB_layout, swizzle=sB_swizzle)
|
||
|
||
After allocation:
|
||
|
||
- ``sA`` has shape ``(BM, BK, PIPE)``.
|
||
- ``sB`` has shape ``(BN, BK, PIPE)``.
|
||
|
||
These are the staged SMEM tensors written by ``cp.async`` and later consumed by
|
||
``partition_A`` / ``partition_B``, ``make_fragment_A`` / ``make_fragment_B``,
|
||
and the ``ldmatrix`` copy views described in `Making Fragments`_.
|
||
|
||
|
||
Executing the GEMM (Main Loop)
|
||
-------------------------------
|
||
|
||
The main loop iterates over K-tiles and, within each tile, over k-blocks
|
||
(``num_k_block = BK / perm_K``). Each k-block loads A and B from SMEM into
|
||
registers via ``ldmatrix``, then issues ``cute.gemm``.
|
||
|
||
.. code-block:: python
|
||
|
||
tCrC.fill(0.0)
|
||
|
||
for k_tile in range(k_tile_count):
|
||
for k_block in cutlass.range(num_k_block, unroll_full=True):
|
||
# Wait for next SMEM stage at the tile boundary
|
||
if k_block == num_k_block - 1:
|
||
cute.arch.cp_async_wait_group(num_smem_stages - 2)
|
||
cute.arch.sync_threads()
|
||
|
||
# ldmatrix: prefetch next k-block from SMEM → RMEM
|
||
k_block_next = (k_block + 1) % num_k_block
|
||
cute.copy(tiled_copy_s2r_A, tCsA_p[None, None, k_block_next],
|
||
tCrA_copy_view[None, None, k_block_next])
|
||
cute.copy(tiled_copy_s2r_B, tCsB_p[None, None, k_block_next],
|
||
tCrB_copy_view[None, None, k_block_next])
|
||
|
||
# cp.async: issue GMEM → SMEM for next K-tile
|
||
# ... (see tensorop_gemm.py for pipeline pointer management)
|
||
|
||
# MMA: tCrC += tCrA * tCrB
|
||
cute.gemm(tiled_mma, tCrC, tCrA[None, None, k_block], tCrB[None, None, k_block], tCrC)
|
||
|
||
cute.arch.cp_async_wait_group(0)
|
||
cute.arch.sync_threads()
|
||
|
||
Key points:
|
||
|
||
- ``cute.gemm`` is **synchronous** — it emits ``mma.sync.aligned``
|
||
instructions. There is no accumulate-mode flag; the accumulator
|
||
(``tCrC``) is always read and written.
|
||
- All operands must be in **registers** before ``cute.gemm`` is called.
|
||
The ``ldmatrix`` copies above prefetch the next k-block into
|
||
``tCrA`` / ``tCrB`` from SMEM each iteration.
|
||
- The ``cp.async`` / ``cp_async_wait_group`` calls manage the GMEM→SMEM
|
||
pipeline; see ``tensorop_gemm.py`` for predication, K-residue handling,
|
||
and pipeline pointer management.
|
||
|
||
|
||
Complete Workflow
|
||
------------------
|
||
|
||
Putting it all together, a typical Ampere warp MMA GEMM has this structure:
|
||
|
||
**Host function** (``@cute.jit``):
|
||
|
||
.. code-block:: python
|
||
|
||
import cutlass
|
||
import cutlass.cute as cute
|
||
|
||
@cute.jit
|
||
def host_function(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor, stream):
|
||
# 1. Create the MMA op and tiled MMA
|
||
op = cute.nvgpu.warp.MmaF16BF16Op(cutlass.Float16, cutlass.Float32, (16, 8, 16))
|
||
atom_layout_mnk = (2, 2, 1)
|
||
permutation_mnk = (
|
||
atom_layout_mnk[0] * 16,
|
||
atom_layout_mnk[1] * 8 * 2,
|
||
atom_layout_mnk[2] * 16,
|
||
)
|
||
tC = cute.make_layout(atom_layout_mnk)
|
||
tiled_mma = cute.make_tiled_mma(op, tC, permutation_mnk=permutation_mnk)
|
||
|
||
# 2. Create SMEM layouts
|
||
ab_copy_bits = 128
|
||
sA_layout, sA_swizzle = _make_smem_layout_AB(
|
||
mA.element_type, a_major_mode, ab_copy_bits,
|
||
(cta_tiler[0], cta_tiler[2], num_stages),
|
||
)
|
||
sB_layout, sB_swizzle = _make_smem_layout_AB(
|
||
mB.element_type, b_major_mode, ab_copy_bits,
|
||
(cta_tiler[1], cta_tiler[2], num_stages),
|
||
)
|
||
|
||
# 3. Launch the kernel
|
||
kernel(mA, mB, mC, ..., tiled_mma, sA_layout, sA_swizzle,
|
||
sB_layout, sB_swizzle).launch(
|
||
grid=grid, block=[128, 1, 1], stream=stream,
|
||
)
|
||
|
||
**Kernel function** (``@cute.kernel``):
|
||
|
||
.. code-block:: python
|
||
|
||
@cute.kernel
|
||
def kernel(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor,
|
||
..., tiled_mma: cute.TiledMma):
|
||
tidx, _, _ = cute.arch.thread_idx()
|
||
bidx, bidy, bidz = cute.arch.block_idx()
|
||
|
||
# -- CTA-tiled global tensors --
|
||
gA = cute.local_tile(mA[None, None, bidz], cta_tiler, (bidx, bidy, None), proj=(1, None, 1))
|
||
gB = cute.local_tile(mB[None, None, bidz], cta_tiler, (bidx, bidy, None), proj=(None, 1, 1))
|
||
gC = cute.local_tile(mC[None, None, bidz], cta_tiler, (bidx, bidy, None), proj=(1, 1, None))
|
||
|
||
# -- Allocate SMEM --
|
||
@cute.struct
|
||
class SharedStorageAB:
|
||
a: cute.struct.Align[cute.struct.MemRange[mA.element_type, cute.cosize(sA_layout)], 16]
|
||
b: cute.struct.Align[cute.struct.MemRange[mB.element_type, cute.cosize(sB_layout)], 16]
|
||
|
||
smem = cutlass.utils.SmemAllocator()
|
||
storage = smem.allocate(SharedStorageAB)
|
||
sA = SharedStorageAB(storage).a.get_tensor(sA_layout, swizzle=sA_swizzle) # (BM, BK, PIPE)
|
||
sB = SharedStorageAB(storage).b.get_tensor(sB_layout, swizzle=sB_swizzle) # (BN, BK, PIPE)
|
||
sC = ... # (BM, BN) SMEM for epilogue (non-MMA, see tensorop_gemm.py)
|
||
|
||
# -- GMEM → SMEM copy partitions (cp.async) --
|
||
# ... setup tAgA, tAsA, tBgB, tBsB (see tensorop_gemm.py)
|
||
|
||
# -- MMA partitions and fragments --
|
||
thr_mma = tiled_mma.get_slice(tidx)
|
||
tCsA = thr_mma.partition_A(sA) # (MMA, MMA_M, MMA_K, PIPE)
|
||
tCsB = thr_mma.partition_B(sB) # (MMA, MMA_N, MMA_K, PIPE)
|
||
tCsC = thr_mma.partition_C(sC) # (MMA, MMA_M, MMA_N)
|
||
tCgC = thr_mma.partition_C(gC) # (MMA, MMA_M, MMA_N)
|
||
tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0]) # (MMA, MMA_M, MMA_K)
|
||
tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0]) # (MMA, MMA_N, MMA_K)
|
||
tCrC = tiled_mma.make_fragment_C(tCgC) # (MMA, MMA_M, MMA_N)
|
||
tCrC.fill(0.0)
|
||
|
||
# -- ldmatrix retiling (see "Making Fragments" § SMEM → RMEM load) --
|
||
# ... build tiled_copy_s2r_A/B from LdMatrix8x8x16bOp + make_tiled_copy_A/B
|
||
# ... then: tCsA_copy_view = partition_S(sA), tCrA_copy_view = retile(tCrA), etc.
|
||
|
||
# -- Prologue: cp.async fills num_stages-1 SMEM buffers --
|
||
# -- Prefetch first k-block into registers via ldmatrix --
|
||
# ... (see tensorop_gemm.py for predication, residual_k, and pipeline setup)
|
||
|
||
# -- Main loop --
|
||
for k_tile in range(k_tile_count):
|
||
for k_block in cutlass.range(num_k_block, unroll_full=True):
|
||
if k_block == num_k_block - 1:
|
||
cute.arch.cp_async_wait_group(num_smem_stages - 2)
|
||
cute.arch.sync_threads()
|
||
|
||
# ldmatrix: prefetch next k-block from SMEM → RMEM
|
||
# tCsA_p / tCsB_p are per-pipeline-stage slices, e.g.:
|
||
# tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
|
||
k_block_next = (k_block + 1) % num_k_block
|
||
cute.copy(tiled_copy_s2r_A, tCsA_p[None, None, k_block_next],
|
||
tCrA_copy_view[None, None, k_block_next])
|
||
cute.copy(tiled_copy_s2r_B, tCsB_p[None, None, k_block_next],
|
||
tCrB_copy_view[None, None, k_block_next])
|
||
|
||
# cp.async: issue GMEM → SMEM for next K-tile
|
||
# ... (see tensorop_gemm.py for pipeline pointer management)
|
||
|
||
# MMA
|
||
cute.gemm(tiled_mma, tCrC, tCrA[None, None, k_block],
|
||
tCrB[None, None, k_block], tCrC)
|
||
|
||
# -- Epilogue: RMEM → SMEM → RMEM → GMEM --
|
||
cute.arch.cp_async_wait_group(0)
|
||
cute.arch.sync_threads()
|
||
tCrD = cute.make_fragment_like(tCrC, c_dtype)
|
||
tCrD[None] = epilogue_op(tCrC.load()).to(c_dtype)
|
||
cute.autovec_copy(tCrD, tCsC) # RMEM → SMEM
|
||
cute.arch.sync_threads()
|
||
# ... reload with epilogue thread layout, then vectorized store to GMEM
|
||
|
||
|
||
Beyond Simple Dense MMAs
|
||
------------------------
|
||
|
||
The warp MMA DSL supports more complex MMA operations beyond simple dense MMA:
|
||
|
||
- Block-scaled MMA
|
||
|
||
.. {$nv-internal-release begin}
|
||
|
||
Internal builds additionally provide:
|
||
|
||
- Sparse MMA
|
||
|
||
.. {$nv-internal-release end}
|
||
|
||
.. {$nv-internal-release begin}
|
||
|
||
Sparse MMA
|
||
~~~~~~~~~~
|
||
|
||
Sparse MMA exploits **2:4 structured sparsity** in operand A: out of every
|
||
4 consecutive K-elements, exactly 2 are non-zero. The hardware consumes a
|
||
compressed A operand together with a compact **metadata** tensor ``E`` that
|
||
encodes which 2 of 4 positions are non-zero.
|
||
|
||
Compared to dense MMA, the MMA API differences are:
|
||
|
||
**1. MMA op creation** — use ``MmaF16BF16SparseOp`` with an extra
|
||
``sparse_metadata_format`` parameter. The sparse instruction K is doubled
|
||
relative to dense (dense ``m16n8k8`` → sparse ``m16n8k16``, dense
|
||
``m16n8k16`` → sparse ``m16n8k32``) because operand A is 2:4 compressed:
|
||
|
||
.. code-block:: python
|
||
|
||
from cutlass.cute.nvgpu.warp.mma import SparseMetadataFormat
|
||
|
||
# Dense F16 (for comparison): inst_K = 16
|
||
dense_op = cute.nvgpu.warp.MmaF16BF16Op(
|
||
cutlass.Float16, cutlass.Float32, (16, 8, 16),
|
||
)
|
||
|
||
# Sparse F16: inst_K = 32 (2× dense, since A is 2:4 compressed)
|
||
sparse_op = cute.nvgpu.warp.MmaF16BF16SparseOp(
|
||
cutlass.Float16, # A/B element type
|
||
cutlass.Float32, # accumulator type
|
||
(16, 8, 32), # instruction shape (M, N, K)
|
||
SparseMetadataFormat.TID, # metadata format
|
||
)
|
||
tiled_mma = cute.make_tiled_mma(sparse_op, cute.make_layout((1, 1, 1)))
|
||
|
||
.. code-block:: text
|
||
|
||
Supported instruction shapes for MmaF16BF16SparseOp:
|
||
|
||
| A/B Type | Acc Type | Inst Shape |
|
||
|----------|-----------|----------------|
|
||
| F16 | F16, F32 | (16,8,16), (16,8,32) |
|
||
| BF16 | F32 | (16,8,16), (16,8,32) |
|
||
|
||
**2. Compressed A tensor and metadata E** — operand A stores only the
|
||
two non-zero values per group of 4 K-elements (half the storage). The
|
||
metadata tensor ``E`` records which 2 of 4 positions are non-zero. The
|
||
exact bit encoding depends on ``SparseMetadataFormat`` and on how the
|
||
implementation packs metadata. In this repository, helper code that
|
||
generates 2:4 test inputs packs two 4-bit metadata entries into each
|
||
``uint8`` value:
|
||
|
||
.. code-block:: python
|
||
|
||
# Example metadata values used by examples/CuTeDSL/helpers/sparse_utils.py
|
||
# Each nibble selects which 2 of 4 positions are non-zero.
|
||
metadata_values = [0x4, 0x8, 0x9, 0xC, 0xD, 0xE]
|
||
|
||
.. code-block:: text
|
||
|
||
Dense A: (M, K) Sparse operands:
|
||
+--+--+--+--+--+--+--+--+ +--+--+--+--+
|
||
| a| 0| b| 0| c| 0| d| 0| → | a| b| c| d| (compressed A values)
|
||
+--+--+--+--+--+--+--+--+ +--+--+--+--+
|
||
|
||
E stores the non-zero positions
|
||
for each 2:4 group.
|
||
|
||
**3. Fragments** — the dense-style fragment APIs for A, B, and C still
|
||
apply to the sparse atom:
|
||
|
||
.. code-block:: python
|
||
|
||
# A/B/C fragments — same public API shape as dense
|
||
tCsA = thr_mma.partition_A(sA)
|
||
tCsB = thr_mma.partition_B(sB)
|
||
tCgC = thr_mma.partition_C(gC)
|
||
|
||
tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0])
|
||
tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0])
|
||
tCrC = tiled_mma.make_fragment_C(tCgC)
|
||
tCrC.fill(0.0)
|
||
|
||
Sparse metadata ``E`` is an auxiliary operand associated with A. The
|
||
public warp API and tests in this repository verify op construction and
|
||
the ``cute.gemm(..., [A, E], B, ...)`` calling convention, but they do
|
||
not provide an end-to-end warp sparse kernel showing the exact
|
||
``partition`` / ``copy`` / ``make_fragment`` sequence for ``E``. For
|
||
that reason, this document intentionally does not spell out an ``E``
|
||
fragment construction sequence that has no example backing it.
|
||
|
||
**4. Modified gemm call** — the metadata E is passed alongside operand A
|
||
as a list. This part of the API is verified by ``cutlass.cute.algorithm.gemm``:
|
||
|
||
.. code-block:: python
|
||
|
||
# Schematic only: E_k is the metadata operand for the same k-slice as A_k.
|
||
A_k = tCrA[None, None, k_block]
|
||
E_k = metadata_k
|
||
B_k = tCrB[None, None, k_block]
|
||
|
||
cute.gemm(
|
||
tiled_mma,
|
||
tCrC,
|
||
[A_k, E_k], # [A, E]
|
||
B_k,
|
||
tCrC,
|
||
)
|
||
|
||
.. code-block:: text
|
||
|
||
Dense gemm call:
|
||
cute.gemm(tiled_mma, tCrC, A_k, B_k, tCrC)
|
||
|
||
Sparse gemm call:
|
||
cute.gemm(tiled_mma, tCrC, [A_k, E_k], B_k, tCrC)
|
||
^^^^ ^^^
|
||
A metadata
|
||
|
||
The epilogue (RMEM → SMEM → GMEM) is identical to a dense kernel.
|
||
|
||
.. note:: An end-to-end warp sparse GEMM example is not yet available in the
|
||
examples directory. The closest verified references in this repository are
|
||
``cutlass_ir/compiler/test/python/not_pytest/sm_80/test_mma_atom.py`` for
|
||
op construction, ``cutlass_ir/compiler/test/python/api/sm_120a/test_nvgpu_warp_mma.py``
|
||
for tiled sparse MMA construction, and
|
||
``examples/CuTeDSL/helpers/sparse_utils.py`` for
|
||
2:4 metadata packing.
|
||
|
||
.. {$nv-internal-release end}
|
||
|
||
|
||
Block-scaled MMA
|
||
~~~~~~~~~~~~~~~~
|
||
|
||
Block-scaled MMA multiplies narrow-type matrices (FP4) while applying
|
||
**per-block scale factors** along the GEMM-K dimension. Each vector of
|
||
``sf_vec_size`` consecutive K-elements shares a single scale factor, so the
|
||
hardware computes ``D = (SFA · A) * (SFB · B) + C``. The scale factors live
|
||
in **registers** alongside the operands and must be loaded from SMEM before
|
||
each ``gemm`` call.
|
||
|
||
Supported ops: ``MmaMXF4Op`` (SM120a+), ``MmaMXF4NVF4Op`` (SM120a+).
|
||
|
||
Compared to a dense MMA kernel, a block-scaled kernel has four additional concerns:
|
||
|
||
**1. MMA op creation** — block-scaled ops fix the data type to FP4
|
||
(E2M1) and the accumulator to FP32. The scale-factor type and vector
|
||
size distinguish the two ops:
|
||
|
||
.. code-block:: python
|
||
|
||
# MXF4: UE8M0 scales, sf_vec_size = 32
|
||
op = cute.nvgpu.warp.MmaMXF4Op(
|
||
cutlass.Float4E2M1FN, # A/B element type (fixed: E2M1)
|
||
cutlass.Float32, # accumulator type (fixed: F32)
|
||
cutlass.Float8E8M0FNU, # scale-factor type
|
||
) # instruction shape = (16, 8, 64), sf_vec_size = 32
|
||
|
||
# MXF4NVF4: UE4M3 scales, sf_vec_size = 16
|
||
op = cute.nvgpu.warp.MmaMXF4NVF4Op(
|
||
cutlass.Float4E2M1FN, # A/B element type (fixed: E2M1)
|
||
cutlass.Float32, # accumulator type (fixed: F32)
|
||
cutlass.Float8E4M3FN, # scale-factor type
|
||
) # instruction shape = (16, 8, 64), sf_vec_size = 16
|
||
|
||
.. code-block:: text
|
||
|
||
| Op | A/B Type | SF Type | Acc | Inst Shape | SF Vec Size |
|
||
|---------------|----------|---------|------|-------------|-------------|
|
||
| MmaMXF4Op | E2M1 | UE8M0 | F32 | (16,8,64) | 32 |
|
||
| MmaMXF4NVF4Op | E2M1 | UE4M3 | F32 | (16,8,64) | 16 |
|
||
|
||
**2. Extra global tensors and SMEM layouts for scale factors** — the host
|
||
function creates SFA/SFB tensors and allocates SMEM layouts for them
|
||
alongside A and B:
|
||
|
||
.. code-block:: python
|
||
|
||
import cutlass.utils.blockscaled_layout as blockscaled_utils
|
||
import cutlass.utils.blackwell_helpers as sm120_utils
|
||
|
||
# Scale-factor global tensors (host side)
|
||
sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(a.shape, sf_vec_size)
|
||
sfa_tensor = cute.make_tensor(sfa.iterator, sfa_layout)
|
||
sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b.shape, sf_vec_size)
|
||
sfb_tensor = cute.make_tensor(sfb.iterator, sfb_layout)
|
||
|
||
# SMEM layouts for scale factors (SM120-specific helper)
|
||
sfa_smem_layout = blockscaled_utils.sm120_make_smem_layout_sfa(
|
||
tiled_mma, tile_shape_mnk, sf_vec_size, num_stages,
|
||
)
|
||
sfb_smem_layout = blockscaled_utils.sm120_make_smem_layout_sfb(
|
||
tiled_mma, tile_shape_mnk, sf_vec_size, num_stages,
|
||
)
|
||
|
||
**3. SF fragment creation and SMEM→RMEM retiling** — scale-factor
|
||
fragments use a ``CopyUniversalOp`` with thread-value layouts derived
|
||
from the tiled MMA, rather than the ``ldmatrix``-based path used for
|
||
A and B:
|
||
|
||
.. code-block:: python
|
||
|
||
# A/B fragments (same as dense)
|
||
tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0])
|
||
tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0])
|
||
|
||
# SF fragments (SM120-specific partition helpers)
|
||
tCrSFA = sm120_utils.partition_fragment_SFA(sSFA[None, None, 0], thr_mma, tidx)
|
||
tCrSFB = sm120_utils.partition_fragment_SFB(sSFB[None, None, 0], thr_mma, tidx)
|
||
|
||
# A/B: ldmatrix retiling (same as dense)
|
||
atom_copy_A = cute.make_copy_atom(cute.nvgpu.warp.LdMatrix8x8x16bOp(...), a_dtype)
|
||
smem_tiled_copy_A = cute.make_tiled_copy_A(atom_copy_A, tiled_mma)
|
||
|
||
# SF: CopyUniversal with SF-specific thread-value layout
|
||
atom_copy_SF = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), sf_dtype)
|
||
smem_tiled_copy_SFA = cute.make_tiled_copy(
|
||
atom_copy_SF,
|
||
sm120_utils.get_layoutSFA_TV(tiled_mma),
|
||
(cute.size(tiled_mma.permutation_mnk[0]), cute.size(tiled_mma.permutation_mnk[2])),
|
||
)
|
||
smem_tiled_copy_SFB = cute.make_tiled_copy(
|
||
atom_copy_SF,
|
||
sm120_utils.get_layoutSFB_TV(tiled_mma),
|
||
(cute.size(tiled_mma.permutation_mnk[1]), cute.size(tiled_mma.permutation_mnk[2])),
|
||
)
|
||
|
||
**4. Modified main loop** — each k-block loads A, B, SFA, and SFB from
|
||
SMEM into registers. The ``cute.gemm`` call passes ``[A, SFA]`` and
|
||
``[B, SFB]`` as operand lists:
|
||
|
||
.. code-block:: python
|
||
|
||
for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
|
||
# ldmatrix: load A and B from SMEM → RMEM (same as dense)
|
||
cute.copy(smem_tiled_copy_A, tCsA_p[None, None, k_block_next],
|
||
tCrA_copy_view[None, None, k_block_next])
|
||
cute.copy(smem_tiled_copy_B, tCsB_p[None, None, k_block_next],
|
||
tCrB_copy_view[None, None, k_block_next])
|
||
|
||
# CopyUniversal: load SFA and SFB from SMEM → RMEM # NEW
|
||
cute.copy(smem_tiled_copy_SFA,
|
||
cute.filter_zeros(tCsSFA_p)[None, None, k_block_next],
|
||
cute.filter_zeros(tCrSFA_copy_view)[None, None, k_block_next])
|
||
cute.copy(smem_tiled_copy_SFB,
|
||
cute.filter_zeros(tCsSFB_p)[None, None, k_block_next],
|
||
cute.filter_zeros(tCrSFB_copy_view)[None, None, k_block_next])
|
||
|
||
# MMA with scale factors passed as [value, scale] pairs
|
||
cute.gemm(
|
||
tiled_mma,
|
||
accumulators,
|
||
[tCrA[None, None, k_block_idx], tCrSFA[None, None, k_block_idx]], # [A, SFA]
|
||
[tCrB[None, None, k_block_idx], tCrSFB[None, None, k_block_idx]], # [B, SFB]
|
||
accumulators,
|
||
)
|
||
|
||
.. code-block:: text
|
||
|
||
Dense gemm call:
|
||
cute.gemm(tiled_mma, acc, tCrA[k], tCrB[k], acc)
|
||
|
||
Block-scaled gemm call:
|
||
cute.gemm(tiled_mma, acc, [tCrA[k], tCrSFA[k]], [tCrB[k], tCrSFB[k]], acc)
|
||
^^^^^^^^ ^^^^^^^^^ ^^^^^^^^ ^^^^^^^^^
|
||
value scale value scale
|
||
(RMEM) (RMEM) (RMEM) (RMEM)
|
||
|
||
Note that ``cute.filter_zeros`` is applied to the SF copy views because
|
||
the scale-factor SMEM layouts may contain padding zeros from the TMA
|
||
tiling. This strips the padded entries so the copy operates only on
|
||
valid elements.
|
||
|
||
The epilogue (RMEM → SMEM → GMEM) is identical to a dense kernel.
|
||
|
||
|
||
See also:
|
||
|
||
- Dense GEMM example (Ampere): ``examples/cute/ampere/kernel/dense_gemm/tensorop_gemm.py``
|
||
- Block-scaled GEMM example (SM120a): ``examples/cute/blackwell_geforce/kernel/blockscaled_gemm/dense_blockscaled_gemm_persistent_pingpong.py``
|
||
- Block-scaled layout utilities: ``cutlass.utils.blockscaled_layout``
|
||
- SM120 helper utilities: ``cutlass.utils.blackwell_helpers``
|