mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-23 14:25:03 +00:00
1529 lines
70 KiB
ReStructuredText
1529 lines
70 KiB
ReStructuredText
.. _tcgen05_programming:
|
||
|
||
tcgen05 MMA Programming Guide
|
||
==============================
|
||
|
||
Blackwell (SM100) introduces the **tcgen05** family of PTX instructions — the
|
||
5th-generation Tensor Core MMA (matrix multiply-accumulate) operations. They
|
||
compute ``D = A * B + C`` with 2x–4x the throughput of Hopper's WGMMA
|
||
instructions, depending on data type.
|
||
|
||
Key architectural characteristics:
|
||
|
||
* **Tensor Memory (TMEM):** A new on-chip memory dedicated to the accumulator
|
||
(and, optionally, operand A). tcgen05 MMA reads and writes the accumulator
|
||
in TMEM directly, freeing the register file for other work.
|
||
* **Single-thread launch:** Only one thread issues the MMA instruction.
|
||
* **CTA-pair cooperation:** Two adjacent CTAs can jointly execute a single
|
||
MMA, doubling the tile size without extra synchronization logic.
|
||
|
||
This guide shows how to program these operations through the CuTe Python
|
||
DSL, using SMEM for operands A and B, and TMEM for the accumulator.
|
||
|
||
.. contents:: **Contents**
|
||
:local:
|
||
:depth: 2
|
||
|
||
Global Memory (GMEM) to MMA data flow overview
|
||
----------------------------------------------
|
||
|
||
Tcgen05 MMA instructions requires us to stage A input operands in Shared Memory (SMEM) or Tensor Memory (TMEM),
|
||
and B input operands in SMEM. The accumulator is always stored in TMEM.
|
||
|
||
The diagram below traces the full data flow of a tcgen05 GEMM kernel, for the most
|
||
common case where A and B matrices are stored in GMEM, and the output matrix --read from TMEM--
|
||
is written to GMEM.
|
||
|
||
There are 3 parallel tracks where each has 2 sub-tracks. Three parallel tracks are
|
||
for operands A, B, and C/D, respectively. The two sub-tracks are for copying data between different memory
|
||
spaces and for MMA execution.
|
||
|
||
- **Operand A** (and symmetrically **Operand B**):
|
||
|
||
- First, we need to create SMEM tensors for A and B matrices: ``sA`` and ``sB``. These
|
||
tensors are physically allocated tensors that are the destination of copy and the source operands
|
||
for the MMA instructions.
|
||
- Next the **data copy flow** creates the tensor views for copying data from GMEM to SMEM.
|
||
It starts with ``mA`` tensor that represents the matrix A in global memory.
|
||
``mA`` → ``local_tile`` → ``gA`` operation creates the local tile view of A that is the
|
||
slice of A matrix needed to compute the given MMA's output tile partitioning.
|
||
``gA`` → ``partition_A`` → ``tCgA`` partitions the full MMA sized tile into smaller tiles which
|
||
are needed to copy the correct portion of A/B matrix to SMEM by individual CTAs cooperating
|
||
for the MMA (1CTA vs 2CTA pair MMA cases).
|
||
Then ``tma_partition`` produces TMA views ``tAsA``, ``tAgA``, and the loop copies tiles from
|
||
GMEM into SMEM via ``copy(tma, tAgA[k], tAsA[stage])``.
|
||
- In parallel, the **MMA flow** turns the SMEM tensors into iterable tensors of SMEM descriptors for MMA instructions.
|
||
``sA`` (the same shared-memory allocation written by TMA) → ``make_fragment_A`` → ``tCrA``
|
||
(they are passed to ``cute.gemm()``). Note that the SMEM descriptors are views created
|
||
from the SMEM tensor that is interpretable by the MMA instructions.
|
||
|
||
- **Accumulator C/D**:
|
||
|
||
- **TMEM accumulator flow** (gemm input/output): ``make_fragment_C(MMA_partition_shape_C)`` → ``tCtAcc``,
|
||
which serves as the accumulator input/output of ``cute.gemm()`` (and MMA instruction).
|
||
- **Output flow** (GMEM destination): The LDTM loads results into registers and a final store writes them
|
||
to global memory. ``mC`` → ``local_tile`` → ``gC``
|
||
→ ``partition_C`` → ``tCgC``. This path creates the tensor views that will be stored to GMEM.
|
||
|
||
.. code-block:: text
|
||
|
||
Operand A Dataflow Path Operand B Dataflow Path Accumulator C/D Dataflow Path
|
||
─────────────────────── ─────────────────────── ─────────────────────────────
|
||
|
||
mA: (M, K) [GMEM] mB: (N, K) [GMEM] ┌──── TMEM ──────────┐
|
||
│ │ │ partition_shape_C()│
|
||
│ local_tile(mA, mma_tiler, coord) │ local_tile(mB, mma_tiler, coord) │ make_fragment_C() │
|
||
▼ ▼ │ bind to tmem_ptr │
|
||
gA: (BM, BK, k) [GMEM] gB: (BN, BK, k) [GMEM] └───────┬────────────┘
|
||
│ │ │
|
||
│ thr_mma.partition_A(gA) │ thr_mma.partition_B(gB) tCtAcc:(MMA,MMA_M,MMA_N) [TMEM]
|
||
▼ ▼ │
|
||
tCgA:(MMA,MMA_M, [GMEM] tCgB:(MMA,MMA_N, [GMEM] │
|
||
MMA_K,k) MMA_K,k) │
|
||
│ │ │ mC: (M, N) [GMEM]
|
||
│ ┌──── SMEM ─────────┐ │ ┌──── SMEM ─────────┐ │ │
|
||
│ │ sA = alloc(layout)│ │ │ sB = alloc(layout)│ │ │ local_tile
|
||
│ └──┬────────┬───────┘ │ └──┬────────┬───────┘ │ ▼
|
||
│ │ │ │ │ │ │ gC: (BM, BN) [GMEM]
|
||
│ │ make_fragment_A(sA) │ │ make_fragment_B(sB) │ │ partition_C
|
||
│ │ │ │ │ │ │ ▼
|
||
│ │ ▼ │ │ ▼ │ tCgC:(MMA,MMA_M,
|
||
│ │ tCrA:(MMA,MMA_M, │ │ tCrB:(MMA,MMA_N, │ MMA_N)
|
||
│ │ MMA_K,STAGE) │ │ MMA_K,STAGE) │ [GMEM] (epi dest)
|
||
│ │ [SMEM descriptors] │ │ [SMEM descriptors] │ │
|
||
│ │ └─────────────┐ │ │ └─────────────┐ │ │
|
||
╰─────┤ │ ╰─────┤ │ │ │
|
||
▼ │ ▼ │ │ │
|
||
tma_partition(tma, │ tma_partition(tma, │ │ │
|
||
sA, tCgA) │ sB, tCgB) │ │ │
|
||
→ tAsA, tAgA │ → tBsB, tBgB │ │ │
|
||
▼ │ ▼ │ │ │
|
||
┌───┴────────────────────┐ │ ┌──────┴─────────────────┐│ │ │
|
||
│ TMA copy loop (A path):│ │ │ TMA copy loop (B path):││ │ │
|
||
│ copy(tma, tAgA[k], │ │ │ copy(tma, tBgB[k], ││ │ │
|
||
│ tAsA[stage]) │ │ │ tBsB[stage]) ││ │ │
|
||
┌─▶│ (writes into sA; │ │ ┌──▶│ (writes into sB; ││ │ │
|
||
│ │ tCrA reads same sA) │ │ │ │ tCrB reads same sB) ││ │ │
|
||
│ │ repeat for next k/stage│ │ │ │ repeat for next k/stage││ │ │
|
||
│ └────────────────────────┘ │ │ └────────────────────────┘│ │ │
|
||
│ │ │ │ │ │ │ │
|
||
└────────┘ ▼ └─────────┘ ▼ ▼ │
|
||
└───────┬───────────────────────────────┴──────────────────┘ │
|
||
│ │
|
||
▼ │
|
||
┌──────────────────────────────────────────────┐ │
|
||
│ GEMM Loop: | │
|
||
| cute.gemm(tiled_mma, │ │
|
||
│ tCtAcc, D (output), │ │
|
||
┌──▶ │ tCrA[stage], A (SMEM desc -> sA), │ │
|
||
│ │ tCrB[stage], B (SMEM desc -> sB), │ │
|
||
│ │ tCtAcc) C (accumulator input) │ │
|
||
│ └──────────────────────────────────────────────┘ │
|
||
│ │ │ │
|
||
└───────┘ | │
|
||
▼ │
|
||
Epilogue: │
|
||
t2r = make_tmem_copy(LdOp, tCtAcc) │
|
||
tTR_tAcc = t2r.partition_S(tCtAcc) │
|
||
tTR_gC = t2r.partition_D(tCgC) ◀────────────────────────────────┘
|
||
tTR_rAcc = make_rmem_tensor(...)
|
||
│
|
||
▼
|
||
LDTM: copy(t2r, tTR_tAcc, tTR_rAcc)
|
||
[TMEM → RMEM]
|
||
│
|
||
▼
|
||
Store: copy(atom, tTR_rAcc, tTR_gC)
|
||
[RMEM → GMEM]
|
||
|
||
|
||
**Naming convention:**
|
||
|
||
* ``mma_tiler_mnk`` = ``(BM, BN, BK)`` — per-CTA (or per-CTA-pair) MMA tile
|
||
* ``mX`` = a global tensor, such as ``mA``, ``mB``, ``mC``
|
||
* ``gX`` = MMA-tiler tiled GMEM slice, e.g. ``(BM, BK, k)`` for A
|
||
* ``tCgX`` = CTA-partitioned GMEM tensor, e.g. ``(MMA, MMA_M, MMA_K, k)`` for A
|
||
* ``sX`` = SMEM allocation (``sA``, ``sB``)
|
||
* ``tCrX`` = SMEM-descriptor MMA fragment, e.g. ``(MMA, MMA_M, MMA_K, STAGE)`` for A
|
||
* ``tCtX`` = TMEM tensor; ``tCtAcc`` = TMEM accumulator ``(MMA, MMA_M, MMA_N[, ACC_STAGE])``
|
||
* ``tAsA`` / ``tBsB`` = TMA-partitioned SMEM views of A / B
|
||
* ``tAgA`` / ``tBgB`` = TMA-partitioned GMEM views of A / B
|
||
* ``tTR_*`` = T2R (TMEM→RMEM) partitioned tensors used in the epilogue
|
||
|
||
|
||
Setting up the TiledMMA, MMA Ops
|
||
---------------------------------
|
||
|
||
As shown in the data flow overview, CuTe DSL provides many utilities to tile/partition
|
||
the global memory tensors, and create fragment views of SMEM and TMEM tensors for MMA instructions.
|
||
|
||
To utilize these functions, we need to setup the TiledMMA, MMA Ops first.
|
||
|
||
Creating a tcgen05 MMA Op
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
A tcgen05 MMA op describes the hardware instruction to use, it has parameters like
|
||
data types, instruction shape, CTA group, operand A source (SMEM or TMEM),
|
||
and operand major modes.
|
||
|
||
.. code-block:: python
|
||
|
||
import cutlass
|
||
import cutlass.cute as cute
|
||
from cutlass.cute.nvgpu import tcgen05, OperandMajorMode
|
||
|
||
op = tcgen05.MmaF16BF16Op(
|
||
cutlass.Float16, # A/B element type
|
||
cutlass.Float32, # accumulator type
|
||
(128, 256, 16), # instruction shape (M, N, K)
|
||
tcgen05.CtaGroup.ONE, # CTA group
|
||
tcgen05.OperandSource.SMEM, # A operand from shared memory
|
||
OperandMajorMode.K, # A is K-major
|
||
OperandMajorMode.K, # B is K-major
|
||
)
|
||
|
||
The key parameters are:
|
||
|
||
- **Instruction shape** ``(M, N, K)``: determines the size of one hardware MMA
|
||
instruction. Larger M and N amortize instruction overhead.
|
||
- **OperandSource**: ``SMEM`` reads A from a shared memory descriptor; ``TMEM``
|
||
reads A directly from tensor memory.
|
||
- **OperandMajorMode**: ``K`` for K-major (default), ``MN`` for transposed layout.
|
||
Transpose A requires ``a_src=SMEM``; when ``a_src=TMEM``, A is always K-major.
|
||
|
||
|
||
CuTe DSL provides implementation of many tcgen05 MMA ops:
|
||
|
||
.. list-table:: tcgen05 MMA ops
|
||
:header-rows: 1
|
||
:widths: 30 24 46
|
||
|
||
* - PTX name
|
||
- Python class
|
||
- Constructor parameters
|
||
* - ``tcgen05.mma.cta_group::{cg}.kind::tf32``
|
||
- ``tcgen05.MmaTF32Op``
|
||
- ``instruction_shape, cta_group, a_src, a_major_mode, b_major_mode``
|
||
* - ``tcgen05.mma.cta_group::{cg}.kind::f16``
|
||
- ``tcgen05.MmaF16BF16Op``
|
||
- ``ab_dtype, acc_dtype, instruction_shape, cta_group, a_src, a_major_mode, b_major_mode``
|
||
* - ``tcgen05.mma.cta_group::{cg}.kind::i8``
|
||
- ``tcgen05.MmaI8Op``
|
||
- ``ab_dtype, instruction_shape, cta_group, a_src, a_major_mode, b_major_mode``
|
||
* - ``tcgen05.mma.cta_group::{cg}.kind::f8f6f4``
|
||
- ``tcgen05.MmaF8F6F4Op``
|
||
- ``a_dtype, b_dtype, acc_dtype, instruction_shape, cta_group, a_src, a_major_mode, b_major_mode``
|
||
* - ``tcgen05.mma.cta_group::{cg}.kind::mxf8f6f4.block_scale``
|
||
- ``tcgen05.MmaMXF8F6F4Op``
|
||
- ``a_dtype, b_dtype, instruction_shape, cta_group, a_src, a_major_mode, b_major_mode``
|
||
* - ``tcgen05.mma.cta_group::{cg}.kind::mxf4.block_scale``
|
||
- ``tcgen05.MmaMXF4Op``
|
||
- ``instruction_shape, cta_group, a_src``
|
||
* - ``tcgen05.mma.cta_group::{cg}.kind::mxf4nvf4.block_scale``
|
||
- ``tcgen05.MmaMXF4NVF4Op``
|
||
- ``sf_dtype, instruction_shape, cta_group, a_src``
|
||
|
||
|
||
Creating a Tiled MMA
|
||
~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
A ``TiledMma`` tiles the MMA atom across the thread block. You can pass the op
|
||
directly or create an explicit atom first.
|
||
|
||
.. code-block:: python
|
||
|
||
# Option 1: directly from op (common shorthand)
|
||
tiled_mma = cute.make_tiled_mma(op)
|
||
|
||
# Option 2: explicit atom creation
|
||
atom = cute.make_mma_atom(op)
|
||
tiled_mma = cute.make_tiled_mma(atom)
|
||
|
||
|
||
Spatial tiling with a repeat count (using ``atom_layout_mnk``)
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
A repeat tuple ``(M_rep, N_rep, K_rep)`` replicates the atom across the M, N, and K dimensions, producing
|
||
a larger tiled MMA that covers a bigger CTA tile.
|
||
|
||
.. code-block:: python
|
||
|
||
tiled_mma = cute.make_tiled_mma(atom, (2, 2, 1))
|
||
|
||
The coordinates of atoms could be thought as a 4D coordinate: (v, m, n, k).
|
||
v is the CTAs for a single MMA (for CtaGroup.ONE always 0, for CtaGroup.TWO always 0 or 1),
|
||
m is the M dimension repeat count, n is the N dimension repeat count, and k is the K dimension repeat count.
|
||
|
||
.. code-block:: text
|
||
|
||
MMA Atom CtaGroup.ONE make_tiled_mma(atom, (2, 2, 1))
|
||
+---------------+ +----------------+----------------+
|
||
| | | | | ^
|
||
| 128 x 256 | | Atom (0,0,0,0) | Atom (0,0,1,0) | |
|
||
| x 16 | --(2,2,1)--> | 128 x 256 | 128 x 256 | | 2 x M_atom
|
||
| | repeat | x 16 | x 16 | | = 256
|
||
| | | | | |
|
||
+---------------+ +----------------+----------------+ |
|
||
| | | |
|
||
| Atom (0,1,0,0) | Atom (0,1,1,0) | |
|
||
| 128 x 256 | 128 x 256 | |
|
||
| x 16 | x 16 | |
|
||
| | | v
|
||
+----------------+----------------+
|
||
<---- 2 x N_atom = 512 -------->
|
||
K unchanged = 16
|
||
|
||
.. code-block:: text
|
||
|
||
MMA Atom CtaGroup.TWO make_tiled_mma(atom, (2, 2, 1))
|
||
+---------------+ +----------------+----------------+
|
||
| CTA v = 0 | | Atom (0,0,0,0) | Atom (0,0,1,0) | ^
|
||
| 128 x 256 | | 128 x 256 | 128 x 256 | |
|
||
| x 16 | | x 16 | x 16 | | 2CTA Atom
|
||
+...............+ +................+................+ |
|
||
| CTA v = 1 | --(2,2,1)--> | Atom (1,0,0,0) | Atom (1,0,1,0) | |
|
||
| 128 x 256 | repeat | 128 x 256 | 128 x 256 | |
|
||
| x 16 | | x 16 | x 16 | v
|
||
+---------------+ +----------------+----------------+
|
||
| Atom (0,1,0,0) | Atom (0,1,1,0) | ^
|
||
| 128 x 256 | 128 x 256 | |
|
||
| x 16 | x 16 | | 2CTA Atom
|
||
+................+................+ |
|
||
| Atom (1,1,0,0) | Atom (1,1,1,0) | |
|
||
| 128 x 256 | 128 x 256 | |
|
||
| x 16 | x 16 | v
|
||
+----------------+----------------+
|
||
<---- 2 x N_atom = 512 -------->
|
||
Per CTA: 2 x M_atom = 256
|
||
Cluster M (v*m*128): 512
|
||
K unchanged = 16
|
||
|
||
|
||
Custom tile permutation with ``permutation_mnk``
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
``make_tiled_mma`` accepts an optional ``permutation_mnk`` argument that
|
||
controls how the atom tiles are laid out across the M, N, and K dimensions.
|
||
``permutation_mnk`` is a tuple of layouts or ints that represent the tile size and reordering of values.
|
||
These permutation operations could be applied to optimize the data access patterns for MMAs.
|
||
|
||
For example, with ``inst_m=256`` and 2 atoms in M (total M tile = 512),
|
||
a permutation can interleave the two atoms' M rows:
|
||
|
||
.. code-block:: python
|
||
|
||
# inst_m=256, inst_n=256, inst_k=16
|
||
m_layout = cute.make_layout(
|
||
shape=(128, 2, 2), # (inst_m // 2, 2, 2)
|
||
stride=(1, 256, 128), # (1, inst_m, inst_m // 2)
|
||
)
|
||
tiled_mma = cute.make_tiled_mma(
|
||
atom,
|
||
atom_layout_mnk=(1, 1, 1),
|
||
permutation_mnk=(m_layout, 256, 16),
|
||
)
|
||
|
||
The layout ``(128,2,2):(1,256,128)`` maps logical flat indices to physical
|
||
M rows in colex order (mode 0 fastest), interleaving the two atoms' halves:
|
||
|
||
.. code-block:: text
|
||
|
||
Without permutation With permutation_mnk
|
||
(sequential, default) m_layout = (128,2,2):(1,256,128)
|
||
|
||
+---------------+ ^ ^ +---------------+ ^
|
||
| MMA 0 top | | 128 CTA 0 | | MMA 0 top | | 128 CTA 0
|
||
| rows 0-127 | | | | rows 0-127 | |
|
||
+...............+ + | Tile 0 +---------------+ v
|
||
| MMA 0 bottom | | 128 CTA 1 | | MMA 1 top | | 128 CTA 0
|
||
| rows 128-255 | | | | rows 128-255 | |
|
||
+---------------+ v v +---------------+ v
|
||
| MMA 1 top | ^ ^ | MMA 0 bottom | | 128 CTA 1
|
||
| rows 256-383 | | 128 CTA 0 | | rows 256-383 | |
|
||
+...............+ + | Tile 1 +---------------+ v
|
||
| MMA 1 bottom | | 128 CTA 1 | | MMA 1 bottom | | 128 CTA 1
|
||
| rows 384-511 | | | | rows 384-511 | |
|
||
+---------------+ v v +---------------+ v
|
||
<-- inst_N=256 -> <-- inst_N=256 ->
|
||
inst_K = 16 inst_K = 16
|
||
|
||
Tile 0: rows 0-255 (contiguous) Tile 0: rows {0-127, 256-383}
|
||
Tile 1: rows 256-511 (contiguous) Tile 1: rows {128-255, 384-511}
|
||
CTA 0 owns rows {0-127, 256-383} CTA 0 owns rows {0-127, 256-383}
|
||
CTA 1 owns rows {128-255, 384-511} CTA 1 owns rows {128-255, 384-511}
|
||
|
||
When ``permutation_mnk`` is not provided (default), the tile ordering is
|
||
sequential and no permutation is applied.
|
||
|
||
Creating Trivial Tiled MMA
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
Since tcgen05 MMAs have quite large instruction shapes, most common TiledMmas created are
|
||
trivial tiled MMAs, with single M, N repetitions, i.e., ``atom_layout_mnk``, and ``permutation_mnk`` are generally unused.
|
||
CuTe DSL provides a convenience function ``make_trivial_tiled_mma`` to create such trivial MMAs with
|
||
automatic MmaOp kind selection based on the data types.
|
||
|
||
.. code-block:: python
|
||
|
||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||
|
||
tiled_mma = sm100_utils.make_trivial_tiled_mma(
|
||
a_dtype,
|
||
b_dtype,
|
||
a_major_mode,
|
||
b_major_mode,
|
||
acc_dtype,
|
||
cta_group,
|
||
mma_tiler_mnk,
|
||
)
|
||
|
||
# Equivalent to
|
||
tiled_mma = cute.make_tiled_mma(
|
||
cute.make_mma_atom(
|
||
cute.MmaXyzOp(
|
||
# ... parameters of MmaXyzOp
|
||
),
|
||
),
|
||
)
|
||
|
||
Partitioning Tensors
|
||
--------------------
|
||
|
||
Before computing MMAs, we want to partition the global memory tensors according to the
|
||
tiled MMA layout. For tcgen05, this maps each CTA's work to the correct portion of
|
||
the global memory tensors.
|
||
|
||
We have two steps to partition the global memory tensors:
|
||
|
||
* Local tile partitioning: partition the global memory tensors into local tiles, each of size ``mma_tiler_mnk``.
|
||
This is the portion of the global memory tensors that will be processed by a single CTA MMA or a 2CTA cooperative MMA.
|
||
* MMA partition: partition the local tile into CTA-sized, per-MMA-instruction tiles (note that each CTA needs
|
||
to load its own portion to SMEM for 2CTA cooperative MMA). The per-operand shapes are
|
||
``(MMA, MMA_M, MMA_K, ...)`` for A, ``(MMA, MMA_N, MMA_K, ...)`` for B, and ``(MMA, MMA_M, MMA_N, ...)`` for C.
|
||
|
||
Note that for tcgen05, SMEM tensors are not partitioned.
|
||
See `Making Fragments`_ for more details.
|
||
|
||
Trivial TiledMma with CtaGroup.ONE MMAs (single CTA):
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
For the trivial tiled MMAs with CtaGroup.ONE tcgen05 MMA operations, partitioning the mma_tiler sized tile
|
||
is an identity operation, i.e., single CTAs' tile is the same as the mma_tiler sized tile.
|
||
The main difference between the result of ``local_tile`` and ``partition_[A/B]`` is that, the latter
|
||
produces a view that can be iterated in per-MMA instruction fashion.
|
||
|
||
Example: ``GEMM (M, N, K) = (512, 768, 384)``, ``mma_tiler_mnk = (128, 256, 64)``,
|
||
``CtaGroup.ONE``, F16 atom = 128x256x16 (inst_M=128, inst_N=256, inst_K=16).
|
||
|
||
Global memory tensors:
|
||
|
||
.. code-block:: text
|
||
|
||
mA: (M, K) = (512, 384) mB: (N, K) = (768, 384) mC: (M, N) = (512, 768)
|
||
|
||
K=384 K=384 N=768
|
||
|<----------->| |<----------->| |<----------------->|
|
||
+-------------+ +-------------+ +---+---+---+-------+
|
||
| | ^ | | ^ | | | | | ^
|
||
| mA | | M=512 | mB | | N=768 | | | | | | M=512
|
||
| | v | | v | | | | | v
|
||
+-------------+ +-------------+ +---+---+---+-------+
|
||
|
||
Tiling with ``mma_tiler_mnk = (BM, BN, BK) = (128, 256, 64)`` gives
|
||
M/BM = 512/128 = 4 tiles, N/BN = 768/256 = 3 tiles, K/BK = 384/64 = 6 tiles:
|
||
|
||
.. code-block:: text
|
||
|
||
mA tiled into (M/BM x K/BK) mB tiled into (N/BN x K/BK) mC tiled into (M/BM x N/BN)
|
||
= (4 x 6) blocks = (3 x 6) blocks = (4 x 3) blocks
|
||
* coordinates annotated on the matrix
|
||
are the mma_coord_mn of the GEMM.
|
||
|
||
BK=64 x6 BK=64 x6 BN=256 x3
|
||
|<--->| |<--->| |<----->|
|
||
+-----+-----+-- --+ +-----+-----+-- --+ +-------+-------+-------+
|
||
| | |..| | ^ | | |..| | ^ | (0,0) | (0,1) | (0,2) | ^
|
||
| | | | | | BM=128 | | | | | | BN=256 | | | | | BM=128
|
||
+-----+-----+-- --+ v +-----+-----+-- --+ v +-------+-------+-------+ v
|
||
| | |..| | ^ | | |..| | ^ | (1,0) | (1,1) | (1,2) | ^
|
||
| | | | | | BM=128 | | | | | | BN=256 | | | | | BM=128
|
||
+-----+-----+-- --+ v +-----+-----+-- --+ v +-------+-------+-------+ v
|
||
| | |..| | ^ | | |..| | ^ | (2,0) | (2,1) | (2,2) | ^
|
||
| | | | | | BM=128 | | | | | | BN=256 | | | | | BM=128
|
||
+-----+-----+-- --+ v +-----+-----+-- --+ v +-------+-------+-------+ v
|
||
| | |..| | ^ | (3,0) | (3,1) | (3,2) | ^
|
||
| | | | | | BM=128 | | | | | BM=128
|
||
+-----+-----+-- --+ v +-------+-------+-------+ v
|
||
|
||
Each CTA picks one (M-coord, N-coord) coordinate.
|
||
For example, CTA at ``mma_coord = (0, 1, :)``.
|
||
|
||
After ``local_tile`` — one CTA's tile has ``k = K/BK = 384/64 = 6`` tiles to process
|
||
for A, B tensors, and a single tile for C tensor:
|
||
|
||
.. code-block:: text
|
||
|
||
gA: (BM, BK, k) = (128, 64, 6) gB: (BN, BK, k) = (256, 64, 6) gC: (BM, BN) = (128, 256)
|
||
(k has 6 tiles total: indices 0..5)
|
||
|
||
BK=64 BK=64 BN=256
|
||
|<----->| |<----->| |<--------->|
|
||
+-------+---------+-------+ +-------+---------+-------+ +-----------+
|
||
| | | | | | | | | | ^
|
||
BM= | gA k0 | k1...k4 | gA k5 | BN= | gB k0 | k1...k4 | gB k5 | BM= | gC | 128
|
||
128 | | | | 256 | | | | 128 | | v
|
||
+-------+---------+-------+ +-------+---------+-------+ +-----------+
|
||
|
||
``get_slice(0)`` — single CTA owns the full tile.
|
||
BM and BN match the atom, BK is split into MMA_K atom-sized steps:
|
||
|
||
.. code-block:: text
|
||
|
||
gA (BK split into MMA_K atoms) gC
|
||
inst_K inst_K inst_K inst_K
|
||
=16 =16 =16 =16
|
||
|<--->|<--->|<--->|<--->|
|
||
+-----+-----+-----+-----+-- +-----------+
|
||
| 0 | 1 | 2 | 3 |.. BM=128 (MMA_M=1) | | BM=128
|
||
+-----+-----+-----+-----+ +-----------+
|
||
|<-- MMA_K = BK/inst_K = 4 -->|
|
||
|
||
gB (BK split into MMA_K atoms)
|
||
inst_K inst_K inst_K inst_K
|
||
=16 =16 =16 =16
|
||
|<--->|<--->|<--->|<--->|
|
||
+-----+-----+-----+-----+--
|
||
| 0 | 1 | 2 | 3 |.. BN=256 (MMA_N=1)
|
||
+-----+-----+-----+-----+
|
||
|
||
After partition (single CTA):
|
||
|
||
- ``tCgA: (MMA, MMA_M, MMA_K, k) = (MMA, 1, 4, 6)`` — MMA_M = BM/inst_M = 128/128 = 1, MMA_K = BK/inst_K = 64/16 = 4
|
||
- ``tCgB: (MMA, MMA_N, MMA_K, k) = (MMA, 1, 4, 6)`` — MMA_N = BN/inst_N = 256/256 = 1, MMA_K = BK/inst_K = 64/16 = 4
|
||
- ``tCgC: (MMA, MMA_M, MMA_N) = (MMA, 1, 1)`` — MMA_M = BM/inst_M = 1, MMA_N = BN/inst_N = 1
|
||
|
||
With CuTe DSL, all these calculations are handled for you provided ``mma_tiler_mnk`` and ``mma_coord``.
|
||
|
||
.. code-block:: python
|
||
|
||
@cute.kernel
|
||
def kernel(tiled_mma: cute.TiledMma, ...):
|
||
gA = cute.local_tile(mA, mma_tiler_mnk, mma_coord, proj=(1, None, 1)) # (BM, BK, k) for A
|
||
gB = cute.local_tile(mB, mma_tiler_mnk, mma_coord, proj=(None, 1, 1)) # (BN, BK, k) for B
|
||
gC = cute.local_tile(mC, mma_tiler_mnk, mma_coord, proj=(1, 1, None)) # (BM, BN) for C
|
||
|
||
# Single CTA MMA: cta index is always 0
|
||
thr_mma = tiled_mma.get_slice(0)
|
||
|
||
tCgA = thr_mma.partition_A(gA) # (MMA, MMA_M, MMA_K, num_k_tiles) for A
|
||
tCgB = thr_mma.partition_B(gB) # (MMA, MMA_N, MMA_K, num_k_tiles) for B
|
||
tCgC = thr_mma.partition_C(gC) # (MMA, MMA_M, MMA_N) for C
|
||
|
||
Trivial TiledMma with CtaGroup.TWO MMAs (2-CTA cluster, each CTA owns half the M-tile):
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
With ``CtaGroup.TWO``, two CTAs cooperate on a single tile. The V-coordinate
|
||
(0 or 1) identifies which CTA within the pair. ``get_slice(V)`` gives each CTA
|
||
its half of the M dimension, while B is fully shared.
|
||
|
||
|
||
Example: ``GEMM (M, N, K) = (512, 768, 384)``, ``mma_tiler_mnk = (256, 256, 64)``,
|
||
``CtaGroup.TWO``, F16 MMA atom = 128x256x16 (inst_M=128, inst_N=256, inst_K=16).
|
||
|
||
Global matrices:
|
||
|
||
.. code-block:: text
|
||
|
||
mA: (M, K) = (512, 384) mB: (N, K) = (768, 384) mC: (M, N) = (512, 768)
|
||
|
||
K=384 K=384 N=768
|
||
|<----------->| |<----------->| |<----------------->|
|
||
+-------------+ +-------------+ +---+---+---+-------+
|
||
| | ^ | | ^ | | | | | ^
|
||
| mA | | M=512 | mB | | N=768 | | | | | | M=512
|
||
| | v | | v | | | | | v
|
||
+-------------+ +-------------+ +---+---+---+-------+
|
||
|
||
Tiling with ``mma_tiler_mnk = (BM, BN, BK) = (256, 256, 64)`` gives
|
||
M/BM = 512/256 = 2 tiles in M-mode, N/BN = 768/256 = 3 tiles in N-mode, K/BK = 384/64 = 6 tiles in K-mode:
|
||
|
||
.. code-block:: text
|
||
|
||
mA tiled into (M/BM x K/BK) mB tiled into (N/BN x K/BK) mC tiled into (M/BM x N/BN)
|
||
= (2 x 6) blocks = (3 x 6) blocks = (2 x 3) blocks
|
||
* coordinates annotated on the matrix
|
||
are the mma_coord_mn of the GEMM.
|
||
|
||
BK=64 x6 BK=64 x6 BN=256 x3
|
||
|<--->| |<--->| |<----->|
|
||
+-----+-----+-- --+ +-----+-----+-- --+ +-------+-------+-------+
|
||
| | |..| | ^ | | |..| | ^ | (0,0) | (0,1) | (0,2) | ^
|
||
| | | | | | BM=256 | | | | | | BN=256 | | | | | BM=256
|
||
+-----+-----+-- --+ v +-----+-----+-- --+ v +-------+-------+-------+ v
|
||
| | |..| | ^ | | |..| | ^ | (1,0) | (1,1) | (1,2) | ^
|
||
| | | | | | BM=256 | | | | | | BN=256 | | | | | BM=256
|
||
+-----+-----+-- --+ v +-----+-----+-- --+ v +-------+-------+-------+ v
|
||
| | |..| | ^
|
||
| | | | | | BN=256
|
||
+-----+-----+-- --+ v
|
||
|
||
Each CTA pair picks one (M-coord, N-coord) coordinate.
|
||
For example, CTA pair at ``mma_coord_mnk = (0, 0, :)``.
|
||
|
||
.. code-block:: text
|
||
|
||
gA: (BM, BK, k) = (256, 64, 6) gB: (BN, BK, k) = (256, 64, 6) gC: (BM, BN) = (256, 256)
|
||
|
||
BK=64 BK=64 BN=256
|
||
|<----->| |<----->| |<--------->|
|
||
+-------+-- +-------+-- +-----------+
|
||
| |.. | |.. | | ^
|
||
BM= | gA | k=6 BN= | gB | k=6 BM= | gC | 256
|
||
256 | | 256 | | | | v
|
||
+-------+ +-------+ +-----------+
|
||
|
||
``get_slice(V)`` splits BM between CTAs; BK is split into ``MMA_K`` steps:
|
||
|
||
.. code-block:: text
|
||
|
||
gA (BM split, BK split into MMA_K atoms) gC (BM split)
|
||
inst_K inst_K inst_K inst_K
|
||
=16 =16 =16 =16
|
||
|<--->|<--->|<--->|<--->|
|
||
+-----+-----+-----+-----+-- +-----------+
|
||
| 0 | 1 | 2 | 3 |.. ^ CTA 0 | CTA 0 | ^
|
||
| | | | | | BM/2=128 (V=0) | (V=0) | | BM/2=128
|
||
+-----+-----+-----+-----+ v +-----------+ v
|
||
| 0 | 1 | 2 | 3 |.. ^ CTA 1 | CTA 1 | ^
|
||
| | | | | | BM/2=128 (V=1) | (V=1) | | BM/2=128
|
||
+-----+-----+-----+-----+ v +-----------+ v
|
||
|<-- MMA_K = BK/inst_K = 4 -->|
|
||
|
||
gB (BN split for SMEM loading, BK split into MMA_K atoms)
|
||
inst_K inst_K inst_K inst_K
|
||
=16 =16 =16 =16
|
||
|<--->|<--->|<--->|<--->|
|
||
+-----+-----+-----+-----+--
|
||
| 0 | 1 | 2 | 3 |.. ^ CTA 0
|
||
| | | | | | BN/2=128 (V=0)
|
||
+-----+-----+-----+-----+ v
|
||
| 0 | 1 | 2 | 3 |.. ^ CTA 1
|
||
| | | | | | BN/2=128 (V=1)
|
||
+-----+-----+-----+-----+ v
|
||
|
||
Both CTAs consume the full gB for MMA, but for SMEM loading each CTA loads
|
||
its N-half.
|
||
|
||
After partition (per CTA, e.g. CTA 0):
|
||
|
||
- ``tCgA: (MMA, MMA_M, MMA_K, k) = (MMA, 1, 4, 6)`` — MMA_M = (BM/2)/inst_M = 128/128 = 1, MMA_K = BK/inst_K = 64/16 = 4
|
||
- ``tCgB: (MMA, MMA_N, MMA_K, k) = (MMA, 1, 4, 6)`` — MMA_N = BN/inst_N = 256/256 = 1, MMA_K = BK/inst_K = 64/16 = 4
|
||
- ``tCgC: (MMA, MMA_M, MMA_N) = (MMA, 1, 1)`` — MMA_M = (BM/2)/inst_M = 1, MMA_N = BN/inst_N = 1
|
||
|
||
Of course with CuTe DSL none of these calculations are needed.
|
||
The DSL handles all the tiling and partitioning for you provided ``mma_tiler_mnk`` and ``mma_coord``.
|
||
|
||
.. code-block:: python
|
||
|
||
@cute.kernel
|
||
def kernel(tiled_mma: cute.TiledMma, cta_layout_vmnk: cute.Layout, ...):
|
||
bidx, bidy, _ = cute.arch.block_idx()
|
||
|
||
# V-coordinate: which CTA within the 2-CTA group (0 or 1)
|
||
mma_coord_vmnk = (
|
||
bidx % cute.size(cta_layout_vmnk, mode=[0]), # V (CTA rank)
|
||
bidx // cute.size(cta_layout_vmnk, mode=[0]), # M tile
|
||
bidy, # N tile
|
||
None, # K (all tiles)
|
||
)
|
||
mma_coord_mnk = mma_coord_vmnk[1:]
|
||
|
||
gA = cute.local_tile(mA, mma_tiler_mnk, mma_coord_mnk, proj=(1, None, 1))
|
||
gB = cute.local_tile(mB, mma_tiler_mnk, mma_coord_mnk, proj=(None, 1, 1))
|
||
gC = cute.local_tile(mC, mma_tiler_mnk, mma_coord_mnk, proj=(1, 1, None))
|
||
|
||
# 2-CTA: each CTA passes its V-coordinate to get its half of the work
|
||
thr_mma = tiled_mma.get_slice(mma_coord_vmnk[0])
|
||
|
||
tCgA = thr_mma.partition_A(gA) # (MMA, MMA_M, MMA_K, num_k_tiles)
|
||
tCgB = thr_mma.partition_B(gB) # (MMA, MMA_N, MMA_K, num_k_tiles)
|
||
tCgC = thr_mma.partition_C(gC) # (MMA, MMA_M, MMA_N)
|
||
|
||
.. note:: Annotation `tCgX` means that the tensor is partitioned w.r.t C matrix coordinates, i.e., the output tile of each CTA.
|
||
|
||
Pre and Post-Conditions for TiledMMA Partitioning
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
* The inputs of the partition should be at least rank-2 tensors.
|
||
* The output of the partition will have the layout that is compatible with the MMA atom's operand:
|
||
- For A, the output will have the layout (MMA, MMA_M, MMA_K, ...).
|
||
- For B, the output will have the layout (MMA, MMA_N, MMA_K, ...).
|
||
- For C, the output will have the layout (MMA, MMA_M, MMA_N, ...).
|
||
* Note that the partition doesn't enforce any rules on the tensor's memory space or the tensor's data type. It only cares about the layout.
|
||
|
||
|
||
What happens when we use ``atom_layout_mnk``?
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
The valid coordinates to ``get_slice`` are the valid coordinates to (v,m,n,k) coordinate space
|
||
of the tiled MMA. ``mma_tiler_mnk`` should be updated such that ``mma_tiler_mnk[0] >= mma_shape[0] * |m|``,
|
||
``mma_tiler_mnk[1] >= mma_shape[1] * |n|``, and ``mma_tiler_mnk[2] >= mma_shape[2] * |k|``.
|
||
|
||
The result of the ``partition_A``, ``partition_B``, and ``partition_C`` remain the same.
|
||
|
||
Making Fragments
|
||
-----------------
|
||
|
||
Fragments are the descriptor-level tensors that the MMA
|
||
instruction operates on. For tcgen05:
|
||
|
||
- **Fragment A**: SMEM descriptor when ``a_src=SMEM``, or a TMEM address when
|
||
``a_src=TMEM``.
|
||
- **Fragment B**: SMEM descriptor pointing into staged shared memory buffers.
|
||
- **Fragment C (accumulator)**: lives in Tensor Memory (TMEM), allocated via
|
||
``TmemAllocator``.
|
||
|
||
Creating fragment descriptors and descriptor tensors
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
Unlike older architectures where fragments live in per-thread registers,
|
||
tcgen05 fragments are **descriptors** pointing into SMEM (for A and B) or
|
||
**addresses** into TMEM (for the accumulator C). The fragment creation has
|
||
three parts:
|
||
|
||
**1. A and B fragments**
|
||
|
||
*When A comes from SMEM* (``a_src=OperandSource.SMEM``):
|
||
|
||
``make_fragment_A`` and ``make_fragment_B`` take the staged SMEM tensors
|
||
(``sA``, ``sB``) and produce descriptor tensors that the MMA instruction
|
||
consumes. Each descriptor points to one stage's tile in shared memory.
|
||
|
||
.. code-block:: python
|
||
|
||
# 1. Build the SMEM layouts (see "Creating SMEM layouts for A and B")
|
||
# a_smem_layout = sm100_utils.make_smem_layout_a(tiled_mma, mma_tiler_mnk, ...)
|
||
# b_smem_layout = sm100_utils.make_smem_layout_b(tiled_mma, mma_tiler_mnk, ...)
|
||
|
||
# 2. Allocate SMEM tensors from those layouts
|
||
# sA = smem.allocate_tensor(layout=a_smem_layout.outer, swizzle=a_smem_layout.inner, ...)
|
||
# sB = smem.allocate_tensor(layout=b_smem_layout.outer, swizzle=b_smem_layout.inner, ...)
|
||
|
||
# 3. Create fragment descriptors from the SMEM tensors
|
||
tCrA = tiled_mma.make_fragment_A(sA) # (MMA, MMA_M, MMA_K, STAGE)
|
||
tCrB = tiled_mma.make_fragment_B(sB) # (MMA, MMA_N, MMA_K, STAGE)
|
||
|
||
Continuing the CtaGroup.ONE example (m128n256k16 atom,
|
||
``mma_tiler_mnk = (128, 256, 64)``, 3 pipeline stages):
|
||
|
||
.. code-block:: text
|
||
|
||
sA is an SMEM tensor with shape (MMA, MMA_M, MMA_K, STAGES),
|
||
allocated with appropriate size (see "Creating SMEM layouts for A and B").
|
||
|
||
For 128x256x16 atom, mma_tiler_mnk = (BM, BN, BK) = (128, 256, 64), 3 stages:
|
||
MMA_M = BM/inst_M = 128/128 = 1, MMA_K = BK/inst_K = 64/16 = 4, STAGES = 3
|
||
|
||
sA: (MMA, MMA_M=1, MMA_K=4, STAGES=3)
|
||
|
||
|<--MMA_K = BK/inst_K = 4-->|
|
||
stage 0: +------+------+------+------+
|
||
| k=0 | k=1 | k=2 | k=3 | inst_M=128
|
||
+------+------+------+------+
|
||
stage 1: +------+------+------+------+
|
||
| k=0 | k=1 | k=2 | k=3 | inst_M=128
|
||
+------+------+------+------+
|
||
stage 2: +------+------+------+------+
|
||
| k=0 | k=1 | k=2 | k=3 | inst_M=128
|
||
+------+------+------+------+
|
||
|
||
make_fragment_A(sA) produces SMEM descriptors with the same shape:
|
||
tCrA: (MMA, MMA_M, MMA_K, STAGES) = (MMA, 1, 4, 3)
|
||
|
||
Each element is an SMEM descriptor — one per (MMA_K, STAGE) pair.
|
||
Similarly for sB/tCrB with shape (MMA, MMA_N=1, MMA_K=4, STAGE=3).
|
||
|
||
Each element of ``tCrA`` / ``tCrB`` is an SMEM descriptor that the MMA
|
||
hardware reads directly — not a register value. Note that, when we print the
|
||
layout of ``tCrA`` (or similarly ``tCrB``), we will see that ``MMA`` dimension
|
||
of ``(MMA, MMA_M, MMA_K, STAGES)`` will appear to be ``1``. This is because
|
||
this mode is an indivisible SMEM descriptor representing the whole SMEM buffer
|
||
that a single MMA instruction will consume.
|
||
|
||
*When A comes from TMEM* (``a_src=OperandSource.TMEM``):
|
||
|
||
In use cases like FMHA or mixed-input GEMM, operand A can be sourced from
|
||
TMEM instead of SMEM. In this case, ``make_fragment_A`` is called to obtain
|
||
the layout, but the fragment is bound to a TMEM pointer instead of an SMEM
|
||
tensor:
|
||
|
||
.. code-block:: python
|
||
|
||
# Build the SMEM layout for A (see `Creating SMEM layouts for A and B`_).
|
||
# The layout defines the tile shape the MMA expects, even though the data
|
||
# will live in TMEM.
|
||
# a_smem_layout = sm100_utils.make_smem_layout_a(...)
|
||
|
||
# Use make_fragment_A with the outer layout to get the expected shape
|
||
tCrA_layout = tiled_mma.make_fragment_A(a_smem_layout.outer).layout
|
||
|
||
# Compute the TMEM pointer offset (A is placed after the accumulator columns).
|
||
# TMEM columns are 32-bit wide, so scale to element offset for narrower types
|
||
# (e.g. Float16: scale = 32 // 16 = 2).
|
||
column_to_element_scale = 32 // acc_dtype.width
|
||
tmem_ptr_a = cute.recast_ptr(
|
||
accumulators.iterator + num_acc_tmem_cols * column_to_element_scale,
|
||
dtype=mma_dtype,
|
||
)
|
||
|
||
# Bind to TMEM storage
|
||
tCrA = cute.make_tensor(tmem_ptr_a, tCrA_layout)
|
||
|
||
The A fragment in TMEM is laid out after the accumulator's TMEM columns.
|
||
|
||
**2. C fragment (accumulator) — TMEM allocation**
|
||
|
||
The accumulator lives in Tensor Memory (TMEM), a dedicated on-chip memory
|
||
separate from registers and SMEM. Creating the C fragment is a four-step
|
||
process:
|
||
|
||
.. code-block:: python
|
||
|
||
# Step 1: Query the partitioned accumulator shape
|
||
acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2])
|
||
# acc_shape: (MMA, MMA_M, MMA_N) = (MMA, 1, 1)
|
||
|
||
# Step 2: Append a staging dimension for ping-pong overlap
|
||
# Use 1 for simple kernels (no overlap), or 2+ to overlap
|
||
# MMA and epilogue on different TMEM buffers.
|
||
num_acc_stages = 2
|
||
acc_shape_staged = cute.append(acc_shape, num_acc_stages)
|
||
|
||
# Step 3: Create a fragment to establish the layout
|
||
tCtAcc = tiled_mma.make_fragment_C(acc_shape_staged)
|
||
# tCtAcc: (MMA, MMA_M, MMA_N, ACC_STAGE)
|
||
|
||
# Step 4: Bind to actual TMEM storage
|
||
tmem_ptr = tmem.retrieve_ptr(cutlass.Float32)
|
||
tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc.layout)
|
||
|
||
.. code-block:: text
|
||
|
||
partition_shape_C((BM, BN) = (128, 256))
|
||
-> (MMA, MMA_M, MMA_N) = (MMA, 1, 1)
|
||
MMA_M = BM/inst_M = 128/128 = 1
|
||
MMA_N = BN/inst_N = 256/256 = 1
|
||
|
||
cute.append(acc_shape, 2)
|
||
-> (MMA, 1, 1, 2)
|
||
|
||
make_fragment_C(acc_shape_staged)
|
||
-> tCtAcc layout: ((128, 256), 1, 1, 2)
|
||
|
||
After binding to TMEM (2-stage ping-pong):
|
||
+---------------------------+---------------------------+
|
||
| tCtAcc stage 0 | tCtAcc stage 1 |
|
||
| 128 x 256 accumulators | 128 x 256 accumulators |
|
||
| (Float32) | (Float32) |
|
||
+---------------------------+---------------------------+
|
||
|
||
|
||
Creating SMEM layouts for A and B
|
||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
The SMEM layouts define how A and B tiles are stored in shared memory, including
|
||
swizzling for bank-conflict-free access. The helper functions handle the
|
||
details: partitioned shape from the tiled MMA, swizzle atom selection, tiling to
|
||
the MMA shape, and multi-stage buffering.
|
||
|
||
**Host side** (``@cute.jit``):
|
||
|
||
.. code-block:: python
|
||
|
||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||
|
||
# Create SMEM layouts (includes swizzle + staging)
|
||
a_smem_layout = sm100_utils.make_smem_layout_a(
|
||
tiled_mma, mma_tiler_mnk, a.element_type, num_stages,
|
||
)
|
||
b_smem_layout = sm100_utils.make_smem_layout_b(
|
||
tiled_mma, mma_tiler_mnk, b.element_type, num_stages,
|
||
)
|
||
|
||
``make_smem_layout_a`` and ``make_smem_layout_b`` are convenience helpers that
|
||
build a complete, staged SMEM layout in four steps:
|
||
|
||
1. **Determine the major mode.** The major mode (K-major or MN-major) is read
|
||
from the MMA op's ``a_major_mode`` / ``b_major_mode`` attribute (or can be
|
||
overridden via the ``is_k_major`` keyword argument).
|
||
|
||
2. **Compute the partitioned SMEM tile shape.** The tiled MMA is asked for the
|
||
partitioned shape of the operand via ``tiled_mma.partition_shape_A``
|
||
(or ``partition_shape_B``). ``cute.dice`` strips the irrelevant mode first
|
||
— for A the ``(M, K)`` portion is kept, for B the ``(N, K)`` portion. The
|
||
result is a hierarchical shape ``((MMA, MMA_MN, MMA_K), repeat_MN, repeat_K)``
|
||
that is flattened into a 2D ``(MN, K)`` size for swizzle selection.
|
||
|
||
3. **Select and materialise the swizzle atom.** A heuristic
|
||
(``get_smem_layout_atom_ab``) picks the widest swizzle whose contiguous
|
||
size (in bits) evenly divides the major-mode dimension:
|
||
|
||
========== ==================
|
||
Swizzle Contiguous bits
|
||
========== ==================
|
||
SW128 1024 (128 B)
|
||
SW64 512 (64 B)
|
||
SW32 256 (32 B)
|
||
Interleave 128 (16 B)
|
||
========== ==================
|
||
|
||
``make_smem_layout_atom`` then combines the chosen swizzle with a compact
|
||
``(MN_elems, 8)`` or ``(8, K_elems)`` outer layout (depending on the major
|
||
mode) into a ``ComposedLayout(swizzle, outer)``.
|
||
|
||
4. **Tile to the MMA shape and append the staging dimension.**
|
||
``tile_to_mma_shape`` broadcasts the atom to the full partitioned shape
|
||
(with ``num_stages`` appended). The ``order`` argument controls which
|
||
dimension is contiguous: ``(1, 2, 3)`` for K-major (K innermost),
|
||
``(2, 1, 3)`` for MN-major (MN innermost).
|
||
|
||
The resulting layout is then fed into SMEM tensors are allocated using
|
||
the layout info:
|
||
|
||
**Kernel side** (``@cute.kernel``):
|
||
|
||
.. code-block:: python
|
||
|
||
smem = cutlass.utils.SmemAllocator()
|
||
sA = smem.allocate_tensor(
|
||
element_type=io_dtype,
|
||
layout=a_smem_layout.outer,
|
||
byte_alignment=128,
|
||
swizzle=a_smem_layout.inner,
|
||
)
|
||
sB = smem.allocate_tensor(
|
||
element_type=io_dtype,
|
||
layout=b_smem_layout.outer,
|
||
byte_alignment=128,
|
||
swizzle=b_smem_layout.inner,
|
||
)
|
||
|
||
.. note:: **Creating SMEM layouts without utilities**
|
||
If you want to create SMEM layouts without using the utilities, you can do the following:
|
||
|
||
.. code-block:: python
|
||
|
||
swizzle = cute.Swizzle(3, 4, 3)
|
||
mma_tile = tiled_mma.partition_shape_A((mma_tiler_mnk[0], mma_tiler_mnk[2]))
|
||
smem_tile = tcgen05.tile_to_mma_shape(swizzle, mma_tile, order=(1, 2, 3))
|
||
|
||
|
||
Executing the GEMM (Main Loop)
|
||
-------------------------------
|
||
|
||
The main loop iterates over K-tiles, loading A and B from global memory via TMA
|
||
into staged SMEM buffers, then issuing ``cute.gemm`` for each tile. The TMA copy
|
||
details are omitted for brevity.
|
||
|
||
.. code-block:: python
|
||
|
||
for k_tile_idx in cutlass.range(num_k_tiles):
|
||
# Wait for TMA load to complete for this K-tile
|
||
ab_full = ab_consumer.wait_and_advance()
|
||
|
||
# Set accumulate mode: first tile overwrites, subsequent tiles accumulate
|
||
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
|
||
|
||
# Issue MMA: tCtAcc += tCrA * tCrB
|
||
tile_crd = (None, None, None, ab_full.index)
|
||
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
|
||
|
||
# Release the SMEM buffer for the next TMA load
|
||
ab_full.release()
|
||
|
||
Key points:
|
||
|
||
- ``tcgen05.Field.ACCUMULATE`` controls whether the MMA accumulates into D
|
||
(``True``) or overwrites D with ``A * B`` (``False``). Set to ``False`` for
|
||
the first K-tile and ``True`` for all subsequent tiles.
|
||
- ``cute.gemm`` is asynchronous. Synchronization is handled by the pipeline
|
||
barriers (``cutlass.pipeline.sm100.PipelineTmaUmma``).
|
||
- The ``tile_crd`` selects which pipeline stage's SMEM buffer to read from.
|
||
|
||
Reading the accumulator from TMEM
|
||
----------------------------------
|
||
|
||
.. code-block:: python
|
||
|
||
tCtAcc = tiled_mma.make_fragment_C(mma_tiler_mnk[:2]) # (MMA, MMA_M, MMA_N) for C
|
||
# TMEM allocation (done once, before the main loop)
|
||
tmem = cutlass.utils.TmemAllocator(...)
|
||
tmem.allocate(num_cols=512)
|
||
tmem_ptr = tmem.retrieve_ptr(cutlass.Float32)
|
||
tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc.layout)
|
||
|
||
# Build copy atom for TMEM → RMEM load
|
||
copy_atom_t2r = cute.make_copy_atom(
|
||
tcgen05.Ld32x32bOp(tcgen05.Repetition.x64),
|
||
cutlass.Float32,
|
||
)
|
||
tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc[(None, None), 0, 0])
|
||
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
|
||
|
||
# (T2R, T2R_M, NumTiles)
|
||
tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc)
|
||
# (T2R, T2R_M, NumTiles)
|
||
tTR_gC = thr_copy_t2r.partition_D(tCgC)
|
||
|
||
# (T2R, T2R_M)
|
||
tTR_rAcc = cute.make_rmem_tensor(tTR_gC[None, None, 0].shape, acc_dtype)
|
||
|
||
cute.copy(tiled_copy_t2r, tTR_tAcc[None, None, i], tTR_rAcc)
|
||
|
||
|
||
Complete Workflow
|
||
------------------
|
||
|
||
Putting it all together, a typical Blackwell tcgen05 GEMM has this structure:
|
||
|
||
**Host function** (``@cute.jit``):
|
||
|
||
.. code-block:: python
|
||
|
||
import cutlass
|
||
import cutlass.cute as cute
|
||
from cutlass.cute.nvgpu import cpasync, tcgen05, OperandMajorMode
|
||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||
|
||
@cute.jit
|
||
def host_function(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor):
|
||
# 1. Create the MMA op and tiled MMA
|
||
op = tcgen05.MmaF16BF16Op(
|
||
cutlass.Float16, cutlass.Float32,
|
||
(128, 256, 16),
|
||
tcgen05.CtaGroup.ONE,
|
||
tcgen05.OperandSource.SMEM,
|
||
OperandMajorMode.K,
|
||
OperandMajorMode.K,
|
||
)
|
||
tiled_mma = cute.make_tiled_mma(op)
|
||
|
||
# 2. Create SMEM layouts for A and B
|
||
a_smem_layout = sm100_utils.make_smem_layout_a(
|
||
tiled_mma, mma_tiler_mnk, a.element_type, num_stages,
|
||
)
|
||
b_smem_layout = sm100_utils.make_smem_layout_b(
|
||
tiled_mma, mma_tiler_mnk, b.element_type, num_stages,
|
||
)
|
||
|
||
# 3. Create TMA copy atoms for global -> shared memory loads
|
||
copy_op = cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE)
|
||
tma_a = cute.nvgpu.make_tiled_tma_atom_A(
|
||
copy_op, a, a_smem_layout, mma_tiler_mnk, tiled_mma,
|
||
)
|
||
tma_b = cute.nvgpu.make_tiled_tma_atom_B(
|
||
copy_op, b, b_smem_layout, mma_tiler_mnk, tiled_mma,
|
||
)
|
||
|
||
# 4. Launch the kernel
|
||
grid = cute.ceil_div((*c.layout.shape, 1), mma_tiler_mnk[:2])
|
||
kernel(tiled_mma, tma_a, tma_b, c).launch(
|
||
grid=grid, block=(128, 1, 1),
|
||
)
|
||
|
||
**Kernel function** (``@cute.kernel``):
|
||
|
||
.. code-block:: python
|
||
|
||
@cute.kernel
|
||
def kernel(
|
||
tiled_mma: cute.TiledMma,
|
||
tma_a: cpasync.TmaInfo,
|
||
tma_b: cpasync.TmaInfo,
|
||
mC: cute.Tensor,
|
||
):
|
||
# -- Setup --
|
||
bidx, bidy, _ = cute.arch.block_idx()
|
||
mma_coord_mnk = (bidx, bidy, None)
|
||
|
||
# Global tensors for A and B live inside the TMA descriptor
|
||
mA = tma_a.tma_tensor # (M, K)
|
||
mB = tma_b.tma_tensor # (N, K)
|
||
|
||
# Allocate SMEM for A, B (staged) and pipeline barriers
|
||
smem = cutlass.utils.SmemAllocator()
|
||
sA = smem.allocate_tensor(...) # staged SMEM for A
|
||
sB = smem.allocate_tensor(...) # staged SMEM for B
|
||
|
||
# Allocate TMEM for the accumulator
|
||
tmem = cutlass.utils.TmemAllocator(...)
|
||
tmem.allocate(num_cols=512)
|
||
|
||
# -- Partition and make fragments --
|
||
# (BM, BK, k)
|
||
gA = cute.local_tile(mA, mma_tiler_mnk, mma_coord_mnk, proj=(1, None, 1))
|
||
# (BN, BK, k)
|
||
gB = cute.local_tile(mB, mma_tiler_mnk, mma_coord_mnk, proj=(None, 1, 1))
|
||
# (BM, BN)
|
||
gC = cute.local_tile(mC, mma_tiler_mnk, mma_coord_mnk, proj=(1, 1, None))
|
||
|
||
thr_mma = tiled_mma.get_slice(0)
|
||
tCgA = thr_mma.partition_A(gA) # (MMA, MMA_M, MMA_K, num_k_tiles)
|
||
tCgB = thr_mma.partition_B(gB) # (MMA, MMA_N, MMA_K, num_k_tiles)
|
||
tCgC = thr_mma.partition_C(gC) # (MMA, MMA_M, MMA_N)
|
||
|
||
# SMEM descriptor fragments
|
||
tCrA = tiled_mma.make_fragment_A(sA) # (MMA, MMA_M, MMA_K, STAGE)
|
||
tCrB = tiled_mma.make_fragment_B(sB) # (MMA, MMA_N, MMA_K, STAGE)
|
||
|
||
# TMEM accumulator
|
||
acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2])
|
||
tCtAcc = tiled_mma.make_fragment_C(acc_shape) # (MMA, MMA_M, MMA_N)
|
||
|
||
# Bind accumulator to TMEM
|
||
tmem.wait_for_alloc()
|
||
tmem_ptr = tmem.retrieve_ptr(cutlass.Float32)
|
||
tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc.layout)
|
||
|
||
# TMA partition for global → shared memory copies
|
||
tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
|
||
tma_a.atom, 0, cute.make_layout(1),
|
||
cute.group_modes(sA, 0, 3), cute.group_modes(tCgA, 0, 3),
|
||
)
|
||
tBsB, tBgB = cute.nvgpu.cpasync.tma_partition(
|
||
tma_b.atom, 0, cute.make_layout(1),
|
||
cute.group_modes(sB, 0, 3), cute.group_modes(tCgB, 0, 3),
|
||
)
|
||
|
||
# -- Main loop: iterate over K-tiles --
|
||
num_k_tiles = cute.size(gA, mode=[2])
|
||
for k_tile_idx in cutlass.range(num_k_tiles):
|
||
# TMA load A, B into staged SMEM (producer side)
|
||
# copy(tma_a.atom, tAgA[k], tAsA[stage])
|
||
# copy(tma_b.atom, tBgB[k], tBsB[stage])
|
||
# ... (see pipeline documentation)
|
||
|
||
# Wait for data
|
||
ab_full = ab_consumer.wait_and_advance()
|
||
|
||
# MMA
|
||
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
|
||
tile_crd = (None, None, None, ab_full.index)
|
||
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
|
||
|
||
ab_full.release()
|
||
|
||
# -- Epilogue: copy accumulator from TMEM to global memory --
|
||
copy_atom_t2r = cute.make_copy_atom(
|
||
tcgen05.Ld32x32bOp(tcgen05.Repetition.x64), cutlass.Float32,
|
||
)
|
||
tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc[(None, None), 0, 0])
|
||
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
|
||
|
||
tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc)
|
||
tTR_gC = thr_copy_t2r.partition_D(tCgC)
|
||
tTR_rAcc = cute.make_rmem_tensor(tTR_gC[None, None, 0].shape, acc_dtype)
|
||
|
||
# TMEM → RMEM, then RMEM → GMEM
|
||
for i in cutlass.range(num_tiles):
|
||
cute.copy(tiled_copy_t2r, tTR_tAcc[None, None, i], tTR_rAcc)
|
||
cute.copy(store_atom, tTR_rAcc, tTR_gC[None, None, i])
|
||
|
||
|
||
Beyond Simple Dense MMAs
|
||
------------------------
|
||
|
||
The tcgen05 MMA DSL supports more complex MMA operations than just the simple dense MMA.
|
||
|
||
|
||
- Block-scaled MMA
|
||
|
||
.. {$nv-internal-release begin}
|
||
|
||
Internal builds additionally provide:
|
||
|
||
- Sparse MMA
|
||
|
||
.. {$nv-internal-release end}
|
||
|
||
.. {$nv-internal-release begin}
|
||
|
||
Sparse MMA
|
||
~~~~~~~~~~
|
||
|
||
Sparse MMA exploits **X:Y = {1:2, 2:4, 4:8} structured sparsity** in operand A: out of every
|
||
Y consecutive K-elements, exactly X are non-zero and Y-X are zero. The
|
||
kernel stores the compressed A values separately from the **metadata**
|
||
tensor ``E``, which records which 2 of 4 positions are non-zero.
|
||
|
||
Compared to a dense MMA kernel, a sparse kernel differs in five areas:
|
||
|
||
**1. MMA op creation** — use ``MmaF16BF16SparseOp`` with an extra
|
||
``sparse_metadata_format`` parameter. The instruction K is **doubled**
|
||
(32 vs 16 for dense F16/BF16) to account for the compressed operand. The
|
||
example here builds its sparse ``TiledMma`` through
|
||
``sm100_utils.make_sparse_trivial_tiled_mma(...)``:
|
||
|
||
.. code-block:: python
|
||
|
||
from cutlass.cute.nvgpu.warp.mma import SparseMetadataFormat
|
||
|
||
# Dense F16 (for comparison): inst_K = 16
|
||
dense_op = tcgen05.MmaF16BF16Op(
|
||
cutlass.Float16, cutlass.Float32, (128, 256, 16),
|
||
tcgen05.CtaGroup.ONE, tcgen05.OperandSource.SMEM,
|
||
cute.nvgpu.OperandMajorMode.K, cute.nvgpu.OperandMajorMode.K,
|
||
)
|
||
|
||
# Sparse F16: inst_K = 32 (2× dense, since A is 2:4 compressed)
|
||
sparse_op = tcgen05.MmaF16BF16SparseOp(
|
||
cutlass.Float16, cutlass.Float32, (128, 256, 32),
|
||
tcgen05.CtaGroup.ONE, tcgen05.OperandSource.SMEM,
|
||
cute.nvgpu.OperandMajorMode.K, cute.nvgpu.OperandMajorMode.K,
|
||
SparseMetadataFormat.TID,
|
||
)
|
||
|
||
# The sparse GEMM example uses the public helper
|
||
tiled_mma = sm100_utils.make_sparse_trivial_tiled_mma(
|
||
a_raw_dtype, a_major_mode, b_major_mode, acc_dtype, cta_group,
|
||
mma_tiler_mn=mma_tiler_mnk[:2],
|
||
sparse_metadata_format=SparseMetadataFormat.TID,
|
||
)
|
||
|
||
**2. Compressed A and metadata E tensors** — operand A is stored with
|
||
**half** the K-elements (the two non-zero values per group of 4), using a
|
||
``sparse_elem<2, dtype>`` type. The metadata tensor E is a compact bit-field
|
||
(``sparse_elem<8, uint8>`` for F16/BF16) that encodes the sparsity pattern.
|
||
|
||
.. code-block:: python
|
||
|
||
# Sparse element types
|
||
a_sparse_dtype = sm100_utils.make_sparse_a_dtype(a_raw_dtype) # sparse_elem<2, F16>
|
||
e_sparse_dtype = sm100_utils.make_sparse_e_dtype(a_raw_dtype) # sparse_elem<8, uint8>
|
||
|
||
# GMEM layouts for compressed A and metadata E
|
||
sp_a_ptr = cute.recast_ptr(a.iterator, dtype=a_sparse_dtype)
|
||
sp_a_layout = sm100_utils.make_sparse_gmem_layout_a(
|
||
mnkl,
|
||
a_raw_dtype,
|
||
is_k_major=(a_major_mode == cute.nvgpu.OperandMajorMode.K),
|
||
sparsity=2,
|
||
)
|
||
sp_a_tensor = cute.make_tensor(sp_a_ptr, sp_a_layout)
|
||
|
||
sp_e_ptr = cute.recast_ptr(e.iterator, dtype=e_sparse_dtype)
|
||
sp_e_layout = sm100_utils.make_sparse_gmem_layout_e(mnkl, a_raw_dtype)
|
||
sp_e_tensor = cute.make_tensor(sp_e_ptr, sp_e_layout)
|
||
|
||
.. code-block:: text
|
||
|
||
Dense A: (M, K) Sparse A (compressed): (M, (2, K/2))
|
||
+--+--+--+--+--+--+--+--+ +--+--+--+--+
|
||
| a| 0| b| 0| c| 0| d| 0| → | a| b| c| d| (only non-zeros stored)
|
||
+--+--+--+--+--+--+--+--+ +--+--+--+--+
|
||
|
||
Metadata E encodes positions: E: [00, 10, 00, 10]
|
||
(which 2 of 4 are non-zero) ↑ ↑
|
||
positions of a,b in each group
|
||
|
||
**3. Extra SMEM layouts, TMA loads, and allocations** — sparse kernels use
|
||
dedicated layout helpers for A and E. An additional TMA descriptor loads
|
||
the metadata into SMEM alongside A and B. In the example here,
|
||
``E`` uses its own logical tile ``mma_tiler_e``:
|
||
|
||
.. code-block:: python
|
||
|
||
# Host side: SMEM layouts
|
||
a_smem_layout = sm100_utils.make_sparse_smem_layout_a(
|
||
tiled_mma, mma_tiler_mnk, a_raw_dtype, num_stages, sparsity=2,
|
||
)
|
||
e_smem_layout = sm100_utils.make_sparse_smem_layout_e(
|
||
tiled_mma, mma_tiler_e, a_raw_dtype, num_stages,
|
||
)
|
||
|
||
# Host side: TMA atom for metadata E (note mma_tiler_e and internal_type=Uint64)
|
||
a_op = sm100_utils.cluster_shape_to_tma_atom_A(cluster_shape_mn, tiled_mma.thr_id)
|
||
tma_atom_e, tma_tensor_e = cute.nvgpu.make_tiled_tma_atom_A(
|
||
a_op,
|
||
sp_e_tensor,
|
||
cute.slice_(e_smem_layout, (None, None, None, 0)),
|
||
mma_tiler_e,
|
||
tiled_mma,
|
||
cluster_layout_vmnk.shape,
|
||
internal_type=cutlass.Uint64,
|
||
)
|
||
|
||
# Kernel side: SMEM allocation for metadata
|
||
sE = smem.allocate_tensor(
|
||
element_type=e_sparse_dtype,
|
||
layout=e_smem_layout.outer,
|
||
byte_alignment=128,
|
||
swizzle=e_smem_layout.inner,
|
||
)
|
||
|
||
**4. Metadata TMEM allocation and SMEM→TMEM copy (S2T)** — the metadata
|
||
must live in TMEM for the MMA instruction. It is placed **after** the
|
||
accumulator columns, and an S2T copy moves it from SMEM to TMEM each
|
||
K-tile. The example here also recasts both sides to raw ``uint8``
|
||
before building the S2T copy and wraps the SMEM source in an S2T
|
||
descriptor tensor:
|
||
|
||
.. code-block:: python
|
||
|
||
# TMEM layout for metadata (placed after accumulator)
|
||
e_tmem_layout = sm100_utils.make_sparse_tmem_layout_e(
|
||
cute.slice_(e_smem_layout_staged, (None, None, None, 0)).shape,
|
||
a_raw_dtype,
|
||
)
|
||
acc_tmem_col_offset = tcgen05.find_tmem_tensor_col_offset(tCtAcc_base)
|
||
if cutlass.const_expr(acc_dtype.width < 32):
|
||
acc_tmem_col_offset = acc_tmem_col_offset * (32 // acc_dtype.width)
|
||
e_tmem_ptr = cute.recast_ptr(
|
||
acc_tmem_ptr + acc_tmem_col_offset, dtype=e_sparse_dtype,
|
||
)
|
||
tCtE = cute.make_tensor(e_tmem_ptr, e_tmem_layout)
|
||
|
||
# S2T copy setup (SMEM → TMEM for metadata)
|
||
e_raw_dtype = cutlass.Uint8
|
||
copy_atom_s2t_e = cute.make_copy_atom(
|
||
tcgen05.Cp128x128bOp(cta_group), e_raw_dtype,
|
||
)
|
||
tCtE_recast = cute.recast_tensor(tCtE, e_raw_dtype)
|
||
tiled_copy_s2t_E = tcgen05.make_s2t_copy(copy_atom_s2t_e, tCtE_recast)
|
||
thr_copy_s2t_E = tiled_copy_s2t_E.get_slice(0)
|
||
|
||
sE_recast = cute.recast_tensor(sE, e_raw_dtype)
|
||
thr_tCsE_s2t_ = thr_copy_s2t_E.partition_S(sE_recast)
|
||
thr_tCsE_s2t = tcgen05.get_s2t_smem_desc_tensor(
|
||
tiled_copy_s2t_E, thr_tCsE_s2t_
|
||
)
|
||
thr_tCtE_s2t = thr_copy_s2t_E.partition_D(tCtE_recast)
|
||
|
||
**5. Modified main loop** — each K-tile iteration loads the metadata via
|
||
S2T, then sets the ``METADATA`` field on the atom before calling ``gemm``.
|
||
The ``gemm`` call signature itself is unchanged; the metadata is implicit
|
||
via the atom field. The full kernel also contains leader-CTA
|
||
synchronization and optional metadata reuse when ``utccp_reuse_cnt > 1``;
|
||
the schematic below keeps only the dataflow-relevant steps:
|
||
|
||
.. code-block:: python
|
||
|
||
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||
|
||
for k_tile in range(k_tile_cnt):
|
||
# S2T: move metadata for the current stage from SMEM to TMEM.
|
||
cute.copy(tiled_copy_s2t_E, e_smem_stage, e_tmem_stage)
|
||
|
||
for kblk_idx in cutlass.range(cute.size(tCrA, mode=[2]), unroll_full=True):
|
||
e_idx = metadata_index_for(k_tile, kblk_idx)
|
||
tiled_mma.set(tcgen05.Field.METADATA, tCtE[None, None, e_idx].iterator)
|
||
cute.gemm(tiled_mma, tCtAcc, tCrA_kblk, tCrB_kblk, tCtAcc)
|
||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||
|
||
.. code-block:: text
|
||
|
||
Dense main loop (per K-tile):
|
||
set(ACCUMULATE, ...)
|
||
gemm(tiled_mma, tCtAcc, tCrA[s], tCrB[s], tCtAcc)
|
||
|
||
Sparse main loop (schematic, per K-tile):
|
||
copy(s2t_E, sE[stage], tCtE) ← metadata SMEM → TMEM
|
||
set(METADATA, tCtE[e_idx].iterator) ← point atom at metadata
|
||
set(ACCUMULATE, ...)
|
||
gemm(tiled_mma, tCtAcc, tCrA[s], tCrB[s], tCtAcc) ← same call
|
||
|
||
The epilogue (TMEM → RMEM → GMEM) is identical to a dense kernel.
|
||
|
||
.. {$nv-internal-release end}
|
||
|
||
|
||
Block-scaled MMA
|
||
~~~~~~~~~~~~~~~~
|
||
|
||
Block-scaled MMA multiplies narrow-type matrices (the tcgen05 MXF8 and
|
||
MXF4-family ops shown here) with **per-block scale factors** along GEMM-K.
|
||
Each ``sf_vec_size`` consecutive K-elements shares one scale factor, so the
|
||
hardware computes ``D = (SFA · A) * (SFB · B) + C``. Unlike dense A/B,
|
||
SFA/SFB must be staged in **TMEM** before ``gemm``.
|
||
|
||
Supported ops: ``MmaMXF8F6F4Op``, ``MmaMXF4Op``, ``MmaMXF4NVF4Op``.
|
||
|
||
Compared to a dense MMA kernel, a block-scaled kernel differs in five areas:
|
||
|
||
**1. MMA op creation** — block-scaled ops fix the accumulator to FP32 and add
|
||
scale-factor typing. The examples usually build ``TiledMma`` through
|
||
``sm100_utils.make_blockscaled_trivial_tiled_mma(...)``, which dispatches to
|
||
``MmaMXF8F6F4Op``, ``MmaMXF4Op``, or ``MmaMXF4NVF4Op`` from
|
||
``(ab_dtype, sf_vec_size)``:
|
||
|
||
.. code-block:: python
|
||
|
||
# Direct op examples
|
||
mxf8_op = tcgen05.MmaMXF8F6F4Op(
|
||
cutlass.Float8E4M3FN, (128, 256, 32),
|
||
tcgen05.CtaGroup.ONE, tcgen05.OperandSource.SMEM,
|
||
cute.nvgpu.OperandMajorMode.K, cute.nvgpu.OperandMajorMode.K,
|
||
)
|
||
|
||
# MXF4/NVF4 example (MmaMXF4Op is the sf_vec_size=32 companion)
|
||
nvf4_op = tcgen05.MmaMXF4NVF4Op(
|
||
cutlass.Float8E8M0FNU, # sf_dtype: UE8M0 or UE4M3
|
||
(128, 256, 64),
|
||
tcgen05.CtaGroup.ONE, tcgen05.OperandSource.SMEM,
|
||
)
|
||
|
||
# Helper used by the block-scaled examples here
|
||
tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
||
ab_dtype, a_major_mode, b_major_mode, sf_dtype, sf_vec_size,
|
||
cta_group, mma_tiler_mn,
|
||
)
|
||
|
||
**2. Extra scale-factor tensors and SMEM layouts** — derive SFA/SFB tensors
|
||
from the A/B shapes, then build staged SMEM layouts for them:
|
||
|
||
.. code-block:: python
|
||
|
||
import cutlass.utils.blockscaled_layout as blockscaled_utils
|
||
|
||
# Scale-factor tensors
|
||
sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(a_tensor.shape, sf_vec_size)
|
||
sfa_tensor = cute.make_tensor(sfa_ptr, sfa_layout)
|
||
sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b_tensor.shape, sf_vec_size)
|
||
sfb_tensor = cute.make_tensor(sfb_ptr, sfb_layout)
|
||
|
||
# Staged SMEM layouts
|
||
sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa(
|
||
tiled_mma, mma_tiler_mnk, sf_vec_size, num_ab_stage,
|
||
)
|
||
sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb(
|
||
tiled_mma, mma_tiler_mnk, sf_vec_size, num_ab_stage,
|
||
)
|
||
|
||
**3. Extra TMA loads and SMEM allocations** — there are four GMEM→SMEM loads
|
||
instead of two. SFA follows the A-side TMA path; SFB uses
|
||
``cluster_shape_to_tma_atom_SFB(...)`` and may use its own tiler/layout in
|
||
2CTA kernels. The pipeline byte count also includes the SFA/SFB traffic:
|
||
|
||
.. code-block:: python
|
||
|
||
# TMA atoms for SFA/SFB (note internal_type=Int16 for packing)
|
||
sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(cluster_shape_mn, tiled_mma.thr_id)
|
||
tma_sfa = cute.nvgpu.make_tiled_tma_atom_A(
|
||
sfa_op, sfa_tensor, sfa_smem_layout_staged,
|
||
mma_tiler_mnk, tiled_mma, cluster_layout_vmnk.shape,
|
||
internal_type=cutlass.Int16,
|
||
)
|
||
|
||
sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(cluster_shape_mn, tiled_mma.thr_id)
|
||
tma_sfb = cute.nvgpu.make_tiled_tma_atom_B(
|
||
sfb_op, sfb_tensor, sfb_smem_layout_staged,
|
||
mma_tiler_sfb, tiled_mma_sfb, cluster_layout_sfb_vmnk.shape,
|
||
internal_type=cutlass.Int16,
|
||
)
|
||
|
||
# Kernel side: allocate staged SMEM for scale factors
|
||
sSFA = smem.allocate_tensor(element_type=sf_dtype, layout=tma_sfa.smem_layout, ...)
|
||
sSFB = smem.allocate_tensor(element_type=sf_dtype, layout=tma_sfb.smem_layout, ...)
|
||
|
||
**4. Scale-factor TMEM allocation and SMEM→TMEM copy (S2T)** — before each
|
||
``gemm``, SFA/SFB must be copied from staged SMEM into TMEM. The examples
|
||
compact away zero-stride modes and wrap the SMEM source in an S2T descriptor
|
||
tensor:
|
||
|
||
.. code-block:: python
|
||
|
||
# TMEM allocation for scale factors
|
||
tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa(
|
||
tiled_mma, mma_tiler_mnk, sf_vec_size,
|
||
cute.slice_(tma_sfa.smem_layout, (None, None, None, 0)),
|
||
)
|
||
tCtSFA = tmem_pool.allocate_tensor(tCtSFA_layout, sf_dtype)
|
||
tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb(
|
||
tiled_mma, mma_tiler_mnk, sf_vec_size,
|
||
cute.slice_(tma_sfb.smem_layout, (None, None, None, 0)),
|
||
)
|
||
tCtSFB = tmem_pool.allocate_tensor(tCtSFB_layout, sf_dtype)
|
||
|
||
# S2T copy setup (SMEM → TMEM for scale factors)
|
||
copy_atom_s2t = cute.make_copy_atom(
|
||
tcgen05.Cp4x32x128bOp(cta_group), sf_dtype,
|
||
)
|
||
|
||
# SFA shown; SFB follows the same pattern.
|
||
tCtSFA_compact = cute.filter_zeros(tCtSFA)
|
||
tiled_copy_s2t_sfa = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFA_compact)
|
||
thr_copy_s2t_sfa = tiled_copy_s2t_sfa.get_slice(0)
|
||
tCsSFA_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(
|
||
tiled_copy_s2t_sfa,
|
||
thr_copy_s2t_sfa.partition_S(cute.filter_zeros(sSFA)),
|
||
)
|
||
tCtSFA_compact_s2t = thr_copy_s2t_sfa.partition_D(tCtSFA_compact)
|
||
# Repeat with tCtSFB / sSFB to produce:
|
||
# tiled_copy_s2t_sfb, tCsSFB_compact_s2t, and tCtSFB_compact_s2t.
|
||
|
||
**5. Modified main loop** — per K-tile, load A/B/SFA/SFB into SMEM, copy
|
||
SFA/SFB to TMEM, then call ``gemm`` with ``[value, scale]`` operands. The
|
||
persistent kernel separates TMA and MMA warps; the tutorial-style loop below
|
||
keeps only the operand flow:
|
||
|
||
.. code-block:: python
|
||
|
||
for k_tile in cutlass.range(num_k_tiles):
|
||
# TMA load A, B, SFA, SFB into SMEM
|
||
cute.copy(tma_a.atom, tAgA[None, ab_empty.count], tAsA[None, ab_empty.index], ...)
|
||
cute.copy(tma_b.atom, tBgB[None, ab_empty.count], tBsB[None, ab_empty.index], ...)
|
||
cute.copy(tma_sfa.atom, tAgSFA[None, ab_empty.count], tAsSFA[None, ab_empty.index], ...)
|
||
cute.copy(tma_sfb.atom, tBgSFB[None, ab_empty.count], tBsSFB[None, ab_empty.index], ...)
|
||
|
||
ab_full = ab_consumer.wait_and_advance()
|
||
|
||
# S2T: copy scale factors from SMEM to TMEM
|
||
s2t_stage_coord = (None, None, None, None, ab_full.index)
|
||
cute.copy(
|
||
tiled_copy_s2t_sfa,
|
||
tCsSFA_compact_s2t[s2t_stage_coord],
|
||
tCtSFA_compact_s2t,
|
||
)
|
||
cute.copy(
|
||
tiled_copy_s2t_sfb,
|
||
tCsSFB_compact_s2t[s2t_stage_coord],
|
||
tCtSFB_compact_s2t,
|
||
)
|
||
|
||
# MMA with scale factors passed as [value, scale] pairs
|
||
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile != 0)
|
||
tile_crd = (None, None, None, ab_full.index)
|
||
cute.gemm(
|
||
tiled_mma,
|
||
tCtAcc,
|
||
[tCrA[tile_crd], tCtSFA], # A value (SMEM) + A scale (TMEM)
|
||
[tCrB[tile_crd], tCtSFB], # B value (SMEM) + B scale (TMEM)
|
||
tCtAcc,
|
||
)
|
||
|
||
ab_full.release()
|
||
|
||
.. code-block:: text
|
||
|
||
Dense tcgen05 mainloop (schematic):
|
||
gemm(tiled_mma, tCtAcc, tCrA[s], tCrB[s], tCtAcc)
|
||
|
||
Block-scaled tcgen05 mainloop (schematic):
|
||
copy(s2t_sfa, sSFA[stage], tCtSFA) ← scale A to TMEM
|
||
copy(s2t_sfb, sSFB[stage], tCtSFB) ← scale B to TMEM
|
||
gemm(tiled_mma, tCtAcc, [tCrA[s], tCtSFA], [tCrB[s], tCtSFB], tCtAcc)
|
||
|
||
The epilogue (TMEM → RMEM → GMEM) is identical to a dense kernel.
|
||
|
||
See also:
|
||
|
||
- Tutorial: step-by-step dense F16 GEMM — ``examples/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_0.py``
|
||
(and ``fp16_gemm_1.py`` through ``fp16_gemm_6.py`` for progressive optimizations)
|
||
- Tutorial: block-scaled NVFP4 GEMM — ``examples/cute/blackwell/tutorial/tutorial_gemm/nvfp4_gemm_0.py``
|
||
- Dense GEMM (production): ``examples/cute/blackwell/kernel/dense_gemm/dense_gemm.py``
|
||
- Persistent dense GEMM: ``examples/cute/blackwell/kernel/dense_gemm/dense_gemm_persistent.py``
|
||
- Block-scaled GEMM: ``examples/cute/blackwell/kernel/blockscaled_gemm/dense_blockscaled_gemm_persistent.py``
|
||
- Sparse GEMM: ``examples/cute/blackwell/kernel/sparse_gemm/sparse_gemm_persistent.py``
|
||
- Helper utilities: ``cutlass.utils.blackwell_helpers``
|
||
- Block-scaled layout utilities: ``cutlass.utils.blockscaled_layout``
|