mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-24 06:44:30 +00:00
v4.5.1 update. (#3237)
This commit is contained in:
@@ -23,3 +23,5 @@ CuTe DSL
|
||||
Compile with TVM FFI <cute_dsl_general/compile_with_tvm_ffi.rst>
|
||||
Ahead-of-Time (AOT) Compilation <cute_dsl_general/dsl_ahead_of_time_compilation.rst>
|
||||
Talks and Presentations <cute_dsl_general/resources.rst>
|
||||
Naming Conventions <cute_dsl_general/naming_conventions.rst>
|
||||
MMA Programming Guides <mma_docs/intro.rst>
|
||||
|
||||
@@ -83,6 +83,9 @@ an elementwise lambda function can be passed in as the ``epilogue_op`` argument.
|
||||
|
||||
Refer to the `Blackwell dense GEMM example <https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py>`__ for a complete example.
|
||||
|
||||
.. note::
|
||||
For the per-thread/partition naming convention used above (``tTR_rAcc``, ``tTR_rC``, and related tokens such as ``tAgA``, ``bSG_sC``, ``tQgQ_qdl``, …), see the :ref:`cute_dsl_naming_conventions`.
|
||||
|
||||
Type safety
|
||||
-----------
|
||||
|
||||
|
||||
253
media/docs/pythonDSL/cute_dsl_general/naming_conventions.rst
Normal file
253
media/docs/pythonDSL/cute_dsl_general/naming_conventions.rst
Normal file
@@ -0,0 +1,253 @@
|
||||
.. _cute_dsl_naming_conventions:
|
||||
|
||||
CuTe DSL Naming Conventions
|
||||
===========================
|
||||
|
||||
This page summarizes the Hungarian-style naming conventions used for identifiers across the DSL examples and epilogue helpers: tensor partitions, per-thread copy-partitioners, copy atoms, and the axis-order suffixes that encode tensor layouts. It is meant as a lookup reference while reading example code — not as a style rule enforced on new code.
|
||||
|
||||
Memory/space scopes
|
||||
-------------------
|
||||
|
||||
- ``g``: Global memory view (GMEM), e.g., ``gB_nkl``, ``tTR_gC``
|
||||
- ``s``: Shared memory view (SMEM), e.g., ``sA``, ``tRS_sC``, ``bSG_sC``
|
||||
- ``r``: Register view (RMEM), e.g., ``tTR_rAcc``, ``tRS_rC``
|
||||
- ``t``: Tensor-memory view (TMEM), used for any TMEM-resident fragment or layout regardless of role. The classical case is the accumulator (``tCtAcc``, ``tTR_tAcc``). The same scope letter also appears for non-accumulator TMEM tensors such as ``tCtE``, ``tCtState``, ``tCtQState``, ``tCtShared``. Read the operand suffix to distinguish the role from the memory scope.
|
||||
|
||||
Per-thread/partitioned views and families
|
||||
-----------------------------------------
|
||||
|
||||
- ``tA…`` / ``tB…``: TMA load path for A/B
|
||||
|
||||
- ``tAgA`` / ``tAsA``: per-thread partitioned global/shared A for TMA load
|
||||
- ``tBgB`` / ``tBsB``: per-thread partitioned global/shared B for TMA load
|
||||
- NVFP4/FP8 scale factors mirror this: ``tAgSFA`` / ``tAsSFA``, ``tBgSFB`` / ``tBsSFB``
|
||||
|
||||
- ``tC…``: Compute/epilogue path for C/Acc
|
||||
|
||||
- ``tCgA`` / ``tCgB`` / ``tCgC``: per-thread partitions used by MMA/epilogue (derived from global tensors)
|
||||
- ``tCrA`` / ``tCrB``: per-thread fragments used by MMA (derived from SMEM A/B)
|
||||
- ``tCtAcc``: per-thread accumulator fragment/layout in TMEM
|
||||
- Additional ``tC*`` tensors follow the same schema for kernels that carry more than the classical A/B/C/Acc operands (see Operands and roles below): e.g. ``tCtState`` / ``tCtQState`` / ``tCtShared`` (gated-delta-net recurrent state in TMEM), ``tCrValpha`` / ``tCrVbeta`` / ``tCrVbias`` (EVT/EFC broadcast vectors in registers), ``tCtAccInter`` / ``tCtAccIntra`` (hierarchical accumulators)
|
||||
|
||||
.. {$nv-internal-release begin}
|
||||
|
||||
- Sparse GEMM additionally defines ``tCtE`` for the sparsity metadata tensor in TMEM (sm_140 / Feynman sparse GEMM, not yet released)
|
||||
|
||||
.. {$nv-internal-release end}
|
||||
|
||||
- ``tTM…``: Per-thread TMEM tiled-copy partitions used by FMHA/attention kernels (e.g. ``tTMrO`` as the register-side view of a TMEM load partitioned through ``thr_tmem_load``)
|
||||
|
||||
- Attention/MLA path families (``tQ…``, ``tK…``, ``tV…``, ``tP…``, ``tO…``): same schema as ``tA…`` / ``tB…`` / ``tC…`` but specialised to the Q/K/V/P/O operands of attention kernels, e.g.:
|
||||
|
||||
- ``tQsQ`` / ``tQgQ_qdl``: per-thread SMEM / GMEM partitions of Q for TMA load
|
||||
- ``tKrK`` / ``tVrV``: per-thread register fragments for K / V
|
||||
- ``tOtO`` / ``tOrO``: per-thread TMEM / register views of the attention output accumulator O
|
||||
- ``tPrP``: per-thread register fragment for the softmax probability matrix P
|
||||
|
||||
Data-movement copy paths
|
||||
------------------------
|
||||
|
||||
- ``tTR_*``: TMEM → Register (T2R)
|
||||
|
||||
- ``tTR_tAcc``: TMEM accumulator source for T2R
|
||||
- ``tTR_rAcc``: Register destination for T2R
|
||||
- ``tTR_gC``: When not using TMA store, Register → Global C destination partition
|
||||
|
||||
- ``tRS_*``: Register → Shared (R2S)
|
||||
|
||||
- ``tRS_rC``: Register source (C dtype)
|
||||
- ``tRS_sC``: Shared destination
|
||||
|
||||
- ``bSG_*``: Thread(b)lock partition for Shared → Global via TMA store
|
||||
|
||||
- ``bSG_sC``: Shared source for TMA store
|
||||
- ``bSG_gC``: Global destination for TMA store
|
||||
- Also used for accumulator in some flows: ``bSG_sAcc``, ``bSG_gAcc``
|
||||
- The same schema extends to additional store operands: ``bSG_sD`` / ``bSG_gD``, ``bSG_sP`` / ``bSG_gP``, ``bSG_sY`` / ``bSG_gY``
|
||||
|
||||
- ``bGS_*``: Thread(b)lock partition for Global → Shared via TMA **load** (the load-path mirror of ``bSG_*``)
|
||||
|
||||
- ``bGS_gC`` / ``bGS_sC``: Global source / Shared destination for TMA load of C-like operands (seen in EFC row/column broadcast prologues)
|
||||
|
||||
- ``simt_atom``: SIMT copy path used when TMA store is disabled (Register → Global)
|
||||
- Generic SIMT / tiled copy atoms ``<src>2<dst>_atom[_suffix]`` name the copy direction between two memory scopes:
|
||||
|
||||
- ``s2r_atom_*``: Shared → Register atom used in specialised epilogues and attention loads (e.g. ``s2r_atom_delta``, ``s2r_atom_cumsum``, ``s2r_atom_d`` in Mamba2 SSD)
|
||||
- ``r2s_atom``: Register → Shared atom
|
||||
- ``t2r_atom`` / ``r2t_atom``: Tensor memory ↔ Register atoms (paired with ``thr_tmem_load`` / ``thr_tmem_store``)
|
||||
- ``s2s_atom``: Shared → Shared atom (reshape/remap without register spill)
|
||||
- ``s2t``: Shared → Tensor memory atom
|
||||
|
||||
.. {$nv-internal-release begin}
|
||||
|
||||
- ``sp2t_copy_op_*``: Sparse source → Tensor memory copy op (sm_140 / Feynman sparse GEMM, not yet released: e.g. ``Sp2TAsACopyOp``, ``Sp2TAsECopyOp``)
|
||||
|
||||
.. {$nv-internal-release end}
|
||||
|
||||
- Custom ``autovec_copy`` paths appear where the DSL auto-vectorises a bespoke layout
|
||||
|
||||
Operands and roles
|
||||
------------------
|
||||
|
||||
- ``A``, ``B``, ``C``: GEMM operands
|
||||
- ``Acc``: Accumulator (TMEM/Register paths). Hierarchical MMA kernels split this into ``AccInter`` / ``AccIntra`` for the inter-/intra-CTA accumulator halves
|
||||
- Classical extra outputs / intermediates: ``D`` (additional output), ``Y`` (fused output), ``SFA`` / ``SFB`` (per-operand scale-factor arrays for NVFP4/FP8), ``SF`` (generic scale factor)
|
||||
- Attention / MLA operand letters (Q/K/V/P/O schema):
|
||||
|
||||
- ``Q`` (query), ``K`` (key), ``V`` (value), ``P`` (softmax probability / score matrix), ``O`` (attention output)
|
||||
- Variants: ``Kt`` / ``Vt`` for the transposed view of K/V, ``Qi`` / ``Ki`` / ``Vi`` for per-iteration slices, ``QK`` / ``PV`` / ``QKV`` where a single fragment spans multiple operands of the two back-to-back matmuls
|
||||
- Mamba / recurrent-state letters: ``Delta`` / ``DeltaA`` (time-step and A-decay), ``State`` / ``QState`` / ``Shared`` (gated-delta-net recurrent state tensors), ``Cumsumlog`` / ``Cumprod`` (running reductions), ``Gate``, ``DecayV``
|
||||
|
||||
.. {$nv-internal-release begin}
|
||||
|
||||
- Sparse-GEMM letters (sm_140 / Feynman, not yet released): ``E`` (sparsity metadata tensor in TMEM; paired with ``sp2t_*`` copy ops)
|
||||
|
||||
.. {$nv-internal-release end}
|
||||
|
||||
- EVT / EFC broadcast vectors: ``Valpha`` / ``Vbeta`` (alpha/beta scalars broadcast as vectors), ``Vbias`` (bias vector), ``Ainv`` (inverse of A for fused solvers)
|
||||
|
||||
.. {$nv-internal-release begin}
|
||||
|
||||
- LUT-based block-scaled GEMM letter (Rubin, not yet released): ``LutB`` (look-up-table operand)
|
||||
|
||||
.. {$nv-internal-release end}
|
||||
- Communication operands (multi-CTA / multicast flows): ``CommInMC`` / ``CommOutMC`` (multicast in/out), ``CommOutUC`` (unicast out)
|
||||
- Head-dimension variants: ``Dv`` (value head dimension when distinct from Q/K dim), ``Nv`` (number of value heads)
|
||||
|
||||
Axis-order suffixes
|
||||
-------------------
|
||||
|
||||
- Suffix encodes axis order of the view (lowercase letters each stand for one tensor mode):
|
||||
|
||||
- GEMM layouts use ``m``/``n``/``k``/``l``:
|
||||
|
||||
- ``_mnl``, ``_nkl``, ``_mkl``, … map to (M, N, K, L) ordering
|
||||
- Example: ``gB_nkl`` is B with axes (N, K, L); ``gC_mnl`` is C with (M, N, L)
|
||||
|
||||
- Attention / FMHA layouts use ``q``/``k``/``d``/``l`` (sequence-Q, sequence-K, head-dim, batch):
|
||||
|
||||
- ``mQ_qdl``: Q tensor with axes (SeqQ, HeadDim, Batch)
|
||||
- ``mK_kdl``: K tensor with axes (SeqK, HeadDim, Batch)
|
||||
- ``mV_dkl``: V tensor with axes (HeadDim, SeqK, Batch) — the ``d``-first order reflects the V-transpose that makes the second matmul (P·V) a standard row-major ``MxK·KxN``
|
||||
|
||||
- Lower-rank 2D slices drop the batch letter: ``_mn``, ``_mk``, ``_nk``
|
||||
|
||||
- Internally, CuTe layouts also expose grouped modes like ``MMA_M/N/K``, ``EPI_M/N``, ``RestM/N/K/L``, ``STAGE``, etc. (these are typically implementation details not directly used in example code).
|
||||
|
||||
Reading compound tokens
|
||||
-----------------------
|
||||
|
||||
- From left to right: ``[t|b][A|B|C|Q|K|V|P|O|TR|RS|SG|GS|TM]_[g|s|r|t][Operand/Role][AxisSuffix?]``
|
||||
|
||||
- ``t`` = per-thread/partitioned view; ``b`` = block/threadblock partition context
|
||||
- family/path letters:
|
||||
|
||||
- Operand-based: ``A`` / ``B`` / ``C`` (GEMM), ``Q`` / ``K`` / ``V`` / ``P`` / ``O`` (attention)
|
||||
- Direction-based: ``TR`` (TMEM → Register), ``RS`` (Register → Shared), ``SG`` (Shared → Global, store), ``GS`` (Global → Shared, load), ``TM`` (TMEM tiled-copy partition), ``R2G`` / ``S2R`` / ``T2R`` / ``R2T`` convenience aliases
|
||||
- memory = ``g``/``s``/``r``/``t``
|
||||
- operand/role = ``A``/``B``/``C``/``Acc``/``SFA``/``SFB``/``Q``/``K``/``V``/``P``/``O``/``E``/``State``/…
|
||||
- axis suffix = ``_mnl``, ``_nkl``, ``_qdl``, ``_kdl``, ``_dkl``, ``_mn``, … when applicable
|
||||
|
||||
- Per-thread-partitioner objects follow a parallel ``thr_*`` vocabulary, grouped by role:
|
||||
|
||||
- MMA partitioner: ``thr_mma``
|
||||
- Tiled-copy direction variants ``thr_copy_<src>2<dst>``: ``thr_copy_g2s``, ``thr_copy_s2r``, ``thr_copy_t2r``, ``thr_copy_r2s``, ``thr_copy_r2t``, ``thr_copy_s2t``
|
||||
- Role-qualified copy variants: ``thr_copy_sfa``, ``thr_copy_sfb``, ``thr_copy_load``, ``thr_copy_beta_g2s``
|
||||
- MMA variants for multi-matmul kernels: ``thr_mma_qk``, ``thr_mma_pv``, ``thr_mma_kv``, ``thr_mma_qkv``, ``thr_mma_intra1`` / ``thr_mma_intra2``, ``thr_mma_leader_cta``, ``thr_mma_sfb``
|
||||
- TMEM access partitioners: ``thr_tmem_load``, ``thr_tmem_store`` (with ``_stats`` / ``_vec`` suffix variants)
|
||||
|
||||
The tensor produced by ``thr_foo.partition_S(X)`` or ``.partition_D(X)`` is then named by the ``[t|b]FamilyPrefix_*`` convention above.
|
||||
|
||||
Concrete references
|
||||
-------------------
|
||||
|
||||
Open these files in the repository to see each pattern in context:
|
||||
|
||||
- TMA load partitions for A/B:
|
||||
|
||||
- ``tAgA``, ``tAsA``, ``tBgB``, ``tBsB``
|
||||
- ``CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py`` (around TMA partition of A/B)
|
||||
|
||||
- Accumulator fragment in TMEM:
|
||||
|
||||
- ``tCtAcc``
|
||||
- ``CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py`` (accumulator creation and use)
|
||||
|
||||
- TMEM → Register (T2R):
|
||||
|
||||
- ``tTR_tAcc``, ``tTR_rAcc``, ``tTR_gC``
|
||||
- ``CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py`` (``epilog_tmem_copy_and_partition``)
|
||||
|
||||
- Register → Shared (R2S):
|
||||
|
||||
- ``tRS_rC``, ``tRS_sC``
|
||||
- ``CuTeDSL/cute/blackwell/kernel/mixed_input_gemm/mixed_input_gemm.py`` (``epilog_smem_copy_and_partition``)
|
||||
|
||||
- Shared → Global via TMA store:
|
||||
|
||||
- ``bSG_sC``, ``bSG_gC``
|
||||
- ``CuTeDSL/cute/blackwell/kernel/blockscaled_gemm/dense_blockscaled_gemm_persistent.py`` (``epilog_gmem_copy_and_partition``)
|
||||
|
||||
- NVFP4/FP8 scale factors:
|
||||
|
||||
- ``tAgSFA``/``tAsSFA``, ``tBgSFB``/``tBsSFB``
|
||||
- ``CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/nvfp4_gemm_0.py`` (scale factor partition and usage)
|
||||
|
||||
- Additional examples across ``examples/``:
|
||||
|
||||
- Register → Global helper naming in MLA: ``tR2G_rO_src``, ``tR2G_rO_dst``
|
||||
- ``CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp16.py`` (output store section)
|
||||
|
||||
- Shared → Register SIMT atoms in Mamba2 SSD: ``s2r_atom_delta``, ``s2r_atom_cumsum``, ``s2r_atom_d``
|
||||
- ``CuTeDSL/cute/blackwell/kernel/attention/mamba2_ssd/mamba2_ssd.py`` (SMEM load paths for delta and D)
|
||||
|
||||
- ``thr_*`` slices for partitioning per-thread work: ``thr_mma``, ``thr_copy_t2r``, ``thr_copy_r2s``, etc.
|
||||
- ``CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py`` (``thr_mma``, ``thr_copy_t2r``, ``thr_copy_r2s``)
|
||||
|
||||
- Axis-order suffix examples:
|
||||
|
||||
- ``gB_nkl``, ``gC_mnl``
|
||||
- ``CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py`` (global tensor tiling and partitioning)
|
||||
|
||||
- Global → Shared (TMA load) block partition ``bGS_*``:
|
||||
|
||||
- ``bGS_gC``, ``bGS_sC``
|
||||
- ``CuTeDSL/cute/blackwell/efc/common_efc.py`` (row/column broadcast prologue building the C-like input for EVT)
|
||||
|
||||
- Attention Q/K/V/P/O families and ``_qdl`` / ``_kdl`` / ``_dkl`` axis suffixes:
|
||||
|
||||
- ``tQsQ``, ``tQgQ_qdl``, ``mK_kdl``, ``mV_dkl``
|
||||
- ``CuTeDSL/cute/hopper/kernel/attention/fmha.py`` (Q/K/V TMA partitions)
|
||||
- ``tOtO``, ``tOrO``, ``tPrP``
|
||||
- ``CuTeDSL/cute/blackwell/tutorial/tutorial_fmha/fmha_0.py`` (output and softmax fragments)
|
||||
- ``tKrK``, ``tVrV``
|
||||
- ``CuTeDSL/cute/blackwell/kernel/attention/mixed_input_fmha/mixed_input_fmha_decode.py`` (mixed-input K/V register fragments)
|
||||
|
||||
- TMEM tiled-copy ``tTM*`` family and the generalised ``<src>2<dst>_atom`` naming:
|
||||
|
||||
- ``tTMrO`` driven by ``thr_tmem_load``
|
||||
- ``CuTeDSL/cute/blackwell/tutorial/tutorial_fmha/fmha_0.py``
|
||||
|
||||
- Recurrent-state operands (``State`` / ``QState`` / ``Shared``) in TMEM:
|
||||
|
||||
- ``tCtState``, ``tCtQState``, ``tCtShared``
|
||||
- ``CuTeDSL/cute/blackwell/kernel/attention/gated_delta_net/gated_delta_net_chunked.py``
|
||||
|
||||
.. {$nv-internal-release begin}
|
||||
|
||||
- Sparse-metadata operand ``E`` and ``sp2t_*`` copy ops (sm_140 / Feynman, not yet released):
|
||||
|
||||
- ``tCtE``, ``sp2t_copy_op_A``, ``sp2t_copy_op_E``
|
||||
- ``CuTeDSL/internal/feynman/sm140_sparse_gemm.py`` and ``sm140_sparse_gemm_temporal_split_k.py``
|
||||
|
||||
- LUT-based block-scaled GEMM operand ``LutB`` (Rubin, not yet released):
|
||||
|
||||
- ``CuTeDSL/cute/rubin/kernel/blockscaled_gemm/dense_blockscaled_gemm_lut.py``
|
||||
- ``CuTeDSL/cute_ext/rubin/dense_gemm_lutb.py``
|
||||
|
||||
.. {$nv-internal-release end}
|
||||
|
||||
- Richer ``thr_*`` and ``thr_copy_*`` / ``thr_mma_*`` / ``thr_tmem_*`` partitioner taxonomy:
|
||||
|
||||
- ``thr_copy_g2s``, ``thr_copy_s2r``, ``thr_copy_s2t``, ``thr_copy_r2t``, ``thr_mma_qk``, ``thr_mma_pv``, ``thr_tmem_load``, ``thr_tmem_store``
|
||||
- The attention and Mamba2 examples above are the densest references; any ``fmha_*.py`` or ``mamba2_ssd.py`` file will show the full vocabulary in use
|
||||
11
media/docs/pythonDSL/mma_docs/intro.rst
Normal file
11
media/docs/pythonDSL/mma_docs/intro.rst
Normal file
@@ -0,0 +1,11 @@
|
||||
Architecture-specific MMA Programming Guides
|
||||
=============================================
|
||||
|
||||
This section contains architecture-specific MMA programming guides.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
wmma_programming
|
||||
wgmma_programming
|
||||
tcgen05_programming
|
||||
1528
media/docs/pythonDSL/mma_docs/tcgen05_programming.rst
Normal file
1528
media/docs/pythonDSL/mma_docs/tcgen05_programming.rst
Normal file
File diff suppressed because it is too large
Load Diff
987
media/docs/pythonDSL/mma_docs/wgmma_programming.rst
Normal file
987
media/docs/pythonDSL/mma_docs/wgmma_programming.rst
Normal file
@@ -0,0 +1,987 @@
|
||||
.. _wgmma_programming:
|
||||
|
||||
Warpgroup MMA Programming Guide
|
||||
================================
|
||||
|
||||
Hopper (SM90a) introduced the **warpgroup-level MMA** PTX instruction family
|
||||
``wgmma.mma_async.sync.aligned``. A warpgroup (128 threads / 4 warps)
|
||||
cooperates on one asynchronous ``D = A * B + C`` matrix multiply-accumulate.
|
||||
|
||||
Key architectural characteristics:
|
||||
|
||||
* **Warpgroup scope:** One MMA is issued collectively by a 128-thread
|
||||
warpgroup rather than by a single warp.
|
||||
* **Asynchronous issue model:** WGMMA instructions are ordered with
|
||||
``cute.nvgpu.warpgroup.fence()``, ``commit_group()``, and ``wait_group()``.
|
||||
* **Descriptor-based operand path:** Operand B is sourced from staged shared
|
||||
memory. Operand A can be sourced either from shared memory descriptors or
|
||||
from registers via ``OperandSource``.
|
||||
* **Register accumulator:** The accumulator lives in RMEM and serves as both
|
||||
the input C and output D of ``cute.gemm()``.
|
||||
* **Architecture-specific operand layouts:** F16/BF16 supports K-major and
|
||||
MN-major dense layouts when A comes from SMEM. FP8 and INT8 variants are
|
||||
K-major only.
|
||||
|
||||
The dense DSL op classes currently exposed are ``MmaF16BF16Op`` (F16/BF16),
|
||||
``MmaF8Op`` (FP8 E4M3/E5M2), and ``MmaI8Op`` (INT8/UINT8); see
|
||||
`Setting up the TiledMMA, MMA Ops`_ for their full constructor parameters,
|
||||
instruction K extents, and major-mode constraints.
|
||||
|
||||
This guide outlines the CuTe Python DSL programming model for WGMMA kernels:
|
||||
stage operands in SMEM, build fragment descriptors, launch asynchronous
|
||||
warpgroup MMAs, and stage the RMEM accumulator back to GMEM in the epilogue.
|
||||
|
||||
|
||||
.. contents:: **Contents**
|
||||
:local:
|
||||
:depth: 2
|
||||
|
||||
Global Memory (GMEM) to MMA data flow overview
|
||||
----------------------------------------------
|
||||
|
||||
WGMMA instructions require us to stage B input operands in Shared Memory (SMEM),
|
||||
while A input operands can be sourced from either SMEM or registers (RMEM).
|
||||
SMEM operands are read asynchronously by the hardware via SMEM descriptors.
|
||||
The accumulator is always kept in registers (RMEM) of the warpgroup.
|
||||
|
||||
The diagram below traces the full data flow of a WGMMA GEMM kernel, for the most
|
||||
common case where A and B matrices are stored in GMEM and both are staged through
|
||||
SMEM (``a_src=SMEM``), and the output matrix --accumulated in RMEM-- is written
|
||||
back to GMEM through an SMEM staging buffer.
|
||||
|
||||
There are 3 parallel tracks where each has 2 sub-tracks. Three parallel tracks are
|
||||
for operands A, B, and C/D, respectively. The two sub-tracks are for copying data between different memory
|
||||
spaces and for MMA execution.
|
||||
|
||||
- **Operand A** (and symmetrically **Operand B**):
|
||||
|
||||
- First, we need to create SMEM tensors for A and B matrices: ``sA`` and ``sB``. These
|
||||
tensors are physically allocated tensors that are the destination of TMA copy and
|
||||
the source operands for the WGMMA instructions.
|
||||
- Next the **data copy flow** creates the tensor views for copying data from GMEM to SMEM.
|
||||
It starts with ``mA`` tensor that represents the matrix A in global memory.
|
||||
Then ``mA`` → ``local_tile`` → ``gA`` operation creates the local tile view of A that is the
|
||||
slice of A matrix needed to compute the given CTA's output tile.
|
||||
Then ``tma_partition(tma, sA, gA)`` produces TMA views ``tAsA``, ``tAgA``,
|
||||
and the loop copies tiles from GMEM into SMEM via ``copy(tma, tAgA[k], tAsA[stage])``.
|
||||
- In parallel, the **MMA flow** turns the SMEM tensor into an iterable tensor of SMEM descriptors
|
||||
for the WGMMA instructions. ``sA`` (the same shared-memory allocation written by TMA)
|
||||
→ ``partition_A`` → ``tCsA`` (MMA-partitioned SMEM view)
|
||||
→ ``make_fragment_A`` → ``tCrA`` (SMEM descriptor passed to ``cute.gemm()``).
|
||||
Note that the SMEM descriptor is a view created from the SMEM tensor that is
|
||||
interpretable by the WGMMA instructions.
|
||||
|
||||
- **Accumulator C/D**:
|
||||
|
||||
- **RMEM accumulator flow** (gemm input/output): ``partition_C(gC)`` → ``tCgC``
|
||||
→ ``make_rmem_tensor(tCgC.shape)`` → ``acc``, which serves as both the accumulator
|
||||
input (C) and output (D) of ``cute.gemm()`` (and the WGMMA instruction).
|
||||
- **Output flow** (RMEM → SMEM → GMEM): After the main loop, the accumulator is
|
||||
type-converted and copied from registers to SMEM via ``stmatrix`` (R2S copy),
|
||||
then stored to global memory via TMA store (S2G copy):
|
||||
``mC`` → ``local_tile`` → ``gC`` → ``partition_C`` → ``tCgC`` on the destination side,
|
||||
and ``tRS_rAcc``/``tRS_sD`` / ``bSG_sD``/``bSG_gD`` views drive the two copy stages.
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
Operand A Dataflow Path Operand B Dataflow Path Accumulator C/D Dataflow Path
|
||||
─────────────────────── ─────────────────────── ─────────────────────────────
|
||||
|
||||
mA: (M, K) [GMEM] mB: (N, K) [GMEM] ┌──── RMEM ──────────┐
|
||||
│ │ │ make_rmem_tensor() │
|
||||
│ local_tile(mA, cta_tiler, coord) │ local_tile(mB, cta_tiler, coord) │ acc: accumulator │
|
||||
▼ ▼ └───────┬────────────┘
|
||||
gA: (BM, BK, k) [GMEM] gB: (BN, BK, k) [GMEM] │
|
||||
│ │ acc:(MMA,MMA_M,MMA_N) [RMEM]
|
||||
│ ┌──── SMEM ─────────┐ │ ┌──── SMEM ─────────┐ │
|
||||
│ │ sA = alloc(layout)│ │ │ sB = alloc(layout)│ │ mC: (M, N) [GMEM]
|
||||
│ └──┬────────┬───────┘ │ └──┬────────┬───────┘ │ │
|
||||
│ │ │ │ │ │ │ │ local_tile
|
||||
│ │ thr_mma.partition_A(sA) │ │ thr_mma.partition_B(sB) │ ▼
|
||||
│ │ ▼ │ │ ▼ │ gC: (BM, BN) [GMEM]
|
||||
│ │ tCsA:(MMA,MMA_M, │ │ tCsB:(MMA,MMA_N, │ │ partition_C
|
||||
│ │ MMA_K,PIPE) [SMEM] │ │ MMA_K,PIPE) [SMEM] │ ▼
|
||||
│ │ │ │ │ │ │ tCgC:(MMA,MMA_M,
|
||||
│ │ make_fragment_A(tCsA) │ │ make_fragment_B(tCsB) │ MMA_N)
|
||||
│ │ ▼ │ │ ▼ │ [GMEM] (epi dest)
|
||||
│ │ tCrA:(MMA,MMA_M, │ │ tCrB:(MMA,MMA_N, │ │
|
||||
│ │ MMA_K,PIPE) │ │ MMA_K,PIPE) │ │
|
||||
│ │ [SMEM descriptors] │ │ [SMEM descriptors] │ │
|
||||
│ │ └─────────────┐ │ │ └─────────────┐ │ │
|
||||
╰─────┤ │ ╰─────┤ │ │ │
|
||||
▼ │ ▼ │ │ │
|
||||
tma_partition(tma, │ tma_partition(tma, │ │ │
|
||||
sA, gA) │ sB, gB) │ │ │
|
||||
→ tAsA, tAgA │ → tBsB, tBgB │ │ │
|
||||
▼ │ ▼ │ │ │
|
||||
┌───┴────────────────────┐ │ ┌──────┴─────────────────┐│ │ │
|
||||
│ TMA copy loop (A path):│ │ │ TMA copy loop (B path):││ │ │
|
||||
│ copy(tma, tAgA[k], │ │ │ copy(tma, tBgB[k], ││ │ │
|
||||
│ tAsA[stage]) │ │ │ tBsB[stage]) ││ │ │
|
||||
┌─▶│ (writes into sA; │ │ ┌──▶│ (writes into sB; ││ │ │
|
||||
│ │ tCrA reads same sA) │ │ │ │ tCrB reads same sB) ││ │ │
|
||||
│ │ repeat for next k/stage│ │ │ │ repeat for next k/stage││ │ │
|
||||
│ └────────────────────────┘ │ │ └────────────────────────┘│ │ │
|
||||
│ │ │ │ │ │ │ │
|
||||
└────────┘ ▼ └─────────┘ ▼ ▼ │
|
||||
└───────┬───────────────────────────────┴───────────────────┘ │
|
||||
│ │
|
||||
▼ │
|
||||
┌──────────────────────────────────────────────┐ │
|
||||
│ GEMM Loop: | │
|
||||
│ warpgroup.fence() │ │
|
||||
│ cute.gemm(tiled_mma, │ │
|
||||
│ acc, D (output, RMEM), │ │
|
||||
┌──▶ │ tCrA[stage], A (SMEM desc -> sA), │ │
|
||||
│ │ tCrB[stage], B (SMEM desc -> sB), │ │
|
||||
│ │ acc) C (accumulator, RMEM) │ │
|
||||
│ │ warpgroup.commit_group() │ │
|
||||
│ │ warpgroup.wait_group(n) │ │
|
||||
│ └──────────────────────────────────────────────┘ │
|
||||
│ │ │ │
|
||||
└───────┘ | │
|
||||
▼ │
|
||||
Epilogue: │
|
||||
tRS_rAcc = retile(acc) │
|
||||
tRS_rD = type_convert(tRS_rAcc) │
|
||||
│ │
|
||||
▼ │
|
||||
R2S: copy(tiled_copy_r2s, tRS_rD, tRS_sD) │
|
||||
[RMEM → SMEM via stmatrix] │
|
||||
│ │
|
||||
▼ │
|
||||
sC = alloc(epi_layout) [SMEM] │
|
||||
bSG_sD, bSG_gD = tma_partition(tma_c, sC, gC) ◀───────────────────┘
|
||||
│
|
||||
▼
|
||||
S2G: copy(tma_c, bSG_sD[stage], bSG_gD[coord])
|
||||
[SMEM → GMEM via TMA store]
|
||||
|
||||
**Naming convention:**
|
||||
|
||||
* cta_tiler = (BM, BN, BK) = CTA-wide tiler dimensions
|
||||
* ``mX`` = a global tensor, e.g., (M, K) for A
|
||||
* ``gX`` = CTA-tiled GMEM slice, e.g., (BM, BK, k) for A
|
||||
* ``sX`` = SMEM allocation, e.g., (BM, BK, PIPE) for A
|
||||
* ``tAsA``/``tBsB`` = TMA-partitioned SMEM views
|
||||
* ``tAgA``/``tBgB`` = TMA-partitioned GMEM views
|
||||
* ``tCsX`` = MMA-partitioned SMEM view, e.g., (MMA, MMA_M, MMA_K, PIPE) for A
|
||||
* ``tCrX`` = SMEM descriptor fragment, e.g., (MMA, MMA_M, MMA_K, PIPE) for A
|
||||
* ``acc`` = RMEM accumulator, (MMA, MMA_M, MMA_N)
|
||||
* ``tCgC`` = MMA-partitioned GMEM, (MMA, MMA_M, MMA_N)
|
||||
* ``tRS_rAcc``/``tRS_sD`` = epilogue retile views for R2S (RMEM → SMEM) copy
|
||||
* ``bSG_sD``/``bSG_gD`` = TMA-partitioned SMEM/GMEM views for epilogue store
|
||||
* MMA = warpgroup atom thread-value layout; MMA_M/MMA_N/MMA_K = repeat counts
|
||||
(e.g., BM/inst_M), k = outer K-tiles, PIPE = pipeline stages
|
||||
|
||||
Setting up the TiledMMA, MMA Ops
|
||||
---------------------------------
|
||||
|
||||
As shown in the data flow overview, CuTe DSL provides many utilities to tile/partition
|
||||
the global memory tensors, and create fragment views of SMEM tensors for MMA instructions.
|
||||
|
||||
To utilize these functions, we need to setup the TiledMMA, MMA Ops first.
|
||||
|
||||
Creating a WGMMA Op
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
A WGMMA op describes the hardware instruction to use, it has parameters like
|
||||
data types, instruction shape, operand A source (SMEM or RMEM),
|
||||
and operand major modes.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.nvgpu import OperandMajorMode
|
||||
import cutlass.cute.nvgpu.warpgroup as warpgroup
|
||||
|
||||
op = warpgroup.MmaF16BF16Op(
|
||||
cutlass.Float16, # A/B element type
|
||||
cutlass.Float32, # accumulator type
|
||||
(64, 128, 16), # instruction shape (M, N, K)
|
||||
warpgroup.OperandSource.SMEM, # A operand from shared memory
|
||||
OperandMajorMode.K, # A is K-major
|
||||
OperandMajorMode.K, # B is K-major
|
||||
)
|
||||
|
||||
The key parameters are:
|
||||
|
||||
- **Instruction shape** ``(M, N, K)``: determines the size of one hardware MMA
|
||||
instruction. WGMMA requires ``M = 64`` and ``8 <= N <= 256`` in steps of 8.
|
||||
K is fixed by the op class (16 for F16/BF16, 32 for FP8 and INT8).
|
||||
- **OperandSource**: ``SMEM`` reads A from a shared memory descriptor; ``RMEM``
|
||||
reads A directly from registers.
|
||||
- **OperandMajorMode**: ``K`` for K-major (default), ``MN`` for transposed layout.
|
||||
F16/BF16 supports both K-major and MN-major for A and B when ``a_src=SMEM``;
|
||||
when ``a_src=RMEM``, only B can be transposed. FP8 and INT8 are K-major only.
|
||||
|
||||
|
||||
CuTe DSL provides implementation of the following WGMMA ops:
|
||||
|
||||
.. list-table:: WGMMA ops
|
||||
:header-rows: 1
|
||||
:widths: 30 24 46
|
||||
|
||||
* - PTX name
|
||||
- Python class
|
||||
- Constructor parameters
|
||||
* - ``wgmma.mma_async.m64n{N}k16.{acc}.f16.f16`` / ``.bf16.bf16``
|
||||
- ``warpgroup.MmaF16BF16Op``
|
||||
- ``ab_dtype, acc_dtype, instruction_shape, a_src, a_major_mode, b_major_mode``
|
||||
* - ``wgmma.mma_async.m64n{N}k32.{acc}.{e4m3|e5m2}.{e4m3|e5m2}``
|
||||
- ``warpgroup.MmaF8Op``
|
||||
- ``a_dtype, b_dtype, acc_dtype, instruction_shape, a_src, a_major_mode, b_major_mode``
|
||||
* - ``wgmma.mma_async.m64n{N}k32.s32.{s8|u8}.{s8|u8}``
|
||||
- ``warpgroup.MmaI8Op``
|
||||
- ``a_dtype, b_dtype, acc_dtype, instruction_shape, a_src, a_major_mode, b_major_mode``
|
||||
|
||||
|
||||
Creating a Tiled MMA
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
A ``TiledMma`` tiles the WGMMA atom across the CTA tile. You can pass the op
|
||||
directly or create an explicit atom first.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Option 1: directly from op (common shorthand)
|
||||
tiled_mma = cute.make_tiled_mma(op)
|
||||
|
||||
# Option 2: explicit atom creation
|
||||
atom = cute.make_mma_atom(op)
|
||||
tiled_mma = cute.make_tiled_mma(atom)
|
||||
|
||||
Spatial tiling with a repeat count
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
A repeat tuple ``(M_rep, N_rep, K_rep)`` replicates the WGMMA atom across the
|
||||
M, N, and K dimensions, producing a larger tiled MMA that covers a bigger CTA
|
||||
tile with a single ``cute.gemm`` call. Each entry in the repeat tuple
|
||||
corresponds to one **warpgroup** (128 threads / 4 warps), so ``(2, 1, 1)``
|
||||
uses two warpgroups — the standard configuration for large Hopper tiles:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
atom = cute.make_mma_atom(op) # op shape: (64, 128, 16)
|
||||
tiled_mma = cute.make_tiled_mma(
|
||||
atom,
|
||||
atom_layout_mnk=(2, 1, 1), # 2 warpgroups in M
|
||||
)
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
WGMMA Atom make_tiled_mma(atom, (2, 1, 1))
|
||||
+---------------+ +----------------+
|
||||
| | | | ^
|
||||
| 64 x 128 | | Atom (0,0,0) | |
|
||||
| x 16 | --(2,1,1)--> | 64 x 128 | | 2 x M_atom
|
||||
| | repeat | x 16 | | = 128
|
||||
| | | [Warpgroup 0] | |
|
||||
+---------------+ +----------------+ |
|
||||
| | |
|
||||
| Atom (1,0,0) | |
|
||||
| 64 x 128 | |
|
||||
| x 16 | |
|
||||
| [Warpgroup 1] | v
|
||||
+----------------+
|
||||
<-- N_atom = 128 -->
|
||||
K unchanged = 16
|
||||
|
||||
The Hopper dense GEMM examples
|
||||
(``examples/cute/hopper/kernel/dense_gemm/dense_gemm.py``) use this pattern.
|
||||
The helper ``sm90_utils.make_trivial_tiled_mma(...)`` selects the repeat count
|
||||
automatically:
|
||||
|
||||
- ``atom_layout_mnk = (2, 1, 1)`` when both ``tile_M > 64`` and
|
||||
``tile_N > 128`` (two warpgroups reduce register pressure).
|
||||
- ``atom_layout_mnk = (1, 1, 1)`` otherwise (a single warpgroup suffices).
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import cutlass.utils.hopper_helpers as sm90_utils
|
||||
|
||||
tiled_mma = sm90_utils.make_trivial_tiled_mma(
|
||||
a_dtype,
|
||||
b_dtype,
|
||||
a_major_mode,
|
||||
b_major_mode,
|
||||
acc_dtype,
|
||||
atom_layout_mnk=(2, 1, 1),
|
||||
tiler_mn=(64, 128), # atom instruction shape (M, N)
|
||||
)
|
||||
|
||||
Custom tile permutation with ``permutation_mnk``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
``make_tiled_mma`` also accepts an optional ``permutation_mnk`` argument that
|
||||
controls how the tiled atom footprint is laid out across M, N, and K. At a
|
||||
high level:
|
||||
|
||||
- ``atom_layout_mnk`` tells CuTe how many atoms (warpgroups) to replicate.
|
||||
- ``permutation_mnk`` tells CuTe how the final tiled footprint is ordered.
|
||||
|
||||
``permutation_mnk`` is a tuple of layouts or integers that represent the
|
||||
tile size and ordering of values along each dimension. When a mode's
|
||||
permutation size is larger than the atom layout's natural coverage
|
||||
(``atom_layout x inst_shape``), each warpgroup receives additional values
|
||||
to fill the extended region — the warpgroup count stays the same, but each
|
||||
warpgroup holds more data.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
atom = cute.make_mma_atom(op) # op shape: (64, 128, 16)
|
||||
tiled_mma = cute.make_tiled_mma(
|
||||
atom,
|
||||
atom_layout_mnk=(2, 1, 1),
|
||||
permutation_mnk=(128, 256, 16), # extend N from 128 to 256
|
||||
)
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
Without permutation — natural atom coverage (M = 128, N = 128):
|
||||
|
||||
C tile (M=128, N=128)
|
||||
+----------------+
|
||||
| | ^
|
||||
| [Warpgroup 0] | |
|
||||
| 64 x 128 | | 2 x inst_M
|
||||
| | | = 128
|
||||
+----------------+ |
|
||||
| | |
|
||||
| [Warpgroup 1] | |
|
||||
| 64 x 128 | |
|
||||
| | v
|
||||
+----------------+
|
||||
<--- N = 128 --->
|
||||
(each warpgroup owns one (64, 128) atom)
|
||||
|
||||
With permutation_mnk = (128, 256, 16) — N extended to 256:
|
||||
|
||||
C tile (M=128, N=256)
|
||||
+----------------+----------------+
|
||||
| | | ^ N = 128 → 256:
|
||||
| [Warpgroup 0] | [Warpgroup 0] | | atom pattern repeats
|
||||
| 64 x 128 | 64 x 128 | | along N. Each warpgroup
|
||||
| | | | now holds 2x the values
|
||||
+----------------+----------------+ | along N (same threads,
|
||||
| | | | more data).
|
||||
| [Warpgroup 1] | [Warpgroup 1] | |
|
||||
| 64 x 128 | 64 x 128 | |
|
||||
| | | v
|
||||
+----------------+----------------+
|
||||
<------------ N = 256 ------------>
|
||||
| atom coverage | value repeat |
|
||||
|
||||
**Why WGMMA typically does not need permutation_mnk:** The WGMMA
|
||||
instruction already has a large N dimension (64, 128, or 256), so the natural
|
||||
atom coverage is wide enough that no permutation is needed to align with SMEM
|
||||
swizzle widths. The Hopper
|
||||
dense GEMM examples (``dense_gemm.py``, ``dense_gemm_persistent.py``) use
|
||||
``atom_layout_mnk`` alone without ``permutation_mnk``.
|
||||
|
||||
When ``permutation_mnk`` is not provided (default), the tile ordering is
|
||||
sequential and no permutation is applied.
|
||||
|
||||
|
||||
Partitioning Tensors
|
||||
---------------------
|
||||
|
||||
Before computing, partition the CTA-tiled tensors according to the
|
||||
tiled MMA layout. WGMMA partitioning is **warpgroup-oriented**: each
|
||||
warpgroup (128 threads / 4 warps) receives its own slice of the CTA
|
||||
tile, sized to match the SMEM descriptors and register accumulators
|
||||
that the WGMMA instruction expects.
|
||||
|
||||
**2-warpgroup example**
|
||||
|
||||
``GEMM (M, N, K) = (512, 768, 256)``, ``tile_shape_mnk = (128, 256, 64)``,
|
||||
F16 WGMMA atom = (64, 256, 16), ``atom_layout_mnk = (2, 1, 1)``,
|
||||
``num_stages = 4``, 2 warpgroups = 256 threads.
|
||||
|
||||
Global matrices:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
mA: (M, K) = (512, 256) mB: (N, K) = (768, 256) mC: (M, N) = (512, 768)
|
||||
|
||||
K=256 K=256 N=768
|
||||
|<--------->| |<--------->| |<----------------->|
|
||||
+-----------+ +-----------+ +---+---+---+-------+
|
||||
| | ^ | | ^ | | | | | ^
|
||||
| mA | | M=512 | mB | | N=768 | | | | | | M=512
|
||||
| | v | | v | | | | | v
|
||||
+-----------+ +-----------+ +---+---+---+-------+
|
||||
|
||||
Tiling with ``tile_shape_mnk = (BM, BN, BK) = (128, 256, 64)`` gives
|
||||
M/BM = 4 tiles, N/BN = 3 tiles, K/BK = 4 tiles:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
mA tiled into (M/BM x K/BK) mB tiled into (N/BN x K/BK) mC tiled into (M/BM x N/BN)
|
||||
= (4 x 4) blocks = (3 x 4) blocks = (4 x 3) blocks
|
||||
|
||||
BK=64 x4 BK=64 x4 BN=256 x3
|
||||
|<--->| |<--->| |<------>|
|
||||
+-----+-----+-----+-----+ +-----+-----+-----+-----+ +--------+--------+--------+
|
||||
| | | | | ^ | | | | | ^ | (0,0) | (0,1) | (0,2) | ^
|
||||
| | | | | |128 | | | | | |256 | | | | |128
|
||||
+-----+-----+-----+-----+ v +-----+-----+-----+-----+ v +--------+--------+--------+ v
|
||||
| | | | | ^ | | | | | ^ | (1,0) | (1,1) | (1,2) | ^
|
||||
| | | | | |128 | | | | | |256 | | | | |128
|
||||
+-----+-----+-----+-----+ v +-----+-----+-----+-----+ v +--------+--------+--------+ v
|
||||
| | | | | | | | | | | (2,0) | (2,1) | (2,2) |
|
||||
+-----+-----+-----+-----+ +-----+-----+-----+-----+ +--------+--------+--------+
|
||||
| | | | | | (3,0) | (3,1) | (3,2) |
|
||||
+-----+-----+-----+-----+ +--------+--------+--------+
|
||||
|
||||
Each CTA picks one (M-tile, N-tile) coordinate.
|
||||
For example, CTA at ``tile_coord = (1, 0, :)``.
|
||||
|
||||
After ``local_tile`` — one CTA's tile (``k = K/BK = 256/64 = 4``):
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
gA: (BM, BK, k) = (128, 64, 4) gB: (BN, BK, k) = (256, 64, 4) gC: (BM, BN) = (128, 256)
|
||||
|
||||
BK=64 BK=64 BN=256
|
||||
|<----->| |<----->| |<--------->|
|
||||
+-------+-- +-------+-- +-----------+
|
||||
| |.. | |.. | | ^
|
||||
BM= | gA | k=4 BN= | gB | k=4 BM= | gC | | 128
|
||||
128 | | 256 | | 128 | | v
|
||||
+-------+ +-------+ +-----------+
|
||||
|
||||
SMEM tensors ``sA`` and ``sB`` include a pipeline staging dimension:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
sA: (BM, BK, PIPE) = (128, 64, 4) sB: (BN, BK, PIPE) = (256, 64, 4)
|
||||
|
||||
``get_slice(warp_group_thread_layout(warp_group_idx))`` — each
|
||||
warpgroup receives its slice of the tiled MMA footprint.
|
||||
With ``atom_layout_mnk = (2, 1, 1)`` and inst shape ``(64, 256, 16)``,
|
||||
the tiled MMA covers ``(2x64, 1x256, 16) = (128, 256, 16)`` which
|
||||
exactly matches the CTA tile in M and N. Each warpgroup owns one
|
||||
64-row slice of M:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
sA (one pipeline stage, BM=128, BK=64):
|
||||
|
||||
Warpgroup 0's slice Warpgroup 1's slice
|
||||
inst_K inst_K inst_K inst_K
|
||||
=16 =16 =16 =16
|
||||
|<--->|<--->|<--->|<--->| |<--->|<--->|<--->|<--->|
|
||||
+-----+-----+-----+-----+ ^ +-----+-----+-----+-----+ ^
|
||||
| 0 | 1 | 2 | 3 | |64 | 0 | 1 | 2 | 3 | |64
|
||||
+-----+-----+-----+-----+ v +-----+-----+-----+-----+ v
|
||||
|<-- MMA_K = BK/inst_K = 4 -->| |<-- MMA_K = 4 ---------->|
|
||||
MMA_M = 64/64 = 1 MMA_M = 64/64 = 1
|
||||
|
||||
gC (BM=128, BN=256):
|
||||
|
||||
+---------------------------+ ^
|
||||
| Warpgroup 0: 64 x 256 | | 64
|
||||
| | |
|
||||
+---------------------------+ v
|
||||
| Warpgroup 1: 64 x 256 | ^
|
||||
| | | 64
|
||||
+---------------------------+ v
|
||||
<--------- N = 256 -------->
|
||||
MMA_M = 64/64 = 1, MMA_N = 256/256 = 1
|
||||
|
||||
After partition (per warpgroup):
|
||||
|
||||
- ``tCsA: (MMA, MMA_M, MMA_K, PIPE) = (MMA, 1, 4, 4)`` — MMA_M = BM / (atom_M x inst_M) = 128 / (2x64) = 1, MMA_K = BK / inst_K = 64 / 16 = 4
|
||||
- ``tCsB: (MMA, MMA_N, MMA_K, PIPE) = (MMA, 1, 4, 4)`` — MMA_N = BN / (atom_N x inst_N) = 256 / (1x256) = 1, MMA_K = 4
|
||||
- ``tCgC: (MMA, MMA_M, MMA_N) = (MMA, 1, 1)`` — MMA_M = 1, MMA_N = 1
|
||||
|
||||
The first mode ``MMA`` contains the atom's **thread x value** layout — it
|
||||
encodes which registers within a warpgroup hold which matrix elements.
|
||||
The remaining modes are repeat counts that tile the atom across the
|
||||
full CTA tile.
|
||||
|
||||
.. note:: Because the WGMMA instruction shape is large (64 x {64..256}),
|
||||
the tiled MMA footprint typically covers the entire CTA tile in M and N
|
||||
with just one or two warpgroups. This means MMA_M and MMA_N are often 1.
|
||||
The MMA_K dimension is where the repeat count is non-trivial (BK / inst_K
|
||||
iterations per pipeline stage).
|
||||
|
||||
**1-warpgroup example (contrast)**
|
||||
|
||||
For a smaller tile ``(128, 128, 64)`` with ``atom_layout_mnk = (1, 1, 1)``,
|
||||
inst shape ``(64, 128, 16)``, and ``num_stages = 4``,
|
||||
the tiled MMA covers only ``(64, 128, 16)``.
|
||||
Now a single warpgroup must iterate over two atom-blocks along M:
|
||||
|
||||
- ``tCsA: (MMA, MMA_M, MMA_K, PIPE) = (MMA, 2, 4, 4)`` — MMA_M = 128 / (1x64) = 2
|
||||
- ``tCsB: (MMA, MMA_N, MMA_K, PIPE) = (MMA, 1, 4, 4)`` — MMA_N = 128 / (1x128) = 1
|
||||
- ``tCgC: (MMA, MMA_M, MMA_N) = (MMA, 2, 1)``
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Based on examples/cute/hopper/kernel/dense_gemm/dense_gemm.py
|
||||
@cute.kernel
|
||||
def kernel(tiled_mma: cute.TiledMma, ...):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
|
||||
# CTA-tiled global tensors
|
||||
gA_mkl = cute.local_tile(
|
||||
mA_mkl, tile_shape_mnk, tile_coord_mnkl, proj=(1, None, 1)
|
||||
)
|
||||
gB_nkl = cute.local_tile(
|
||||
mB_nkl, tile_shape_mnk, tile_coord_mnkl, proj=(None, 1, 1)
|
||||
)
|
||||
gC_mnl = cute.local_tile(
|
||||
mC_mnl, tile_shape_mnk, tile_coord_mnkl, proj=(1, 1, None)
|
||||
)
|
||||
|
||||
# Warpgroup-oriented slicing (128 threads per warpgroup)
|
||||
warp_group_idx = cute.arch.make_warp_uniform(
|
||||
tidx // num_threads_per_warp_group # 128
|
||||
)
|
||||
warp_group_thread_layout = cute.make_layout(
|
||||
mma_warp_groups, # e.g. 2
|
||||
stride=num_threads_per_warp_group, # 128
|
||||
)
|
||||
thr_mma = tiled_mma.get_slice(
|
||||
warp_group_thread_layout(warp_group_idx)
|
||||
)
|
||||
|
||||
# Partition C from global
|
||||
tCgC = thr_mma.partition_C(gC_mnl) # (MMA, MMA_M, MMA_N)
|
||||
|
||||
# Partition A/B from staged SMEM
|
||||
tCsA = thr_mma.partition_A(sA) # (MMA, MMA_M, MMA_K, PIPE)
|
||||
tCsB = thr_mma.partition_B(sB) # (MMA, MMA_N, MMA_K, PIPE)
|
||||
|
||||
|
||||
Pre and Post-Conditions for Partitioning
|
||||
-----------------------------------------
|
||||
|
||||
* The inputs of ``partition_A``, ``partition_B``, and ``partition_C`` should be
|
||||
at least rank-2 tensors.
|
||||
* The output layout is constrained by the selected MMA atom:
|
||||
|
||||
- For A, the output has layout ``(MMA, MMA_M, MMA_K, ...)``.
|
||||
- For B, the output has layout ``(MMA, MMA_N, MMA_K, ...)``.
|
||||
- For C, the output has layout ``(MMA, MMA_M, MMA_N, ...)``.
|
||||
|
||||
* Partitioning reasons about layout, not memory space or element type.
|
||||
When ``a_src=OperandSource.RMEM``, the same tiled MMA shape still
|
||||
determines the logical A footprint, but A is materialized as a register
|
||||
fragment rather than a shared-memory descriptor.
|
||||
|
||||
|
||||
Making Fragments
|
||||
-----------------
|
||||
|
||||
Fragments are the tensors that the WGMMA instruction operates on. For dense
|
||||
WGMMA:
|
||||
|
||||
- **Fragment A**: an SMEM descriptor when ``a_src=OperandSource.SMEM``, or an
|
||||
RMEM register fragment when ``a_src=OperandSource.RMEM``.
|
||||
- **Fragment B**: an SMEM descriptor pointing into staged shared memory buffers.
|
||||
- **Fragment C (accumulator)**: an RMEM tensor that serves as both the input C
|
||||
and output D of ``cute.gemm()``.
|
||||
|
||||
WGMMA fragments for A and B are **SMEM descriptors** — the hardware reads
|
||||
directly from shared memory. There is no explicit SMEM → RMEM copy step for
|
||||
operands A and B. The accumulator, however, still lives in per-thread
|
||||
registers (RMEM).
|
||||
|
||||
Creating fragment descriptors and accumulator fragments
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Fragment creation has two parts:
|
||||
|
||||
**1. A and B fragment descriptors**
|
||||
|
||||
``make_fragment_A`` and ``make_fragment_B`` take the MMA-partitioned SMEM
|
||||
views (``tCsA`` / ``tCsB``) and produce descriptor tensors that the WGMMA
|
||||
instruction consumes. Each descriptor points to one tile within a pipeline
|
||||
stage in shared memory.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# MMA-partitioned SMEM views (see "Partitioning Tensors")
|
||||
tCsA = thr_mma.partition_A(sA) # (MMA, MMA_M, MMA_K, PIPE)
|
||||
tCsB = thr_mma.partition_B(sB) # (MMA, MMA_N, MMA_K, PIPE)
|
||||
|
||||
# SMEM descriptor fragments consumed by cute.gemm()
|
||||
tCrA = tiled_mma.make_fragment_A(tCsA) # (MMA, MMA_M, MMA_K, PIPE)
|
||||
tCrB = tiled_mma.make_fragment_B(tCsB) # (MMA, MMA_N, MMA_K, PIPE)
|
||||
|
||||
Continuing the 2-warpgroup example from `Partitioning Tensors`_
|
||||
(F16 atom = (64, 256, 16), ``tile_shape_mnk = (128, 256, 64)``,
|
||||
``atom_layout_mnk = (2, 1, 1)``, ``num_stages = 4``):
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
tCsA: (MMA, MMA_M=1, MMA_K=4, PIPE=4)
|
||||
tCsB: (MMA, MMA_N=1, MMA_K=4, PIPE=4)
|
||||
|
||||
make_fragment_A(tCsA) -> tCrA: (MMA, 1, 4, 4)
|
||||
make_fragment_B(tCsB) -> tCrB: (MMA, 1, 4, 4)
|
||||
|
||||
Each element of tCrA/tCrB is an SMEM descriptor — one per
|
||||
(MMA_K, PIPE) pair. The hardware reads SMEM directly via the
|
||||
descriptor; no explicit SMEM -> RMEM load is needed.
|
||||
|
||||
tCrA per warpgroup (4 pipeline stages, 4 K-blocks each):
|
||||
|
||||
|<-- MMA_K = BK/inst_K = 4 -->|
|
||||
stage 0: +------+------+------+------+
|
||||
| k=0 | k=1 | k=2 | k=3 | inst_M=64 (MMA_M=1)
|
||||
+------+------+------+------+
|
||||
stage 1: +------+------+------+------+
|
||||
| k=0 | k=1 | k=2 | k=3 | inst_M=64
|
||||
+------+------+------+------+
|
||||
stage 2: +------+------+------+------+
|
||||
| k=0 | k=1 | k=2 | k=3 | inst_M=64
|
||||
+------+------+------+------+
|
||||
stage 3: +------+------+------+------+
|
||||
| k=0 | k=1 | k=2 | k=3 | inst_M=64
|
||||
+------+------+------+------+
|
||||
|
||||
Similarly for tCrB with shape (MMA, MMA_N=1, MMA_K=4, PIPE=4).
|
||||
|
||||
.. note:: WGMMA fragments for A and B are SMEM descriptors — the hardware
|
||||
reads SMEM directly, so there is no ``ldmatrix`` retiling step required
|
||||
before ``cute.gemm()``.
|
||||
|
||||
**When A comes from registers (``OperandSource.RMEM``)**
|
||||
|
||||
In fused kernels, the output of one MMA can become the A operand of the
|
||||
next. The second ``TiledMma`` is created with
|
||||
``a_src=OperandSource.RMEM``, and ``make_fragment_A`` is **not** used.
|
||||
Instead:
|
||||
|
||||
1. The accumulator's C layout ``(MMA, MMA_M, MMA_N)`` is converted to the
|
||||
A layout ``(MMA, MMA_M, MMA_K)`` expected by the second ``TiledMma``.
|
||||
2. The accumulator values are type-converted and stored into an RMEM tensor
|
||||
with the A layout.
|
||||
3. The resulting RMEM tensor is passed directly to ``cute.gemm()`` as the A
|
||||
operand — no SMEM descriptor is involved.
|
||||
|
||||
See the Hopper FMHA example (``examples/cute/hopper/kernel/attention/fmha.py``) for the complete pattern.
|
||||
|
||||
**2. C fragment (accumulator)**
|
||||
|
||||
The accumulator lives in per-thread registers (RMEM). Its shape is derived
|
||||
from the partitioned C layout. The accumulator starts at zero before the K
|
||||
loop and is updated in-place by each ``cute.gemm()`` call.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Partition C from global (see "Partitioning Tensors")
|
||||
tCgC = thr_mma.partition_C(gC_mnl) # (MMA, MMA_M, MMA_N)
|
||||
|
||||
# Allocate RMEM accumulator with the same shape
|
||||
acc_shape = tCgC.shape
|
||||
acc = cute.make_rmem_tensor(acc_shape, cutlass.Float32)
|
||||
acc.fill(0.0)
|
||||
|
||||
For the same running example:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
tCgC: (MMA, MMA_M=1, MMA_N=1)
|
||||
|
||||
make_rmem_tensor(tCgC.shape, Float32) -> acc: (MMA, 1, 1)
|
||||
|
||||
The accumulator stays in RMEM for the entire main loop.
|
||||
cute.gemm() reads A/B from SMEM descriptors and accumulates into acc.
|
||||
|
||||
+-----------------------------------+
|
||||
| acc: (MMA, 1, 1) in RMEM |
|
||||
| 64 x 256 elements per warpgroup |
|
||||
| Float32 |
|
||||
+-----------------------------------+
|
||||
|
||||
|
||||
Creating SMEM layouts for A and B
|
||||
----------------------------------
|
||||
|
||||
The SMEM layouts define how A and B tiles are staged in shared memory,
|
||||
including swizzling for bank-conflict-free descriptor access. The helper
|
||||
functions in ``cutlass.utils.hopper_helpers`` handle the details.
|
||||
|
||||
**Host side** (``@cute.jit``):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import cutlass.utils.hopper_helpers as sm90_utils
|
||||
|
||||
# Create SMEM layouts (includes swizzle + staging)
|
||||
a_smem_layout = sm90_utils.make_smem_layout_a(
|
||||
a_layout, # LayoutEnum — row-major or col-major
|
||||
tile_shape_mnk, # CTA tile (M, N, K)
|
||||
a_dtype, # element type (e.g. Float16)
|
||||
num_stages, # pipeline depth
|
||||
)
|
||||
b_smem_layout = sm90_utils.make_smem_layout_b(
|
||||
b_layout,
|
||||
tile_shape_mnk,
|
||||
b_dtype,
|
||||
num_stages,
|
||||
)
|
||||
epi_smem_layout = sm90_utils.make_smem_layout_epi(
|
||||
c_dtype,
|
||||
c_layout,
|
||||
epi_tile,
|
||||
epi_stage,
|
||||
)
|
||||
|
||||
``make_smem_layout_a`` and ``make_smem_layout_b`` are convenience helpers that
|
||||
build a complete, staged SMEM layout in four steps:
|
||||
|
||||
1. **Extract the operand tile shape.** For A the ``(M, K)`` portion of
|
||||
``tile_shape_mnk`` is kept via ``cute.slice_``; for B the ``(N, K)``
|
||||
portion.
|
||||
|
||||
2. **Determine the major mode.** The major mode (K-major or MN-major) is read
|
||||
from the layout enum (``a_layout.is_k_major_a()``). The major-mode
|
||||
dimension size is used for swizzle selection.
|
||||
|
||||
3. **Select and materialise the swizzle atom.** A heuristic
|
||||
(``get_smem_layout_atom``) picks the widest swizzle whose contiguous
|
||||
size (in bits) evenly divides the major-mode dimension:
|
||||
|
||||
+------------+-----------------+
|
||||
| Swizzle | Contiguous bits |
|
||||
+============+=================+
|
||||
| SW128 | 1024 (128 B) |
|
||||
+------------+-----------------+
|
||||
| SW64 | 512 (64 B) |
|
||||
+------------+-----------------+
|
||||
| SW32 | 256 (32 B) |
|
||||
+------------+-----------------+
|
||||
| Interleave | 128 (16 B) |
|
||||
+------------+-----------------+
|
||||
|
||||
``make_smem_layout_atom`` then combines the chosen swizzle with a compact
|
||||
outer layout into a ``ComposedLayout(swizzle, outer)``.
|
||||
|
||||
4. **Tile to the operand shape and append the staging dimension.**
|
||||
``cute.tile_to_shape`` broadcasts the atom to the full ``(M_or_N, K)``
|
||||
shape with ``num_stages`` appended. The ``order`` argument controls which
|
||||
dimension is contiguous: ``(0, 1, 2)`` for K-major (K innermost),
|
||||
``(1, 0, 2)`` for MN-major (MN innermost).
|
||||
|
||||
For the running F16 example (``tile_shape_mnk = (128, 256, 64)``,
|
||||
``num_stages = 4``, K-major A, K-major B):
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
A operand (K-major, tile = (M=128, K=64)):
|
||||
major_mode_size = 64
|
||||
64 * 16 bits = 1024 bits → SW128
|
||||
atom = make_smem_layout_atom(K_SW128, Float16)
|
||||
tile_to_shape(atom, (128, 64, 4), order=(0,1,2))
|
||||
-> a_smem_layout: ComposedLayout with shape (128, 64, 4)
|
||||
|
||||
B operand (K-major, tile = (N=256, K=64)):
|
||||
major_mode_size = 64
|
||||
64 * 16 bits = 1024 bits → SW128
|
||||
atom = make_smem_layout_atom(K_SW128, Float16)
|
||||
tile_to_shape(atom, (256, 64, 4), order=(0,1,2))
|
||||
-> b_smem_layout: ComposedLayout with shape (256, 64, 4)
|
||||
|
||||
**Kernel side** (``@cute.kernel``):
|
||||
|
||||
The layout and swizzle are passed to shared-memory allocation. The result
|
||||
is a ``ComposedLayout`` whose ``.outer`` is the logical layout and ``.inner``
|
||||
is the swizzle:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Based on examples/cute/hopper/kernel/dense_gemm/dense_gemm.py
|
||||
sA = storage.sA.get_tensor(
|
||||
a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner
|
||||
)
|
||||
sB = storage.sB.get_tensor(
|
||||
b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner
|
||||
)
|
||||
|
||||
After allocation:
|
||||
|
||||
- ``sA`` has shape ``(BM, BK, PIPE) = (128, 64, 4)``.
|
||||
- ``sB`` has shape ``(BN, BK, PIPE) = (256, 64, 4)``.
|
||||
|
||||
These are the staged SMEM tensors consumed by ``partition_A`` /
|
||||
``partition_B`` and ``make_fragment_A`` / ``make_fragment_B``
|
||||
(see `Making Fragments`_).
|
||||
|
||||
.. note:: If you need finer control, you can build layout atoms directly with
|
||||
``cute.nvgpu.warpgroup.make_smem_layout_atom(...)`` and compose the final
|
||||
SMEM layout manually via ``cute.tile_to_shape``.
|
||||
|
||||
|
||||
Executing the GEMM (Main Loop)
|
||||
-------------------------------
|
||||
|
||||
The main loop iterates over K-tiles. The WGMMA-specific part of each
|
||||
iteration is the **fence / gemm / commit / wait** sequence:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
acc.fill(0.0)
|
||||
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
|
||||
|
||||
for k_tile in cutlass.range(k_pipe_mmas, k_tile_cnt, 1, unroll=1):
|
||||
# ... wait for TMA load (pipeline details in dense_gemm.py) ...
|
||||
|
||||
cute.nvgpu.warpgroup.fence()
|
||||
tile_crd = (None, None, None, consumer_read.index)
|
||||
cute.gemm(tiled_mma, acc, tCrA[tile_crd], tCrB[tile_crd], acc)
|
||||
cute.nvgpu.warpgroup.commit_group()
|
||||
cute.nvgpu.warpgroup.wait_group(k_pipe_mmas)
|
||||
|
||||
# ... release buffer & advance pipeline (see dense_gemm.py) ...
|
||||
|
||||
cute.nvgpu.warpgroup.wait_group(0)
|
||||
|
||||
Key points:
|
||||
|
||||
- ``fence()`` orders prior SMEM writes before WGMMA issue.
|
||||
- ``commit_group()`` publishes queued WGMMA instructions as a group.
|
||||
- ``wait_group(n)`` waits until at most ``n`` groups remain in flight.
|
||||
``wait_group(0)`` after the loop drains all work before the epilogue.
|
||||
- ``Field.ACCUMULATE`` — ``True`` accumulates (``D += A*B``),
|
||||
``False`` overwrites (``D = A*B``). The dense GEMM sets ``True`` and
|
||||
zero-fills ``acc`` so the first iteration computes ``0 + A*B``.
|
||||
|
||||
|
||||
Complete Workflow
|
||||
------------------
|
||||
|
||||
Putting it all together, a typical Hopper WGMMA GEMM has this structure.
|
||||
The MMA-relevant steps are highlighted; see ``dense_gemm.py`` for the full
|
||||
kernel including TMA, pipeline, and epilogue details.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.nvgpu import OperandMajorMode
|
||||
import cutlass.cute.nvgpu.warpgroup as warpgroup
|
||||
import cutlass.utils.hopper_helpers as sm90_utils
|
||||
|
||||
# --- Host side (@cute.jit) ---
|
||||
|
||||
# 1. MMA op + tiled MMA
|
||||
op = warpgroup.MmaF16BF16Op(
|
||||
cutlass.Float16, cutlass.Float32, (64, 128, 16),
|
||||
warpgroup.OperandSource.SMEM, OperandMajorMode.K, OperandMajorMode.K,
|
||||
)
|
||||
tiled_mma = cute.make_tiled_mma(op)
|
||||
|
||||
# 2. SMEM layouts
|
||||
a_smem_layout = sm90_utils.make_smem_layout_a(a_layout, tile_shape_mnk, a_dtype, num_stages)
|
||||
b_smem_layout = sm90_utils.make_smem_layout_b(b_layout, tile_shape_mnk, b_dtype, num_stages)
|
||||
|
||||
# 3. TMA copy atoms + kernel launch (see dense_gemm.py)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# --- Kernel side (@cute.kernel) ---
|
||||
|
||||
# 4. Allocate SMEM
|
||||
smem = cutlass.utils.SmemAllocator()
|
||||
storage = smem.allocate(SharedStorage)
|
||||
sA = storage.sA.get_tensor(
|
||||
a_smem_layout.outer, swizzle=a_smem_layout.inner) # (BM, BK, PIPE)
|
||||
sB = storage.sB.get_tensor(
|
||||
b_smem_layout.outer, swizzle=b_smem_layout.inner) # (BN, BK, PIPE)
|
||||
|
||||
# 5. CTA-tiled global tensors
|
||||
gA_mkl = cute.local_tile(mA_mkl, tile_shape_mnk, tile_coord, proj=(1, None, 1))
|
||||
gB_nkl = cute.local_tile(mB_nkl, tile_shape_mnk, tile_coord, proj=(None, 1, 1))
|
||||
gC_mnl = cute.local_tile(mC_mnl, tile_shape_mnk, tile_coord, proj=(1, 1, None))
|
||||
|
||||
# 6. Warpgroup slice, partition & make fragments
|
||||
warp_group_idx = cute.arch.make_warp_uniform(tidx // num_threads_per_warp_group)
|
||||
warp_group_thread_layout = cute.make_layout(mma_warp_groups, stride=num_threads_per_warp_group)
|
||||
thr_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx))
|
||||
|
||||
tCsA = thr_mma.partition_A(sA) # (MMA, MMA_M, MMA_K, PIPE)
|
||||
tCsB = thr_mma.partition_B(sB) # (MMA, MMA_N, MMA_K, PIPE)
|
||||
tCrA = tiled_mma.make_fragment_A(tCsA) # SMEM descriptor
|
||||
tCrB = tiled_mma.make_fragment_B(tCsB) # SMEM descriptor
|
||||
tCgC = thr_mma.partition_C(gC_mnl) # (MMA, MMA_M, MMA_N)
|
||||
acc = cute.make_rmem_tensor(tCgC.shape, acc_dtype)
|
||||
|
||||
# 7. TMA pipeline setup + prefetch (see dense_gemm.py)
|
||||
|
||||
# 8. Main loop — fence / gemm / commit / wait
|
||||
acc.fill(0.0)
|
||||
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
|
||||
|
||||
for k_tile in cutlass.range(k_pipe_mmas, k_tile_cnt, 1, unroll=1):
|
||||
# ... wait for TMA load ...
|
||||
cute.nvgpu.warpgroup.fence()
|
||||
tile_crd = (None, None, None, consumer_read.index)
|
||||
cute.gemm(tiled_mma, acc, tCrA[tile_crd], tCrB[tile_crd], acc)
|
||||
cute.nvgpu.warpgroup.commit_group()
|
||||
cute.nvgpu.warpgroup.wait_group(k_pipe_mmas)
|
||||
# ... release buffer, advance pipeline ...
|
||||
|
||||
cute.nvgpu.warpgroup.wait_group(0)
|
||||
|
||||
# 9. Epilogue: RMEM → SMEM (stmatrix) → GMEM (TMA store)
|
||||
# ... (see dense_gemm.py)
|
||||
|
||||
|
||||
.. Beyond Simple Dense MMAs
|
||||
.. ------------------------
|
||||
|
||||
.. The current Python DSL coverage for warpgroup MMA is centered on the three
|
||||
.. dense ops above. PTX also defines additional WGMMA instruction families that
|
||||
.. do **not** yet have DSL op classes. These are tracked in the source at
|
||||
.. ``cutlass/cute/nvgpu/warpgroup/mma.py`` (marked ``✗`` in the instruction
|
||||
.. table).
|
||||
|
||||
.. **Structured-sparse WGMMA** (``wgmma.mma_async.sp``)
|
||||
|
||||
.. 2:4 structured sparsity in operand A: out of every 4 consecutive K-elements,
|
||||
.. exactly 2 are non-zero. The instruction K is **doubled** relative to the
|
||||
.. dense counterpart (e.g. ``m64nNk32`` for F16/BF16 vs ``m64nNk16`` dense)
|
||||
.. because A is stored in compressed form. Supported data types include
|
||||
.. F16/BF16, TF32, FP8, and INT8.
|
||||
|
||||
.. Compared to the dense workflow, a sparse kernel would add:
|
||||
|
||||
.. - A **compressed A tensor** storing only the non-zero values (half the
|
||||
.. K-elements), and a **metadata tensor E** encoding which 2 of 4 positions
|
||||
.. are non-zero.
|
||||
.. - Extra SMEM layouts, TMA loads, and allocations for both the compressed A
|
||||
.. and the metadata E.
|
||||
.. - A metadata staging step each K-tile (SMEM to the MMA instruction).
|
||||
|
||||
.. Once DSL support is added, the same fence/commit/wait workflow described in
|
||||
.. this guide applies, with the additional metadata operand.
|
||||
|
||||
.. **Dense TF32 WGMMA** (``m64nNk8``)
|
||||
|
||||
.. TF32 (19-bit truncated FP32) inputs with FP32 accumulator. The instruction
|
||||
.. K = 8 is smaller than F16's K = 16, so MMA_K repeat counts are larger for
|
||||
.. the same BK tile size. Otherwise the workflow is identical to the dense
|
||||
.. F16/BF16 path — the same SMEM layout, descriptor, and fence/commit/wait
|
||||
.. pattern applies.
|
||||
|
||||
.. **Dense B1 WGMMA** (``m64nNk256``)
|
||||
|
||||
.. 1-bit (binary) inputs with INT32 accumulator. The very large instruction
|
||||
.. K = 256 means each atom consumes 256 bits along K per operand, resulting in
|
||||
.. small MMA_K repeat counts. This is a niche instruction for binary neural
|
||||
.. networks.
|
||||
|
||||
|
||||
See also:
|
||||
|
||||
- Dense GEMM example: ``examples/cute/hopper/kernel/dense_gemm/dense_gemm.py``
|
||||
- Persistent GEMM example: ``examples/cute/hopper/kernel/dense_gemm/dense_gemm_persistent.py``
|
||||
- FMHA example (RMEM A path): ``examples/cute/hopper/kernel/attention/fmha.py``
|
||||
- Helper utilities: ``cutlass.utils.hopper_helpers``
|
||||
1358
media/docs/pythonDSL/mma_docs/wmma_programming.rst
Normal file
1358
media/docs/pythonDSL/mma_docs/wmma_programming.rst
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user