diff --git a/CHANGELOG.md b/CHANGELOG.md index 3aa3bf24d..f027e0d37 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,28 @@ # CUTLASS 4.x +## [4.5.1](https://github.com/NVIDIA/cutlass/releases/tag/v4.5.1) (2026-05-15) + +### CuTe DSL +* Bug fixing and improvements + - Fixed following issues: + https://github.com/NVIDIA/cutlass/issues/3219 + https://github.com/NVIDIA/cutlass/issues/3218 + https://github.com/NVIDIA/cutlass/issues/3212 + https://github.com/NVIDIA/cutlass/issues/3210 + https://github.com/NVIDIA/cutlass/issues/3208 + https://github.com/NVIDIA/cutlass/issues/3201 + https://github.com/NVIDIA/cutlass/issues/3227 + - Fixed Jax int64 stride divisibility issue + - Fixed issues for SM120 blockscaled MMAs + - added missing MXFP8MMAOP and MXF8F6F4MMAOP for sm120. + +### CUTLASS C++ +* Fix SM100 F8F6F4 SS MMA (1SM and 2SM) traits to use typed op templates. +* Add UE8M0 (uniform exponent distribution) initialization support in tensor fill utilities. +* Add `cvt.rn.bf16x2.e4m3x2` conversion instruction support to `numeric_conversion.h`. +* Update [example 93](https://github.com/NVIDIA/cutlass/tree/main/examples/93_blackwell_low_latency_gqa) with paged KV cache support for Blackwell low-latency GQA. + ## [4.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.5.0) (2026-05-01) ### CuTe DSL @@ -20,7 +42,7 @@ - Improved source code correlation for profiling/debugging - Fixed an aarch64 segfault issue with tvm-ffi - Re-organization for CuTe DSL examples/tutorials for better discoverability - + * More examples of authorizing peak-performance kernels - MOE examles - A new style of grouped-gemm that aligns to torch's grouped_mm and scaled_groued_mm interface. diff --git a/README.md b/README.md index 0c7710c4d..28e2fe498 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ ![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") # Overview -# CUTLASS 4.5.0 +# CUTLASS 4.5.1 -_CUTLASS 4.5.0 - May 2026_ +_CUTLASS 4.5.1 - May 2026_ CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels and scales within CUDA. It incorporates strategies for @@ -61,6 +61,17 @@ To get started quickly - please refer : - Improved source code correlation for profiling/debugging - Fixed an aarch64 segfault issue with tvm-ffi - Re-organization for CuTe DSL examples/tutorials for better discoverability + - Fixed following issues: + https://github.com/NVIDIA/cutlass/issues/3219 + https://github.com/NVIDIA/cutlass/issues/3218 + https://github.com/NVIDIA/cutlass/issues/3212 + https://github.com/NVIDIA/cutlass/issues/3210 + https://github.com/NVIDIA/cutlass/issues/3208 + https://github.com/NVIDIA/cutlass/issues/3201 + https://github.com/NVIDIA/cutlass/issues/3227 + - Fixed Jax int64 stride divisibility issue + - Fixed issues for SM120 blockscaled MMAs + - added missing MXFP8MMAOP and MXF8F6F4MMAOP for sm120. * More examples of authorizing peak-performance kernels - MOE examles @@ -90,11 +101,13 @@ To get started quickly - please refer : * Add [example 95](https://github.com/NVIDIA/cutlass/tree/main/examples/95_blackwell_gemm_green_context) to support green context SM partition - Enables launching GEMM on stream with partial SM allocation. * Add [Snake](https://github.com/NVIDIA/cutlass/blob/main/test/unit/epilogue/thread/activation.cu#L409) activation functor for EVT. +* Fix SM100 F8F6F4 SS MMA (1SM and 2SM) traits to use typed op templates. +* Add UE8M0 (uniform exponent distribution) initialization support in tensor fill utilities. +* Add `cvt.rn.bf16x2.e4m3x2` conversion instruction support to `numeric_conversion.h`. +* Update [example 93](https://github.com/NVIDIA/cutlass/tree/main/examples/93_blackwell_low_latency_gqa) with paged KV cache support for Blackwell low-latency GQA. * Fix some kernel issues: - Fix l2_capacity=0 handling in Blackwell SM100/SM120 kernel templates - Fix CUTLASS clang build issues - - Fix atomicCAS read-modify-write loop in `ConstSubbyteReference` - - Replace `__nv_atomic_load_n` with `volatile` for CUDA 11.4 compatibility in subbyte reference - Remove `PipelineStorage` shadowing in SM100 complex epilogue - Fix build issue in SM90 epilogue fusion visitor TMA warpspecialized * Fix some profiler issues: diff --git a/examples/77_blackwell_fmha/collective/fmha_common.hpp b/examples/77_blackwell_fmha/collective/fmha_common.hpp index 11e381b36..207056b05 100644 --- a/examples/77_blackwell_fmha/collective/fmha_common.hpp +++ b/examples/77_blackwell_fmha/collective/fmha_common.hpp @@ -75,19 +75,17 @@ CUTE_HOST_DEVICE constexpr auto to_tiled_mma_sm100_ts( TiledMMA, cute::C, - cute::integral_constant, - cute::integral_constant, - cute::integral_constant, - cute::integral_constant>, + SM100_MMA_F8F6F4_SS, TAs...>, TMs...>) { return TiledMMA>, + a_neg, b_neg, UMMA::Saturate::False>, TAs...>, TMs...>{}; } diff --git a/examples/93_blackwell_low_latency_gqa/CMakeLists.txt b/examples/93_blackwell_low_latency_gqa/CMakeLists.txt index 8f708db1e..8d034896b 100644 --- a/examples/93_blackwell_low_latency_gqa/CMakeLists.txt +++ b/examples/93_blackwell_low_latency_gqa/CMakeLists.txt @@ -31,6 +31,9 @@ if (NOT MSVC AND CUTLASS_NVCC_ARCHS MATCHES "100a|100f|103a|103f") cutlass_example_add_executable( 93_blackwell_low_latency_gqa tgv_gqa.cu + common.cuh + tgv_gqa.cuh + tgv_gqa_paged.cuh ) endif() diff --git a/examples/93_blackwell_low_latency_gqa/common.cuh b/examples/93_blackwell_low_latency_gqa/common.cuh new file mode 100644 index 000000000..7c32e7ebd --- /dev/null +++ b/examples/93_blackwell_low_latency_gqa/common.cuh @@ -0,0 +1,140 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cuda_runtime.h" + +#include + +#include +#include + +#include +#include + +#ifndef gpuErrChk +#define gpuErrChk(ans) { gpuAssert2((ans), __FILE__, __LINE__); } +inline void gpuAssert2(cudaError_t code, const char *file, int line, bool abort=true) { + if (code != cudaSuccess) { + fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); + if (abort) exit(code); + } +} +#endif + +namespace TGV { + +using namespace cute; + +// Store value to remote shared memory in the cluster +CUTE_DEVICE void +store_shared_remote_f32(float value, uint32_t dsmem_addr, uint32_t remote_barrier_addr) { + asm volatile("st.async.shared::cluster.mbarrier::complete_tx::bytes.f32 [%0], %1, [%2];" + : : "r"(dsmem_addr), "f"(value), "r"(remote_barrier_addr)); +} + +// given a smem tensor, return the dsmem tensor for the given rank, the tensor addr is in smem addr space (not generic addr space) +template +CUTE_DEVICE auto +get_dsmem_tensor(Tensor tensor, int rank) { + using T = typename decltype(tensor)::value_type; + // tensor.data().get() is the smem addr in the generic addr space, in the generic addr space a region is reserved for smem + // doing ld/st to this region of the generic addr space will be converted into ld.shared/st.shared to the smem addr space by the compiler + // the mapa (and many inline ptx) instruction's input and output addr are in the smem/dsmem addr space, so we need to explicitly convert from generic to shared addr space + uint32_t smem_addr = __cvta_generic_to_shared(tensor.data().get()); // smem addr space + // mapa to get the dsmem addr of this tensor in another CTA + uint32_t dsmem_addr = set_block_rank(smem_addr, rank); // smem addr space + return make_tensor(make_smem_ptr((T*)dsmem_addr), tensor.layout()); +} + +// copied from SM100::TMEM::LOAD::copy_unpack cutlass/include/cute/atom/copy_traits_sm100.hpp +// what it does is given a tmem address, load the data into rmem tensor with the given tcgen05.ld copy op +template < + class CopyOp, + class TD, class DLayout> +CUTLASS_DEVICE void +tmem_load( + uint32_t tmem_addr, + Tensor& dst +) { + static_assert(is_rmem::value, "Expected RMEM dst."); + + using RegTypeDst = typename remove_extent::type; + Tensor rD = recast(dst); + + constexpr int RegNumDst = extent::value; + CUTE_STATIC_ASSERT_V(size(rD) == Int{}, + "The tcgen05.ld CopyOp's size does not match the destination tensor size."); + + detail::explode(CopyOp::copy, + &tmem_addr, seq<0>{}, + rD, make_seq{}); +} + +// copied from SM100::TMEM::STORE::copy_unpack cutlass/include/cute/atom/copy_traits_sm100.hpp +// what it does is given a tmem address, store the data in rmem tensor to the tmem address with the given tcgen05.st copy op +template < + class CopyOp, + class TS, class SLayout> +CUTLASS_DEVICE void +tmem_store( + Tensor& src, + uint32_t tmem_addr +) { + static_assert(is_rmem::value, "Expected RMEM src."); + + using RegTypeSrc = typename remove_extent::type; + Tensor rS = recast(src); + + constexpr int RegNumSrc = extent::value; + CUTE_STATIC_ASSERT_V(size(rS) == Int{}, + "The tcgen05.st CopyOp's size does not match the source tensor size."); + + detail::explode(CopyOp::copy, + rS, make_seq{}, + &tmem_addr, seq<0>{}); +} + +// issue cp.async to load 4 bytes (one int) from gmem to smem +CUTLASS_DEVICE void +cp_async( + int* gmem_addr, + int* smem_addr +) { + uint32_t smem_int_ptr = cute::cast_smem_ptr_to_uint(smem_addr); + asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + :: "r"(smem_int_ptr), + "l"(gmem_addr), + "n"(sizeof(int))); +} + +} // namespace TGV diff --git a/examples/93_blackwell_low_latency_gqa/readme.md b/examples/93_blackwell_low_latency_gqa/readme.md index 09c62508a..1289296aa 100644 --- a/examples/93_blackwell_low_latency_gqa/readme.md +++ b/examples/93_blackwell_low_latency_gqa/readme.md @@ -1,6 +1,18 @@ # Blackwell Low Latency GQA This example introduces TGV GQA, a CuTe C++-based Blackwell kernel optimized for low latency (low batch) generation phase GQA. +The example ships two variants: + +- `tgv_gqa.cuh` — contiguous KV cache (default, `--mode 0`). +- `tgv_gqa_paged.cuh` — paged KV cache (`--mode 1`). Layout matches a typical paged-attention serving runtime: a combined KV + buffer of shape `(num_pages_total, 2, Page_Size, kvH, dH)` (BS folded into `num_pages_total`, mode-1 selects K vs V) plus a + `(kvL/Page_Size, BS)` page table that maps `(bs, per_batch_page_idx) -> physical page id`. The example harness builds the + page table host-side using a Fisher-Yates shuffle to stress non-contiguous mappings; replacing that with another placement + policy needs no kernel changes. + +`common.cuh` holds the shared inline PTX wrappers used by both variants (`cp_async`, `tmem_load`/`tmem_store`, +`store_shared_remote_f32`, `get_dsmem_tensor`). + To compile and run this example: ```bash # in cutlass top level directory @@ -8,7 +20,10 @@ mkdir build && cd build cmake .. -DCUTLASS_NVCC_ARCHS=100a -DCUTLASS_ENABLE_TESTS=OFF -DCUTLASS_ENABLE_EXAMPLES=ON -DCUTLASS_ENABLE_LIBRARY=OFF cd examples/93_blackwell_low_latency_gqa make +# contiguous KV cache (default) ./93_blackwell_low_latency_gqa --kvL 8192 --kvH 8 --qH 64 --BS 1 +# paged KV cache +./93_blackwell_low_latency_gqa --kvL 8192 --kvH 8 --qH 64 --BS 1 --mode 1 ``` Supported configs are: @@ -20,11 +35,11 @@ Supported configs are: - Flash decoding, configurable number of splits - Cluster reduction with configurable number of reduction cta - Attention sink and sliding window +- Paged KV cache (`--mode 1`) Unsupported features are: - Persistent schedule - MTP -- Paged KV cache ## Kernel Design diff --git a/examples/93_blackwell_low_latency_gqa/tgv_gqa.cu b/examples/93_blackwell_low_latency_gqa/tgv_gqa.cu index 90f847bb8..3d10af335 100644 --- a/examples/93_blackwell_low_latency_gqa/tgv_gqa.cu +++ b/examples/93_blackwell_low_latency_gqa/tgv_gqa.cu @@ -40,8 +40,14 @@ kvL is max_seq_len, seq_lens[BS] is the actual seq len for each batch sinks has shape (qHLocal * kvH), i.e. one sink per q head + --mode 1 enables the paged variant: + Combined KV cache shape (num_pages_total, 2, Page_Size, kvH, dH), with BS folded into num_pages_total + and KV(_,0,_,_,_) = K, KV(_,1,_,_,_) = V. + page_table shape (kvL/Page_Size, BS), entry [p, bs] = physical page id of batch bs's per-batch page p. + Example usage: $ ./examples/93_blackwell_low_latency_gqa --kvL 8192 --kvH 8 --qH 64 --BS 1 + $ ./examples/93_blackwell_low_latency_gqa --kvL 8192 --kvH 8 --qH 64 --BS 1 --mode 1 */ // Standard library includes @@ -51,6 +57,8 @@ #include #include #include +#include +#include #include #include @@ -68,7 +76,9 @@ // CuTe includes #include // CuTe tensor implementation +#include "common.cuh" #include "tgv_gqa.cuh" +#include "tgv_gqa_paged.cuh" using namespace cute; @@ -411,9 +421,23 @@ struct ProblemStride { int stride_O_qL; int stride_O_dH; int stride_O_BS; + + // Combined KV (paged mode) layout: (num_pages_total, 2, Page_Size, kvH, dH). + // BS is folded into num_pages_total = BS * kvL / Page_Size (no explicit BS mode). + // Mode 1 is the K/V selector; harness picks dH-innermost packed strides below. + int stride_KV_pages; + int stride_KV_KV; + int stride_KV_ps; + int stride_KV_kvH; + int stride_KV_dH; + + // Page table (paged mode) layout: (kvL/Page_Size, BS), entry [p, bs] = physical page id of batch bs's per-batch + // page p. Mode 0 is innermost in memory (per-batch page id, stride 1); mode 1 advances by pages_per_batch across batches. + int stride_PT_p; + int stride_PT_BS; }; -ProblemStride make_gqa_stride(int kvH, int qHLocal, int qL, int kvL, int dH, int BS) { +ProblemStride make_gqa_stride(int kvH, int qHLocal, int qL, int kvL, int dH, int BS, int Page_Size) { ProblemStride stride; // Q shape ((qHLocal, qL), dH, kvH, BS), where dH is contiguous @@ -446,6 +470,20 @@ ProblemStride make_gqa_stride(int kvH, int qHLocal, int qL, int kvL, int dH, int stride.stride_O_dH = 1; stride.stride_O_BS = kvH * qHLocal * dH * qL; + // Combined KV (paged) shape (num_pages_total, 2, Page_Size, kvH, dH), dH innermost contiguous. + // BS is folded into num_pages_total; the (bs_idx, per_batch_page_idx) -> physical page mapping lives in the + // page_table tensor (see stride_PT_*). Strides slowest -> fastest: + // num_pages_total, KV (K/V selector), Page_Size, kvH, dH. + stride.stride_KV_dH = 1; + stride.stride_KV_kvH = dH; + stride.stride_KV_ps = kvH * dH; + stride.stride_KV_KV = Page_Size * kvH * dH; + stride.stride_KV_pages = 2 * Page_Size * kvH * dH; + + // Page table shape (kvL/Page_Size, BS); per-batch page idx is mode 0 (contiguous, stride 1), batch is mode 1 (advances by pages_per_batch). + stride.stride_PT_p = 1; + stride.stride_PT_BS = kvL / Page_Size; + return stride; } @@ -459,17 +497,25 @@ public: static constexpr int CTA_qL = 1; static constexpr int CTA_kvL = 128; static constexpr int CTA_dH = 64; + // Page_Size only used by gqa_paged (mode 1). Page_Size must divide CTA_kvL; CTA_kvL/Page_Size = pages per CTA tile. + static constexpr int Page_Size = 32; static constexpr int BMM1_DMA_Stage = 3; static constexpr int BMM2_DMA_Stage = 3; + // Page-idx staging (mode 1 only). Num_Page_Idx_Per_Stage must be a multiple of CTA_kvL/Page_Size + // so a DMA stage's pages live in one pi stage. Page_Idx_Stage = pipeline depth on the page-idx side. + static constexpr int Page_Idx_Stage = 2; + static constexpr int Num_Page_Idx_Per_Stage = 8 * (CTA_kvL / Page_Size); static constexpr int MaxSplits = 8; static constexpr int NumReductionCTA = 8; static constexpr bool NoSink = true; + static constexpr bool VarSeqLens = false; private: int kvH_, qHLocal_, qL_, kvL_, dH_, BS_; float softmax_scale_; ProblemStride stride_; int sliding_window_size_; + int mode_; // 0: gqa, 1: gqa_paged // Host vectors thrust::host_vector host_Q_; @@ -480,17 +526,82 @@ private: thrust::host_vector host_seq_lens_; thrust::host_vector host_sinks_; - // Device vectors + // Device vectors thrust::device_vector device_Q_; thrust::device_vector device_K_; thrust::device_vector device_V_; + // Combined KV cache for paged mode (mode_ == 1). + // Layout: (num_pages_total, 2, Page_Size, kvH, dH) with dH innermost; BS is folded into num_pages_total. + // The (bs, per_batch_page_idx) -> physical page mapping is held in device_page_table_ (built by + // build_random_page_table: a Fisher-Yates shuffle of [0, num_pages_total) sliced per batch). + // KV(_,0,_,_,_) is K, KV(_,1,_,_,_) is V. + thrust::device_vector device_KV_; + // Page table for paged mode. Logical shape (kvL/Page_Size, BS); entry [p, bs] is the physical page id for + // batch bs and per-batch page index p. Padded by Num_Page_Idx_Per_Stage ints at the tail so the device's + // last-pi-stage cp.async never reads past the allocation. + thrust::device_vector device_page_table_; thrust::device_vector device_O_; thrust::device_vector device_seq_lens_; thrust::device_vector device_sinks_; + // Build the host-side page table: tail-padded buffer with shape (kvL/Page_Size, BS) and contents = a random + // per-batch slice of a Fisher-Yates shuffle of [0, num_pages_total). Each batch's slice length is + // seq_len[bs]/Page_Size, so total pages assigned = sum(seq_len_pages) <= num_pages_total. Tail padding + // (Num_Page_Idx_Per_Stage ints) covers the device's last-pi-stage cp.async OOB read. + thrust::host_vector build_random_page_table(int pages_per_batch, int num_pages_total) { + thrust::host_vector host_page_table(num_pages_total + Num_Page_Idx_Per_Stage, 0); + auto host_tensor_page_table = make_tensor(host_page_table.data(), + make_layout(make_shape(pages_per_batch, BS_), + make_stride(stride_.stride_PT_p, stride_.stride_PT_BS))); + std::vector perm(num_pages_total); + std::iota(perm.begin(), perm.end(), 0); + for (int i = num_pages_total - 1; i > 0; --i) { + std::swap(perm[i], perm[rand() % (i + 1)]); + } + int perm_offset = 0; + for (int bs = 0; bs < BS_; ++bs) { + // ceil_div: a partial tail page (seq_len % Page_Size != 0) still needs a physical page assigned + // so the kernel's address-by-page lookup for positions [floor(seq_len/Page_Size)*Page_Size, seq_len) + // hits real packed K/V data. Positions past seq_len are masked by the kernel and don't contribute. + int seq_len_pages = cutlass::ceil_div(host_seq_lens_[bs], Page_Size); + for (int p = 0; p < seq_len_pages; ++p) { + host_tensor_page_table(p, bs) = perm[perm_offset + p]; + } + perm_offset += seq_len_pages; + } + return host_page_table; + } + + // Pack the flat per-batch K/V buffers into the combined paged KV layout, using the page table for the + // (bs, per_batch_page) -> physical page mapping. Bounded by seq_len_pages per batch so out-of-seq-len entries + // (which the page table doesn't populate) aren't dereferenced. + template + void pack_combined_kv(HostTensorPageTable const& host_tensor_page_table, + HostTensorK const& host_tensor_K, HostTensorV const& host_tensor_V, + HostTensorKV& host_tensor_KV) { + for (int bs = 0; bs < BS_; ++bs) { + // ceil_div pages so the partial tail page (when seq_len isn't page-aligned) is packed. + // We pack the full Page_Size for the tail page; positions past seq_len carry whatever + // initialize_tensor wrote into the K/V buffers, but the kernel masks those out. + int seq_len_pages = cutlass::ceil_div(host_seq_lens_[bs], Page_Size); + for (int p = 0; p < seq_len_pages; ++p) { + int global_page = host_tensor_page_table(p, bs); + for (int ps = 0; ps < Page_Size; ++ps) { + int kvl = p * Page_Size + ps; + for (int kvh = 0; kvh < kvH_; ++kvh) { + for (int dh = 0; dh < dH_; ++dh) { + host_tensor_KV(global_page, 0, ps, kvh, dh) = host_tensor_K(kvl, dh, kvh, bs); + host_tensor_KV(global_page, 1, ps, kvh, dh) = host_tensor_V(dh, kvl, kvh, bs); + } + } + } + } + } + } + public: - GQATester(int kvH, int qH, int qL, int kvL, int dH, int BS, float softmax_scale, int sliding_window_size) : - kvH_(kvH), qHLocal_(qH / kvH), qL_(qL), kvL_(kvL), dH_(dH), BS_(BS), softmax_scale_(softmax_scale), sliding_window_size_(sliding_window_size) { + GQATester(int kvH, int qH, int qL, int kvL, int dH, int BS, float softmax_scale, int sliding_window_size, int mode = 0) : + kvH_(kvH), qHLocal_(qH / kvH), qL_(qL), kvL_(kvL), dH_(dH), BS_(BS), softmax_scale_(softmax_scale), sliding_window_size_(sliding_window_size), mode_(mode) { assert(sliding_window_size_ >= 0); // Allocate host memory host_Q_.resize(kvH_ * qHLocal_ * qL_ * dH_ * BS_); @@ -501,7 +612,7 @@ public: host_seq_lens_.resize(BS_); host_sinks_.resize(qHLocal_ * kvH_); // one sink per q head - stride_ = make_gqa_stride(kvH_, qHLocal_, qL_, kvL_, dH_, BS_); + stride_ = make_gqa_stride(kvH_, qHLocal_, qL_, kvL_, dH_, BS_, Page_Size); // Create host CuTe tensors for initialization auto host_tensor_Q = make_tensor(host_Q_.data(), TGV::gqa::make_layout_Q(kvH_, qHLocal_, qL_, dH_, BS_, stride_.stride_Q_kvH, stride_.stride_Q_qHLocal, stride_.stride_Q_qL, stride_.stride_Q_dH, stride_.stride_Q_BS)); @@ -513,10 +624,8 @@ public: initialize_tensor(host_tensor_Q); initialize_tensor(host_tensor_K); initialize_tensor(host_tensor_V); - // have batch size matching kvL (i.e. max seq len) for now - bool test_var_seq_lens = false; for (int i = 0; i < BS_; ++i) { - if (test_var_seq_lens) { + if (VarSeqLens) { host_seq_lens_[i] = rand() % kvL_ + 1; } else { // all the batch have the same seq len @@ -535,27 +644,81 @@ public: device_seq_lens_ = host_seq_lens_; device_sinks_ = host_sinks_; + // For paged mode, build the page_table and combined KV tensor on device. The harness owns the layouts: + // stride_KV_* and stride_PT_* were computed in make_gqa_stride above and are used both here (host pack) + // and downstream (passed into gqa_paged_host as stride args). Combined KV shape: (num_pages_total, 2, + // Page_Size, kvH, dH), BS folded into num_pages_total. We populate the page_table via + // build_random_page_table (Fisher-Yates shuffle of [0, num_pages_total) sliced per batch) to stress + // non-contiguous mappings. + if (mode_ == 1) { + assert(kvL_ % Page_Size == 0); + assert(kvL_ % CTA_kvL == 0); + int pages_per_batch = kvL_ / Page_Size; + int num_pages_total = BS_ * pages_per_batch; + + auto host_page_table = build_random_page_table(pages_per_batch, num_pages_total); + auto host_tensor_page_table = make_tensor(host_page_table.data(), + make_layout(make_shape(pages_per_batch, BS_), + make_stride(stride_.stride_PT_p, stride_.stride_PT_BS))); + + thrust::host_vector host_KV(num_pages_total * stride_.stride_KV_pages); + auto host_tensor_KV = make_tensor(host_KV.data(), + make_layout(make_shape(num_pages_total, 2, Page_Size, kvH_, dH_), + make_stride(stride_.stride_KV_pages, stride_.stride_KV_KV, + stride_.stride_KV_ps, stride_.stride_KV_kvH, stride_.stride_KV_dH))); + pack_combined_kv(host_tensor_page_table, host_tensor_K, host_tensor_V, host_tensor_KV); + + device_KV_ = host_KV; + device_page_table_ = host_page_table; + } + gpuErrChk(cudaDeviceSynchronize()); } void run_kernel(bool pdl, int pdl_count = -1, cudaStream_t stream = 0) { - TGV::gqa::gqa_host< - TypeQKV, TypeO, TypeAcc, - CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH, - BMM1_DMA_Stage, BMM2_DMA_Stage, - MaxSplits, - NumReductionCTA>( - device_K_.data().get(), device_Q_.data().get(), device_V_.data().get(), device_O_.data().get(), - device_seq_lens_.data().get(), - NoSink ? nullptr : device_sinks_.data().get(), - kvH_, qHLocal_, qL_, kvL_, dH_, BS_, - stride_.stride_K_kvH, stride_.stride_K_kvL, stride_.stride_K_dH, stride_.stride_K_BS, - stride_.stride_Q_kvH, stride_.stride_Q_qHLocal, stride_.stride_Q_qL, stride_.stride_Q_dH, stride_.stride_Q_BS, - stride_.stride_V_kvH, stride_.stride_V_kvL, stride_.stride_V_dH, stride_.stride_V_BS, - stride_.stride_O_kvH, stride_.stride_O_qHLocal, stride_.stride_O_qL, stride_.stride_O_dH, stride_.stride_O_BS, - softmax_scale_, - sliding_window_size_, - pdl, pdl_count, stream); + if (mode_ == 0) { + TGV::gqa::gqa_host< + TypeQKV, TypeO, TypeAcc, + CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH, + BMM1_DMA_Stage, BMM2_DMA_Stage, + MaxSplits, + NumReductionCTA>( + device_K_.data().get(), device_Q_.data().get(), device_V_.data().get(), device_O_.data().get(), + device_seq_lens_.data().get(), + NoSink ? nullptr : device_sinks_.data().get(), + kvH_, qHLocal_, qL_, kvL_, dH_, BS_, + stride_.stride_K_kvH, stride_.stride_K_kvL, stride_.stride_K_dH, stride_.stride_K_BS, + stride_.stride_Q_kvH, stride_.stride_Q_qHLocal, stride_.stride_Q_qL, stride_.stride_Q_dH, stride_.stride_Q_BS, + stride_.stride_V_kvH, stride_.stride_V_kvL, stride_.stride_V_dH, stride_.stride_V_BS, + stride_.stride_O_kvH, stride_.stride_O_qHLocal, stride_.stride_O_qL, stride_.stride_O_dH, stride_.stride_O_BS, + softmax_scale_, + sliding_window_size_, + pdl, pdl_count, stream); + } + else { + TGV::gqa_paged::gqa_paged_host< + TypeQKV, TypeO, TypeAcc, + CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH, + Page_Size, + BMM1_DMA_Stage, BMM2_DMA_Stage, + Page_Idx_Stage, Num_Page_Idx_Per_Stage, + MaxSplits, + NumReductionCTA>( + device_KV_.data().get(), + device_Q_.data().get(), + device_O_.data().get(), + NoSink ? nullptr : device_sinks_.data().get(), + device_seq_lens_.data().get(), + device_page_table_.data().get(), + kvH_, qHLocal_, qL_, kvL_, dH_, BS_, + stride_.stride_KV_pages, stride_.stride_KV_KV, stride_.stride_KV_ps, stride_.stride_KV_kvH, stride_.stride_KV_dH, + stride_.stride_Q_kvH, stride_.stride_Q_qHLocal, stride_.stride_Q_qL, stride_.stride_Q_dH, stride_.stride_Q_BS, + stride_.stride_O_kvH, stride_.stride_O_qHLocal, stride_.stride_O_qL, stride_.stride_O_dH, stride_.stride_O_BS, + stride_.stride_PT_p, stride_.stride_PT_BS, + softmax_scale_, + sliding_window_size_, + pdl, pdl_count, stream); + } } bool verify() { @@ -604,16 +767,17 @@ public: }; -void benchmark_gqa(int kvH, int qH, int qL, int kvL, int dH, int BS, float softmax_scale, int sliding_window_size, bool pdl, int pdl_count, int num_testers = 4, int bench_iters = 100) { +void benchmark_gqa(int kvH, int qH, int qL, int kvL, int dH, int BS, float softmax_scale, int sliding_window_size, int mode, bool pdl, int pdl_count, int num_testers = 4, int bench_iters = 100) { std::cout << "=== GQA Benchmark ===" << std::endl; std::cout << "Problem size: kvH=" << kvH << ", qH=" << qH << ", qL=" << qL << ", kvL=" << kvL << ", dH=" << dH << ", BS=" << BS << ", sliding_window_size=" << sliding_window_size << std::endl; + std::cout << "Mode: " << mode << " (" << (mode == 0 ? "gqa" : "gqa_paged") << ")" << std::endl; std::cout << "Number of testers (L2 thrashing): " << num_testers << std::endl; std::cout << "Benchmark iterations: " << bench_iters << std::endl; // Create multiple tester instances to thrash L2 cache std::vector> testers; for (int i = 0; i < num_testers; ++i) { - testers.push_back(std::make_unique(kvH, qH, qL, kvL, dH, BS, softmax_scale, sliding_window_size)); + testers.push_back(std::make_unique(kvH, qH, qL, kvL, dH, BS, softmax_scale, sliding_window_size, mode)); } std::cout << "Created " << num_testers << " GQATester instances" << std::endl; @@ -708,6 +872,8 @@ int main(int argc, char* argv[]) { bool pdl = false; // don't support it yet int pdl_count = -1; + // 0: gqa (contiguous KV cache), 1: gqa_paged (paged KV cache) + int mode = 0; // arg parsing while (1) { @@ -718,6 +884,7 @@ int main(int argc, char* argv[]) { {"qL", required_argument, 0, 0}, {"BS", required_argument, 0, 0}, {"sliding_window_size", required_argument, 0, 0}, + {"mode", required_argument, 0, 0}, {0, 0, 0, 0} // denote end of array }; @@ -737,14 +904,17 @@ int main(int argc, char* argv[]) { else if (option_index == 3) qL = atoi(optarg); else if (option_index == 4) BS = atoi(optarg); else if (option_index == 5) sliding_window_size = atoi(optarg); + else if (option_index == 6) mode = atoi(optarg); break; default: assert(false); } } - GQATester tester(kvH, qH, qL, kvL, dH, BS, softmax_scale, sliding_window_size); + GQATester tester(kvH, qH, qL, kvL, dH, BS, softmax_scale, sliding_window_size, mode); bool success = tester.verify(); - std::cout << "Correctness test " << (success ? "PASSED" : "FAILED") << std::endl; + std::cout << "Correctness test" + << " mode=" << mode + << " " << (success ? "PASSED" : "FAILED") << std::endl; - benchmark_gqa(kvH, qH, qL, kvL, dH, BS, softmax_scale, sliding_window_size, pdl, pdl_count, 100, 1000); + benchmark_gqa(kvH, qH, qL, kvL, dH, BS, softmax_scale, sliding_window_size, mode, pdl, pdl_count, 100, 1000); } diff --git a/examples/93_blackwell_low_latency_gqa/tgv_gqa.cuh b/examples/93_blackwell_low_latency_gqa/tgv_gqa.cuh index 303aa876a..bdfd5cba7 100644 --- a/examples/93_blackwell_low_latency_gqa/tgv_gqa.cuh +++ b/examples/93_blackwell_low_latency_gqa/tgv_gqa.cuh @@ -50,13 +50,7 @@ #include // TMEM allocator for SM100 #include -#define gpuErrChk(ans) { gpuAssert2((ans), __FILE__, __LINE__); } -inline void gpuAssert2(cudaError_t code, const char *file, int line, bool abort=true) { - if (code != cudaSuccess) { - fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); - if (abort) exit(code); - } -} +#include "common.cuh" namespace TGV { namespace gqa { @@ -93,27 +87,6 @@ acc = tl.dot(p.to(q.dtype), v) # [BLOCK_qL, BLOCK_dH] tl.store(acc_ptrs, acc) */ -// Store value to remote shared memory in the cluster -CUTE_DEVICE void -store_shared_remote_f32(float value, uint32_t dsmem_addr, uint32_t remote_barrier_addr) { - asm volatile("st.async.shared::cluster.mbarrier::complete_tx::bytes.f32 [%0], %1, [%2];" - : : "r"(dsmem_addr), "f"(value), "r"(remote_barrier_addr)); -} - -// given a smem tensor, return the dsmem tensor for the given rank, the tensor addr is in smem addr space (not generic addr space) -template -CUTE_DEVICE auto -get_dsmem_tensor(Tensor tensor, int rank) { - using T = typename decltype(tensor)::value_type; - // tensor.data().get() is the smem addr in the generic addr space, in the generic addr space a region is reserved for smem - // doing ld/st to this region of the generic addr space will be converted into ld.shared/st.shared to the smem addr space by the compiler - // the mapa (and many inline ptx) instruction's input and output addr are in the smem/dsmem addr space, so we need to explicitly convert from generic to shared addr space - uint32_t smem_addr = __cvta_generic_to_shared(tensor.data().get()); // smem addr space - // mapa to get the dsmem addr of this tensor in another CTA - uint32_t dsmem_addr = set_block_rank(smem_addr, rank); // smem addr space - return make_tensor(make_smem_ptr((T*)dsmem_addr), tensor.layout()); -} - // Helper methods to create layouts // K always has the shape (kvL, dH, kvH, BS) // kvH has to be the last dim because we do mma partitioning to the first two dims (M, K) in gemm terminology @@ -243,6 +216,8 @@ struct SharedStorage { alignas(16) cute::uint64_t bmm1_softmax_full_barrier; // Barrier between BMM1 and softmax, BMM1 tells softmax the tile is ready/full, softmax can start consuming it alignas(16) cute::uint64_t bmm2_epilog_full_barrier; // Barrier between BMM2 and epilog, BMM2 tells epilog the tile is ready/full, epilog can start consuming it + alignas(16) cute::uint64_t tmem_allocation_result_barrier; // Barrier between MMA and epilog, sync tmem allocation/deallocation status between MMA and epilogue warps within CTA + // for cluster reduction alignas(16) cute::uint64_t maxsum_mailbox_full_barrier; // barrier indicating the st.async of fmax and fsum are done alignas(16) cute::uint64_t acc2_mailbox_full_barrier; // barrier indicating the st.async of acc2 are done @@ -512,67 +487,6 @@ cta_reduce_transposed( return acc; } -// copied from SM100::TMEM::LOAD::copy_unpack cutlass/include/cute/atom/copy_traits_sm100.hpp -// what it does is given a tmem address, load the data into rmem tensor with the given tcgen05.ld copy op -template < - class CopyOp, - class TD, class DLayout> -CUTLASS_DEVICE void -tmem_load( - uint32_t tmem_addr, - Tensor& dst -) { - static_assert(is_rmem::value, "Expected RMEM dst."); - - using RegTypeDst = typename remove_extent::type; - Tensor rD = recast(dst); - - constexpr int RegNumDst = extent::value; - CUTE_STATIC_ASSERT_V(size(rD) == Int{}, - "The tcgen05.ld CopyOp's size does not match the destination tensor size."); - - detail::explode(CopyOp::copy, - &tmem_addr, seq<0>{}, - rD, make_seq{}); -} - -// copied from SM100::TMEM::STORE::copy_unpack cutlass/include/cute/atom/copy_traits_sm100.hpp -// what it does is given a tmem address, store the data in rmem tensor to the tmem address with the given tcgen05.st copy op -template < - class CopyOp, - class TS, class SLayout> -CUTLASS_DEVICE void -tmem_store( - Tensor& src, - uint32_t tmem_addr -) { - static_assert(is_rmem::value, "Expected RMEM src."); - - using RegTypeSrc = typename remove_extent::type; - Tensor rS = recast(src); - - constexpr int RegNumSrc = extent::value; - CUTE_STATIC_ASSERT_V(size(rS) == Int{}, - "The tcgen05.st CopyOp's size does not match the source tensor size."); - - detail::explode(CopyOp::copy, - rS, make_seq{}, - &tmem_addr, seq<0>{}); -} - -// issue cp.async -CUTLASS_DEVICE void -cp_async( - int* gmem_addr, - int* smem_addr -) { - uint32_t smem_int_ptr = cute::cast_smem_ptr_to_uint(smem_addr); - asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" - :: "r"(smem_int_ptr), - "l"(gmem_addr), - "n"(sizeof(int))); -} - // mapping between thread id (T) -> dH (row index of Acc2) template CUTLASS_DEVICE auto @@ -831,14 +745,13 @@ template < class TiledBMM1, class TiledBMM2, int CTA_qHLocal, int CTA_qL, int CTA_kvL, int CTA_dH> -CUTLASS_DEVICE void +CUTLASS_DEVICE void MMA_warp( SharedStorage& shared_storage, WorkTileInfo work_tile_info, OTensor mO, TiledBMM1 tiled_bmm1, - TiledBMM2 tiled_bmm2, - cutlass::arch::NamedBarrier& tmem_allocation_barrier + TiledBMM2 tiled_bmm2 ) { if (!work_tile_info.is_valid()) { // we don't allocate tmem for invalid tiles but we still need to relinquish the allocation lock @@ -900,7 +813,7 @@ MMA_warp( tmem_allocator.allocate(Acc1_col_max, &shared_storage.bmm1_tmem_base_ptr); tmem_allocator.allocate(Acc2_col_max, &shared_storage.bmm2_tmem_base_ptr); // notify epilog warp that tmem allocation is complete - tmem_allocation_barrier.arrive(); + arrive_barrier(shared_storage.tmem_allocation_result_barrier); // relinquish early so that prefetch cta can be launched tmem_allocator.release_allocation_lock(); @@ -963,7 +876,10 @@ MMA_warp( MMA_gemm(tCrV, tCrP, tCtAcc2, tiled_bmm2, bmm2_stage_idx, tma_bmm2_full_barrier_phase_bit, bmm2_accumulate, shared_storage.tmasoftmax_bmm2_full_barrier, shared_storage.tma_bmm2_empty_barrier, shared_storage.bmm2_epilog_full_barrier); // wait for tmem deallocation signal from epilog warp - tmem_allocation_barrier.arrive_and_wait(); + arrive_barrier(shared_storage.tmem_allocation_result_barrier); + // initial phase bit = 1 since it's already flipped once for tmem allocation + // it will flip to 0 when tmem can be deallocated, so we wait for old phase bit of 1 + wait_barrier(shared_storage.tmem_allocation_result_barrier, 1); // deallocate TMEM tmem_allocator.free(shared_storage.bmm1_tmem_base_ptr, Acc1_col_max); @@ -982,7 +898,7 @@ template < int CTA_qHLocal, int CTA_qL, int CTA_kvL, int CTA_dH, int NumReductionCTA, bool NoSink> -CUTLASS_DEVICE void +CUTLASS_DEVICE void EPILOG_warp( SharedStorage& shared_storage, WorkTileInfo work_tile_info, @@ -994,7 +910,6 @@ EPILOG_warp( TiledBMM2 tiled_bmm2, float softmax_scale_log2, int sliding_window_size, - cutlass::arch::NamedBarrier& tmem_allocation_barrier, cutlass::arch::NamedBarrier& epilog_barrier, int NumSplits, int tid, // tid local to epilog warp @@ -1024,7 +939,9 @@ EPILOG_warp( // wait for tmem allocation in mma warp to complete, only do the wait for valid tiles if (work_tile_info.is_valid()) { - tmem_allocation_barrier.arrive_and_wait(); + arrive_barrier(shared_storage.tmem_allocation_result_barrier); + // initial phase bit = 0, it will flip to 1 when tmem is allocated, so we wait for old phase bit of 0 + wait_barrier(shared_storage.tmem_allocation_result_barrier, 0); } // update tmem base ptr of the accumulator tensor @@ -1344,11 +1261,15 @@ EPILOG_warp( static_assert(MaxSplits <= 32, "we can use 1 warp to initialize mailbox"); // initialize mailbox tensor for fmax and fsum, when NumSplits < MaxSplits, we need to init those value to -inf and 0 // because we do reduction on the full tensor (of size MaxSplits) not just the valid splits - // there is no need to init sAcc2 because it will be scaled with beta which will be 0 for invalid splits if (tid < MaxSplits) { fill(sFmaxMailbox(tid, _), -cutlass::platform::numeric_limits::infinity()); clear(sFsumMailbox(tid, _)); } + // we also need to clear out acc2 mailbox for invalid splits, because acc2 value could be nan + // nan * 0 (beta) = nan, we still need to clear acc2 + if (tid < CTA_dH) { + clear(sAcc2Mailbox(tid, _, _)); + } // ensure initialized smem is visible to the entire cluster cutlass::arch::fence_view_async_shared(); @@ -1606,7 +1527,7 @@ EPILOG_warp( } // signal the mma warp tcgen05.ld of bmm2 is done, can start deallocate all tmem - tmem_allocation_barrier.arrive(); + arrive_barrier(shared_storage.tmem_allocation_result_barrier); } // only NumReductionCTA number of reduction ctas will do the reduction @@ -1734,15 +1655,16 @@ EPILOG_warp( }*/ } -// K has shape (kvL, dH, kvH, BS) -// Q has shape ((qHLocal, qL), dH, kvH, BS) -// V has shape (dH, kvL, kvH, BS) -// O has shape (dH, (qHLocal, qL), kvH, BS) -// sinks has shape ((qHLocal, qL), kvH) -// seq_len has shape (BS) +// mK has shape (kvL, dH, kvH, BS) +// mQ has shape ((qHLocal, qL), dH, kvH, BS) +// mV has shape (dH, kvL, kvH, BS) +// mO has shape (dH, (qHLocal, qL), kvH, BS) +// mSink has shape ((qHLocal, qL), kvH) +// mSeqLens has shape (BS) template < class SharedStorage, class KTensor, class QTensor, class VTensor, class OTensor, class SinkTensor, + class SeqLensTensor, class TmaAtomK, class TmaAtomQ, class TmaAtomV, class TiledBMM1, class TiledBMM2, class TypeAcc, @@ -1758,7 +1680,7 @@ gqa_device( VTensor mV, OTensor mO, SinkTensor mSink, - int* seq_lens, + SeqLensTensor mSeqLens, CUTE_GRID_CONSTANT TmaAtomK const tma_atom_K, CUTE_GRID_CONSTANT TmaAtomQ const tma_atom_Q, CUTE_GRID_CONSTANT TmaAtomV const tma_atom_V, @@ -1782,7 +1704,7 @@ gqa_device( int BS_idx = blockIdx.x / kvH; // only thread 0 issues cp.async to load the seq_len if (threadIdx.x == 0) { - cp_async(&seq_lens[BS_idx], &shared_storage.seq_len); + cp_async(&mSeqLens(BS_idx), &shared_storage.seq_len); } //if (threadIdx.x == 0) { @@ -1805,15 +1727,13 @@ gqa_device( cutlass::arch::detail::initialize_barrier_array_aligned(&shared_storage.bmm1_softmax_full_barrier, /* arrival count */ 1); // 1 thread (BMM2) arrive to signal epilog cutlass::arch::detail::initialize_barrier_array_aligned(&shared_storage.bmm2_epilog_full_barrier, /* arrival count */ 1); + // 32 (mma) + 128 (epilog) to signal tmem allocation/deallocation result + cutlass::arch::detail::initialize_barrier_array_aligned(&shared_storage.tmem_allocation_result_barrier, /* arrival count */ 32 + 128); // 1 thread (epilog) arrive to signal maxsum cutlass::arch::detail::initialize_barrier_array_aligned(&shared_storage.maxsum_mailbox_full_barrier, /* arrival count */ 1); // 1 thread (epilog) arrive to signal acc2 cutlass::arch::detail::initialize_barrier_array_aligned(&shared_storage.acc2_mailbox_full_barrier, /* arrival count */ 1); } - // Sync tmem allocation status between MMA and softmax/epilogue warps within CTA - // 32 threads (mma) + 128 threads (epilog) to sync - // also used for tmem deallocation between epilog warps and mma warps within CTA - cutlass::arch::NamedBarrier tmem_allocation_barrier(32 + 128, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); // syncing all threads (128) within 4 epilog warps cutlass::arch::NamedBarrier epilog_barrier(128, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); @@ -1913,13 +1833,13 @@ gqa_device( DMA_KV_warp(shared_storage, work_tile_info, mK, mV, &tma_atom_K, &tma_atom_V, tiled_bmm1, tiled_bmm2); } else if (warp_idx == 2) { - MMA_warp(shared_storage, work_tile_info, mO, tiled_bmm1, tiled_bmm2, tmem_allocation_barrier); - } + MMA_warp(shared_storage, work_tile_info, mO, tiled_bmm1, tiled_bmm2); + } else if (warp_idx >= 4) { // epilog tid is from 128 to 255, need to offset by -128 when getting the per thread slice int tid = threadIdx.x - 128; // warp_idx - 4 because epilog warp group starts from warp 4 - EPILOG_warp(shared_storage, work_tile_info, mK, mO, mSink, seq_len, tiled_bmm1, tiled_bmm2, softmax_scale_log2, sliding_window_size, tmem_allocation_barrier, epilog_barrier, NumSplits, tid, warp_idx - 4, rank); + EPILOG_warp(shared_storage, work_tile_info, mK, mO, mSink, seq_len, tiled_bmm1, tiled_bmm2, softmax_scale_log2, sliding_window_size, epilog_barrier, NumSplits, tid, warp_idx - 4, rank); } __syncthreads(); @@ -1980,6 +1900,7 @@ void gqa_host( Tensor mV = make_tensor(make_gmem_ptr(device_ptr_V), layout_V); // (dH, kvL, kvH, BS) Tensor mO = make_tensor(make_gmem_ptr(device_ptr_O), layout_O); // (dH, (qHLocal, qL), kvH, BS) Tensor mSink = make_tensor(make_gmem_ptr(device_ptr_sinks), layout_sinks); // ((qHLocal, qL), kvH) + Tensor mSeqLens = make_tensor(make_gmem_ptr(seq_lens), make_layout(make_shape(BS))); // (BS) //printf("mK: "); print(mK); printf("\n"); //printf("mQ: "); print(mQ); printf("\n"); @@ -2174,6 +2095,7 @@ void gqa_host( if (device_ptr_sinks != nullptr) { auto *kernel_instance = &gqa_device +#include +#include +#include +#include + +// Cutlass includes +#include +#include +#include +#include // mma/smem selector, umma::major +#include +#include + +// CuTe includes +#include // CuTe tensor implementation +#include // TMEM allocator for SM100 +#include +#include + +#include "common.cuh" +#include "tgv_gqa.cuh" // reuse layout helpers, WorkTileInfo, TMA_copy, MMA_gemm, reductions, DMA_Q/MMA/EPILOG warps + +// Grouped Query Attention (paged KV cache) with dual BMM + online softmax. 7 warps: 1 DMA_Q, 1 DMA_KV, 1 MMA, +// 4 EPILOG. Warp 3 is unused. +// Warp 0 (DMA_Q): Loads Q via TMA, single-stage. Reused from gqa namespace (gqa::DMA_Q_warp). +// Warp 1 (DMA_KV): Loads K then V via TMA, one cp.async.bulk per page. Issues its own cp.async of per-CTA-tile +// page indices into a smem staging buffer (lane-distributed, thread-local cp_async fence/wait). +// Defined locally in gqa_paged namespace. +// Warp 2 (MMA): Performs BMM1 (K@Q) and BMM2 (V@P). Reused from gqa namespace (gqa::MMA_warp). +// Warps 4-7 (EPILOG): Softmax partial max/sum warp reduction, cluster wide max/sum reduction, final flash-decode +// output with attention sink support. Reused from gqa namespace (gqa::EPILOG_warp). +// WorkTileInfo: Reused from gqa namespace. Attention-specific fields: BS_idx, kvH_idx, kvL_idx_start/end, dH_idx, qHLocal_idx, qL_idx. +// SharedStorage: Defined locally. Inherits from gqa::SharedStorage and adds the paged-only pieces: PageIdx smem +// buffer (double-buffered because K and V are offset by 1 tile -- around pi-stage boundaries they read different +// slots) and paged views (tensor_sK_paged / tensor_sV_paged / tensor_sPageIdx) over the inherited K/V smem buffers. + +namespace TGV { +namespace gqa_paged { + +using namespace cute; + +// Symbols reused identically from TGV::gqa -- see tgv_gqa.cuh for definitions. +// Log2_E and WorkTileInfo are pulled in here; everything else is referenced explicitly via gqa:: at call sites. +using TGV::gqa::Log2_E; +using TGV::gqa::WorkTileInfo; + +// The (bs_idx, per_batch_page_idx) -> physical page id mapping lives in a gmem page_table tensor of shape +// (kvL/Page_Size, BS), seeded by the host harness and fetched at runtime by DMA_KV_warp. To install a +// new placement policy, populate the page_table differently host-side; the kernel does not assume any structure. +// The shared memory buffers for Q, K, V matrices. +template < + class TypeQKV, // Tensor Q/K/V data type + class TypeAcc, // Tensor Acc data type + class KSmemLayout, // ((Mma_M, Mma_K), NumMma_M, NumMma_K, BMM1_DMA_Stage) + class KPagedSmemLayout,// (Page_Size, CTA_dH, NumPagePerCTATile, BMM1_DMA_Stage), same memory as KSmemLayout + class QSmemLayout, // ((Mma_N, Mma_K), NumMma_N, NumMma_K, 1) + class VSmemLayout, // ((Mma_M, Mma_K), NumMma_M, NumMma_K, BMM2_DMA_Stage) + class VPagedSmemLayout,// (CTA_dH, Page_Size, NumPagePerCTATile, BMM2_DMA_Stage), same memory as VSmemLayout + class SSmemLayout, // (CTA_kvL, (CTA_qHLocal, CTA_qL), 1) aka C matrix (M, N, 1) for bmm1 + class PSmemLayout, // ((CTA_qHLocal, CTA_qL), CTA_kvL, 1) aka B matrix (N, K, 1) for bmm2 + class WRSmemLayout, // (NumEpilogWarps, (CTA_qHLocal, CTA_qL)), WR stands for warp reduce + class MSMailboxSmemLayout,// (MaxSplits, CTA_qHLocal * CTA_qL / NumReductionCTA), MS stands max and sum + class Acc1SmemLayout, // (CTA_kvL, (CTA_qHLocal, CTA_qL), 1) + class Acc2MailboxSmemLayout, // (CTA_dH, CTA_qHLocal * CTA_qL / NumReductionCTA, MaxSplits) + class SinksSmemLayout, // (CTA_qHLocal * CTA_qL / NumReductionCTA) + class PageIdxSmemLayout, // ((NumPagePerCTATile, Tiles_Per_Pi_Stage), Page_Idx_Stage), int32 page indices staged by DMA_KV warp + int BMM1_DMA_Stage, + int BMM2_DMA_Stage, + int Page_Idx_Stage> +// Paged kernel adds two pieces on top of the plain gqa SharedStorage: +// - PageIdx smem buffer (DMA_KV staging, int32 page indices). Double-buffered: K and V are 1 tile apart, so +// around a pi-stage boundary they read different slots. DMA_KV writes the next slot via lane-distributed +// cp.async at K_t_in_stage==1 (when V has crossed into the current pi-stage and the next slot is free). +// - paged views of the inherited K/V smem buffers (same memory, different layout) +struct SharedStorage : TGV::gqa::SharedStorage< + TypeQKV, TypeAcc, + KSmemLayout, QSmemLayout, VSmemLayout, SSmemLayout, PSmemLayout, + WRSmemLayout, MSMailboxSmemLayout, Acc1SmemLayout, + Acc2MailboxSmemLayout, SinksSmemLayout, + BMM1_DMA_Stage, BMM2_DMA_Stage> { + // DMA_KV staging buffer for int32 page indices, layout ((NumPagePerCTATile, Tiles_Per_Pi_Stage), Page_Idx_Stage) + alignas(128) cute::ArrayEngine> PageIdx; + + // alternative paged view of the same K smem buffer used by TMA: (Page_Size, CTA_dH, NumPagePerCTATile, BMM1_DMA_Stage) + CUTE_DEVICE constexpr auto tensor_sK_paged() { return make_tensor(make_smem_ptr(this->K.begin()), KPagedSmemLayout{}); } + // alternative paged view of the same V smem buffer used by TMA: (CTA_dH, Page_Size, NumPagePerCTATile, BMM2_DMA_Stage) + CUTE_DEVICE constexpr auto tensor_sV_paged() { return make_tensor(make_smem_ptr(this->V.begin()), VPagedSmemLayout{}); } + CUTE_DEVICE constexpr auto tensor_sPageIdx() { return make_tensor(make_smem_ptr(PageIdx.begin()), PageIdxSmemLayout{}); } +}; + +// paged TMA copy: like gqa::TMA_copy, but issues NumPagePerCTATile back-to-back copies that share one full barrier. +// empty/full barrier semantics still operate at CTA-tile granularity -- one slot covers the whole (CTA_kvL, CTA_dH) tile, +// transaction bytes counts all pages, and a single set_barrier_transaction_bytes arrives once per stage. +// +// Page indices are read from smem -- DMA_KV stages them from the gmem page_table via lane-distributed +// cp.async and a thread-local cp_async_fence/wait. The caller passes a smem pointer to the +// NumPagePerCTATile page indices owned by this tile; this function performs an ld.shared per page and uses the +// resulting global page id as the gmem coordinate for the TMA copy. +template < + class GTensor, + class STensor, + class TmaAtom, + char Name, + bool Print, + int DMA_Stage, + int NumPagePerCTATile> +CUTLASS_DEVICE void +TMA_copy_paged( + GTensor gTensor, // ((TMA, NumTma_K), Num_Page_Global) + STensor sTensor, // ((TMA, NumTma_K), NumPagePerCTATile, DMA_Stage) + int k_tile, + int const* page_idx_smem, // pointer to NumPagePerCTATile contiguous int page indices in smem (this tile's slice) + int& tma_mma_empty_barrier_phase_bit, + int tma_transaction_bytes, // total bytes for one CTA tile (NumPagePerCTATile pages) + TmaAtom const* tma_atom, + cute::uint64_t* tma_mma_full_barrier, + cute::uint64_t* tma_mma_empty_barrier +) { + // wait for the smem slot to be empty before issuing pages for the next CTA tile + wait_barrier(tma_mma_empty_barrier[k_tile % DMA_Stage], tma_mma_empty_barrier_phase_bit); + + if constexpr (Print) { + if (elect_one_sync() && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 0)) { + printf("[DMA_%c] barrier empty, kblock %d (paged, %d pages)\n", Name, k_tile, NumPagePerCTATile); + } + } + + if (elect_one_sync()) { + // single set_barrier_transaction_bytes accounts for all pages; later cute::copy calls don't add arrivals. + set_barrier_transaction_bytes(tma_mma_full_barrier[k_tile % DMA_Stage], tma_transaction_bytes); + CUTE_UNROLL + for (int p = 0; p < NumPagePerCTATile; p++) { + // per-page lookup: load global page id from smem and feed it as the gmem coordinate for one TMA copy. + int global_page = page_idx_smem[p]; + copy(tma_atom->with(tma_mma_full_barrier[k_tile % DMA_Stage]), + gTensor(_, global_page), + sTensor(_, p, k_tile % DMA_Stage)); + } + } + + if ((k_tile % DMA_Stage) == (DMA_Stage - 1)) { + tma_mma_empty_barrier_phase_bit ^= 1; + } +} + +// Lane-distributed cp.async of one pi-stage's worth (Num_Page_Idx_Per_Stage ints) of page indices from gmem +// page table to a smem slot. Each lane loads ceil(Num_Page_Idx_Per_Stage/32) ints; the LSU coalesces them. +// Caller pre-slices the BS mode of the page table (it is CTA-constant) and is responsible for +// cp_async_fence/wait + __syncwarp ordering after this returns. +template < + int Tiles_Per_Pi_Stage, int Num_Page_Idx_Per_Stage, int Page_Idx_Stage, + class GPageTableTensor, class SPageIdxTensor> +CUTLASS_DEVICE void +issue_pi_stage_cp_async( + GPageTableTensor gPageTable, // ((NumPagePerCTATile, kvL/CTA_kvL),), int gmem -- BS already sliced + SPageIdxTensor sPageIdx, // ((NumPagePerCTATile, Tiles_Per_Pi_Stage), Page_Idx_Stage), int smem + int kvL_idx_start, int pi_stage_idx) { + int tile_start = kvL_idx_start + pi_stage_idx * Tiles_Per_Pi_Stage; + int slot = pi_stage_idx % Page_Idx_Stage; + // Shift gPageTable's data pointer to the first page of this pi-stage; flat-index `i` walks consecutive + // (page-within-CTA-tile, CTA-tile) pairs in column-major order across Num_Page_Idx_Per_Stage ints. + auto gp = domain_offset(make_coord(make_coord(0, tile_start)), gPageTable); + // Use threadIdx.x-derived lane (cheap: threadIdx.x is already live in a register from earlier in the kernel). + int lane = cutlass::canonical_lane_idx(); + CUTE_UNROLL + for (int i = lane; i < Num_Page_Idx_Per_Stage; i += 32) { + cp_async(&gp(i), &sPageIdx(i, slot)); + } +} + +// Paged DMA_KV warp: TMA-loads K then V, one cp.async.bulk per page. Page indices are staged into smem by this +// same warp via lane-distributed cp.async (no separate Read_Page_Idx warp, no transaction-barrier handshake). +// See the in-body comment ahead of the prolog for the page-idx fetch pipeline. Tiles_Per_Pi_Stage >= 2 is +// required so the K_t_in_stage==1 issue hook exists. +template < + class SharedStorage, + class WorkTileInfo, + class KTensor, + class VTensor, + class PageTableTensor, + class TmaAtomK, + class TmaAtomV, + int CTA_kvL, int CTA_dH, int Page_Size, + int Page_Idx_Stage, int Num_Page_Idx_Per_Stage> +CUTLASS_DEVICE void +DMA_KV_warp( + SharedStorage& shared_storage, + WorkTileInfo work_tile_info, + KTensor mK, + VTensor mV, + PageTableTensor mPageTable, // (kvL/Page_Size, BS), int gmem; underlying allocation tail-padded by Num_Page_Idx_Per_Stage ints. Mode-0 is partitioned by NumPagePerCTATile to access per-CTA-tile slices. + // when passing tma descriptor as function argument, it has to be pass by pointer/reference, if pass by value, it will live on local memory (i.e. the stack) + // and the tma unit cannot access the local memory, (even if it can, the local memory is strided by thread id, the content for each thread is strided) + TmaAtomK const* tma_atom_K, + TmaAtomV const* tma_atom_V) { + + if (!work_tile_info.is_valid()) { + return; + } + + // CTA_kvL % Page_Size == 0, Num_Page_Idx_Per_Stage % NumPagePerCTATile == 0, Page_Idx_Stage == 2, + // and Tiles_Per_Pi_Stage >= 2 are all checked in gqa_paged_host. + constexpr int NumPagePerCTATile = CTA_kvL / Page_Size; + constexpr int Tiles_Per_Pi_Stage = Num_Page_Idx_Per_Stage / NumPagePerCTATile; + + // setup code for K tensor + // paged smem view of K, NOT mma partitioned, used purely for TMA: (Page_Size, CTA_dH, NumPagePerCTATile, BMM1_DMA_Stage) + // the same smem buffer is also accessible via tensor_sK() in MMA-partitioned form for the MMA warp. + Tensor sK_paged = shared_storage.tensor_sK_paged(); // (Page_Size, CTA_dH, NumPagePerCTATile, BMM1_DMA_Stage) + // mK has shape (Page_Size, dH, Num_Page_Global, kvH); BS is folded into Num_Page_Global so there is no BS mode. + // local_tile with a static (Page_Size, CTA_dH) tile_shape materializes the leading two modes statically (TMA + // partition requires static mode-0 sizes); dH==CTA_dH and Page_Size already matches so the divisions collapse. + // The page selection is done at TMA-issue time via the smem page_idx slice (staged by DMA_KV's own cp.async), + // not by indexing on a BS mode here. + Tensor gK = local_tile(mK, make_shape(Int{}, Int{}), + make_coord(0, 0, _, work_tile_info.kvH_idx)); // (Page_Size, CTA_dH, Num_Page_Global) + + // group modes [0,2) on both sK_paged and gK so the TMA box (Page_Size, CTA_dH) is mode 0; outer modes are + // (NumPagePerCTATile, BMM1_DMA_Stage) for smem and Num_Page_Global for gmem. + auto [tAgK, tAsK] = tma_partition(*tma_atom_K, + Int<0>{}, // cta_coord: 1x1 cluster + Layout<_1>{}, // cta_layout: CTA coord -> logical multicast id, no multicast, just identity layout + group_modes<0, 2>(sK_paged), group_modes<0, 2>(gK)); + // tAsK: ((TMA, NumTma_K), NumPagePerCTATile, BMM1_DMA_Stage) -- 3 modes + // tAgK: ((TMA, NumTma_K), Num_Page_Global) -- 2 modes + // the shape of the TMA box is (Page_Size, CTA_dH) + + /*if (elect_one_sync() && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 0)) { + printf("[PAGED-K] sK_paged:\t"); print(sK_paged); printf("\n"); // (Page_Size, CTA_dH, NumPagePerCTATile, BMM1_DMA_Stage) + printf("[PAGED-K] gK:\t"); print(gK); printf("\n"); // (Page_Size, CTA_dH, Num_Page_Global) + printf("[PAGED-K] tAgK:\t"); print(tAgK); printf("\n"); // ((TMA, NumTma_K), Num_Page_Global) + printf("[PAGED-K] tAsK:\t"); print(tAsK); printf("\n"); // ((TMA, NumTma_K), NumPagePerCTATile, BMM1_DMA_Stage) + }*/ + + // setup code for V tensor + // paged smem view of V, NOT mma partitioned, used purely for TMA: (CTA_dH, Page_Size, NumPagePerCTATile, BMM2_DMA_Stage) + // the same smem buffer is also accessible via tensor_sV() in MMA-partitioned form for the MMA warp. + Tensor sV_paged = shared_storage.tensor_sV_paged(); // (CTA_dH, Page_Size, NumPagePerCTATile, BMM2_DMA_Stage) + // mV has shape (dH, Page_Size, Num_Page_Global, kvH); BS is folded into Num_Page_Global so there is no BS mode. + Tensor gV = local_tile(mV, make_shape(Int{}, Int{}), + make_coord(0, 0, _, work_tile_info.kvH_idx)); // (CTA_dH, Page_Size, Num_Page_Global) + + // group modes [0,2) on both sV_paged and gV so the TMA box (CTA_dH, Page_Size) is mode 0; outer modes are + // (NumPagePerCTATile, BMM2_DMA_Stage) for smem and Num_Page_Global for gmem. + auto [tAgV, tAsV] = tma_partition(*tma_atom_V, + Int<0>{}, // cta_coord: 1x1 cluster + Layout<_1>{}, // cta_layout: CTA coord -> logical multicast id, no multicast, just identity layout + group_modes<0, 2>(sV_paged), group_modes<0, 2>(gV)); + // tAsV: ((TMA, NumTma_K), NumPagePerCTATile, BMM2_DMA_Stage) -- 3 modes + // tAgV: ((TMA, NumTma_K), Num_Page_Global) -- 2 modes + // the shape of the TMA box is (CTA_dH, Page_Size) + + /*if (elect_one_sync() && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 0)) { + printf("[PAGED-V] sV_paged:\t"); print(sV_paged); printf("\n"); // (CTA_dH, Page_Size, NumPagePerCTATile, BMM2_DMA_Stage) + printf("[PAGED-V] gV:\t"); print(gV); printf("\n"); // (CTA_dH, Page_Size, Num_Page_Global) + printf("[PAGED-V] tAgV:\t"); print(tAgV); printf("\n"); // ((TMA, NumTma_K), Num_Page_Global) + printf("[PAGED-V] tAsV:\t"); print(tAsV); printf("\n"); // ((TMA, NumTma_K), NumPagePerCTATile, BMM2_DMA_Stage) + }*/ + + // total K/V bytes per stage slot = all NumPagePerCTATile pages combined (they share one full barrier). + // Slice tAsK/tAsV at one stage to get all pages in that stage; size_in_bytes is implicit via sizeof(tensor_like). + int tma_K_transaction_bytes = sizeof(make_tensor_like(tAsK(_, _, 0))); + int tma_V_transaction_bytes = sizeof(make_tensor_like(tAsV(_, _, 0))); + int k_tile_count = work_tile_info.kvL_idx_end - work_tile_info.kvL_idx_start; + // BMM1_DMA_Stage = mode 3 of rank-4 sK_paged (Page_Size, CTA_dH, NumPage, Stage) + // BMM2_DMA_Stage = mode 3 of rank-4 sV_paged (CTA_dH, Page_Size, NumPage, Stage) + int constexpr BMM1_DMA_Stage = size<3>(sK_paged); + int constexpr BMM2_DMA_Stage = size<3>(sV_paged); + + /*if (elect_one_sync() && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 0)) { + printf("[PAGED-K] tma_K_transaction_bytes (per CTA tile)=%d\n", tma_K_transaction_bytes); + printf("[PAGED-V] tma_V_transaction_bytes=%d, k_tile_count=%d, BMM1_DMA_Stage=%d, BMM2_DMA_Stage=%d\n", tma_V_transaction_bytes, k_tile_count, BMM1_DMA_Stage, BMM2_DMA_Stage); + }*/ + + // details of how the phase bit works is in DMA_Q_warp (in tgv_gqa.cuh) + int tma_bmm1_empty_barrier_phase_bit = 1; + int tma_bmm2_empty_barrier_phase_bit = 1; + + int bmm1_k_tile = 0; + int bmm2_k_tile = 0; + + bool constexpr Print = false; + + // sPageIdx layout: ((NumPagePerCTATile, Tiles_Per_Pi_Stage), Page_Idx_Stage). For tile t (CTA-tile idx): + // t_in_stage = t % Tiles_Per_Pi_Stage; pi_slot = (t / Tiles_Per_Pi_Stage) % Page_Idx_Stage; + // page_idx_smem_ptr = &sPageIdx(make_coord(0, t_in_stage), pi_slot). + Tensor sPageIdx = shared_storage.tensor_sPageIdx(); + + // gPageTable view: per-CTA slice of the page table, then split mode 0 by NumPagePerCTATile so the + // (page-within-CTA-tile, CTA-tile-idx) split is explicit. BS is sliced upfront -- it's CTA-constant. + // Tail-padded host allocation lets us read up to Tiles_Per_Pi_Stage tiles past kvL_idx_end without OOB. + Tensor gPageTable = logical_divide(mPageTable(_, work_tile_info.BS_idx), Shape>{}); // ((NumPagePerCTATile, kvL/CTA_kvL),) + + // SOL order: K0, (K1, V0), (K2, V1), ... + // + // Page-idx fetch pipeline. The page-idx smem buffer is double-buffered (Page_Idx_Stage=2 slots) because + // K leads V by 1 tile -- around a pi-stage boundary K reads slot S%2 (stage S) while V is finishing the + // last tile of slot (S-1)%2 (stage S-1), so the two slots must coexist. Synchronization is thread-local + // cp_async_fence/cp_async_wait<0> + __syncwarp -- no cross-warp barrier. + // * Prolog: issue stage 0 -> slot 0, fence, wait, syncwarp, then K0 TMA. + // * Main loop K_t_in_stage == 0: drain the previously-issued cp.async (stage K_pi_stage's data). At this + // point at most one cp.async group is outstanding so wait<0> is correct. + // * Main loop K_t_in_stage == 1: pre-issue stage K_pi_stage+1 -> slot (K_pi_stage+1)%2 if it has tiles + // in this CTA's range. V crossed into pi-stage K_pi_stage at the previous iteration, so slot + // (K_pi_stage+1)%2 = (K_pi_stage-1)%2 is no longer being ld.shared'd and is safe to overwrite. The + // consumer's wait at the next pi-stage boundary (Tiles_Per_Pi_Stage-1 tiles later) hides the RTT. + + // Prolog: stage 0 must be in slot 0 before K0 reads it. + { + issue_pi_stage_cp_async( + gPageTable, sPageIdx, work_tile_info.kvL_idx_start, /*pi_stage_idx=*/0); + cp_async_fence(); + cp_async_wait<0>(); + __syncwarp(); + + // K0 prolog (no V yet). + TMA_copy_paged(tAgK, tAsK, bmm1_k_tile, &sPageIdx(make_coord(0, 0), 0), tma_bmm1_empty_barrier_phase_bit, tma_K_transaction_bytes, tma_atom_K, shared_storage.tma_bmm1_full_barrier, shared_storage.tma_bmm1_empty_barrier); + bmm1_k_tile++; + } + + for (; bmm1_k_tile < k_tile_count; bmm1_k_tile++, bmm2_k_tile++) { + int K_t_in_stage = bmm1_k_tile % Tiles_Per_Pi_Stage; + int K_pi_stage = bmm1_k_tile / Tiles_Per_Pi_Stage; + int K_pi_slot = K_pi_stage % Page_Idx_Stage; + + // Drain stage K_pi_stage's cp.async before K reads its slot. + if (K_t_in_stage == 0) { + cp_async_wait<0>(); + __syncwarp(); + } + + // Pre-issue stage K_pi_stage+1 if there are still tiles in this CTA's range to consume from it. + if (K_t_in_stage == 1) { + int next_pi_stage = K_pi_stage + 1; + int next_tile_start = work_tile_info.kvL_idx_start + next_pi_stage * Tiles_Per_Pi_Stage; + if (next_tile_start < work_tile_info.kvL_idx_end) { + issue_pi_stage_cp_async( + gPageTable, sPageIdx, work_tile_info.kvL_idx_start, next_pi_stage); + cp_async_fence(); + } + } + + TMA_copy_paged(tAgK, tAsK, bmm1_k_tile, &sPageIdx(make_coord(0, K_t_in_stage), K_pi_slot), tma_bmm1_empty_barrier_phase_bit, tma_K_transaction_bytes, tma_atom_K, shared_storage.tma_bmm1_full_barrier, shared_storage.tma_bmm1_empty_barrier); + + int V_t_in_stage = bmm2_k_tile % Tiles_Per_Pi_Stage; + int V_pi_slot = (bmm2_k_tile / Tiles_Per_Pi_Stage) % Page_Idx_Stage; + TMA_copy_paged(tAgV, tAsV, bmm2_k_tile, &sPageIdx(make_coord(0, V_t_in_stage), V_pi_slot), tma_bmm2_empty_barrier_phase_bit, tma_V_transaction_bytes, tma_atom_V, shared_storage.tmasoftmax_bmm2_full_barrier, shared_storage.tma_bmm2_empty_barrier); + } + + // V epilog (last V). Its slot is already populated by an earlier cp.async. + { + int V_t_in_stage = bmm2_k_tile % Tiles_Per_Pi_Stage; + int V_pi_slot = (bmm2_k_tile / Tiles_Per_Pi_Stage) % Page_Idx_Stage; + TMA_copy_paged(tAgV, tAsV, bmm2_k_tile, &sPageIdx(make_coord(0, V_t_in_stage), V_pi_slot), tma_bmm2_empty_barrier_phase_bit, tma_V_transaction_bytes, tma_atom_V, shared_storage.tmasoftmax_bmm2_full_barrier, shared_storage.tma_bmm2_empty_barrier); + } + + cutlass::arch::launch_dependent_grids(); +} + +// mK has shape (Page_Size, dH, num_pages, kvH) +// mQ has shape ((qHLocal, qL), dH, kvH, BS) +// mV has shape (dH, Page_Size, num_pages, kvH) +// mO has shape (dH, (qHLocal, qL), kvH, BS) +// mSink has shape ((qHLocal, qL), kvH) +// mSeqLens has shape (BS) +// mPageTable has shape (kvL/Page_Size, BS), entry [p, bs] = physical page id of batch bs's per-batch page p +// (BS is folded into num_pages on the K/V side; mPageTable supplies the (bs, kvL) -> page mapping.) +template < + class SharedStorage, + class KTensor, class QTensor, class VTensor, class OTensor, class SinkTensor, + class SeqLensTensor, class PageTableTensor, + class TmaAtomK, class TmaAtomQ, class TmaAtomV, + class TiledBMM1, class TiledBMM2, + class TypeAcc, + int CTA_qHLocal, int CTA_qL, int CTA_kvL, int CTA_dH, + int Page_Size, + int BMM1_DMA_Stage, int BMM2_DMA_Stage, + int Page_Idx_Stage, int Num_Page_Idx_Per_Stage, + int MaxSplits, int NumReductionCTA, + bool NoSink> +__maxnreg__(128) +__global__ void +gqa_paged_device( + KTensor mK, + QTensor mQ, + VTensor mV, + OTensor mO, + SinkTensor mSink, + SeqLensTensor mSeqLens, + PageTableTensor mPageTable, // (kvL/Page_Size, BS); underlying allocation tail-padded by Num_Page_Idx_Per_Stage ints + CUTE_GRID_CONSTANT TmaAtomK const tma_atom_K, + CUTE_GRID_CONSTANT TmaAtomQ const tma_atom_Q, + CUTE_GRID_CONSTANT TmaAtomV const tma_atom_V, + TiledBMM1 tiled_bmm1, + TiledBMM2 tiled_bmm2, + float softmax_scale_log2, + int sliding_window_size, + int pdl_count +) { + // Allocate SMEM + extern __shared__ char shared_memory[]; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // WorkTileInfo, for non persistent static scheduler, cta id is the work tile info + // since loading the seq_lens is on the critical path of the prolog, we want to start it as soon as possible + // mK shape: (Page_Size, CTA_dH, Num_Page_Global, kvH) -- BS is folded into Num_Page_Global; kvH is mode 3. + int kvH = shape<3>(mK); + int BS_idx = blockIdx.x / kvH; + // only thread 0 issues cp.async to load the seq_len -- keep this as the first thing on the critical path. + if (threadIdx.x == 0) { + cp_async(&mSeqLens(BS_idx), &shared_storage.seq_len); + } + + int warp_idx = cutlass::canonical_warp_idx_sync(); + + // barrier initialization, warp 0 does initialization + if (warp_idx == 0) { + // transaction barrier because tma arrive on it, 6 thread arrive: one for DMA_Q warp, one for DMA_KV (K fetch) warp, and 4 for softmax warp (tcgen05.ld Acc1 is done). + // For paged K, DMA_KV still arrives ONCE per stage (via set_barrier_transaction_bytes once with total bytes for all pages). + cutlass::arch::detail::initialize_barrier_array_aligned(shared_storage.tma_bmm1_full_barrier, /* arrival count */ 6); + // 1 thread (BMM1) arrive to signal DMA_Q and DMA_KV (K fetch) warp + cutlass::arch::detail::initialize_barrier_array_aligned(shared_storage.tma_bmm1_empty_barrier, /* arrival count */ 1); + // transaction barrier because tma arrive on it, 5 thread arrive: one for DMA_KV (V fetch) warp and 4 for softmax warp (S/P store) + cutlass::arch::detail::initialize_barrier_array_aligned(shared_storage.tmasoftmax_bmm2_full_barrier, /* arrival count */ 5); + // 1 thread (BMM2) arrive to signal DMA_KV (V fetch) and softmax warp (P store) + cutlass::arch::detail::initialize_barrier_array_aligned(shared_storage.tma_bmm2_empty_barrier, /* arrival count */ 1); + // 1 thread (BMM1) arrive to signal softmax + cutlass::arch::detail::initialize_barrier_array_aligned(&shared_storage.bmm1_softmax_full_barrier, /* arrival count */ 1); + // 1 thread (BMM2) arrive to signal epilog + cutlass::arch::detail::initialize_barrier_array_aligned(&shared_storage.bmm2_epilog_full_barrier, /* arrival count */ 1); + // 32 (mma) + 128 (epilog) to signal tmem allocation/deallocation result + cutlass::arch::detail::initialize_barrier_array_aligned(&shared_storage.tmem_allocation_result_barrier, /* arrival count */ 32 + 128); + // 1 thread (epilog) arrive to signal maxsum + cutlass::arch::detail::initialize_barrier_array_aligned(&shared_storage.maxsum_mailbox_full_barrier, /* arrival count */ 1); + // 1 thread (epilog) arrive to signal acc2 + cutlass::arch::detail::initialize_barrier_array_aligned(&shared_storage.acc2_mailbox_full_barrier, /* arrival count */ 1); + // No page-idx full/empty barriers: DMA_KV fetches its own page indices via cp.async with thread-local + // cp_async_fence/wait, so the cross-warp transaction-barrier handshake is gone. + } + // syncing all threads (128) within 4 epilog warps + cutlass::arch::NamedBarrier epilog_barrier(128, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + + // barrier initialization needs to be visible to all warps + // defer it as late as possible to allow some thread divergence in prolog + cutlass::arch::fence_barrier_init(); +#if 0 + // this will have a membar.gpu to ensure dsmem write visibility within the entire cluster, because there isn't a membar.cluster + // membar.gpu is 0.2us + cluster_sync(); +#else + // the alternative is to use proper fences + // at the ptx level, fence.mbarrier_init.release.cluster act as a release fence (in cluster scope) for mbarrier init op + cutlass::arch::fence_barrier_init(); + // thread 0 waits for its previously issued cp.async to complete + // here we overlap the cp.async with the barrier initialization as much as possible + if (threadIdx.x == 0) { + cp_async_fence(); + cp_async_wait<0>(); + } + + // the cluster sync serves two purposes: + // 1. it waits for all threads in the cluster to see the barrier initialization + // 2. it waits for all threads in the cta to see the cp.async result in smem + cluster_arrive_relaxed(); + cluster_wait(); +#endif + + // bid.y includes rasterization of qHLocal and qL + int qHLocal = shape<0>(shape<0>(mQ)); + int num_qHLocal = cutlass::ceil_div(qHLocal, CTA_qHLocal); + // bid.z is the split id for kvL split, try to evenly distribute the kvL blocks to every CTA + int rank = blockIdx.z; + int seq_len = shared_storage.seq_len; + // Sliding window optimization: when enabled, we only process tokens in range + // [seq_len - sliding_window_size, seq_len). To simplify tile distribution, + // we align the skip boundary to CTA_kvL tiles. + int workload_seq_len = seq_len; + int seq_len_skip_offset = 0; + if (sliding_window_size > 0 && sliding_window_size < seq_len) { + int unaligned_skip = seq_len - sliding_window_size; + seq_len_skip_offset = (unaligned_skip / CTA_kvL) * CTA_kvL; + workload_seq_len = seq_len - seq_len_skip_offset; + } + int NumKVTiles = cutlass::ceil_div(workload_seq_len, CTA_kvL); + int NumSplits = cute::min(MaxSplits, NumKVTiles); + int kvL_tile_count_per_cta = NumKVTiles / MaxSplits; + int kvL_tile_count_remainder = NumKVTiles % MaxSplits; + int kvL_tile_count = kvL_tile_count_per_cta + (rank < kvL_tile_count_remainder ? 1 : 0); + int kvL_tile_count_skip_offset = seq_len_skip_offset / CTA_kvL; + int kvL_tile_count_start = rank * kvL_tile_count_per_cta + (rank < kvL_tile_count_remainder ? rank : kvL_tile_count_remainder) + kvL_tile_count_skip_offset; + int kvL_tile_count_end = kvL_tile_count_start + kvL_tile_count; + WorkTileInfo work_tile_info { + .BS_idx = (int32_t)BS_idx, + .kvH_idx = (int32_t)(blockIdx.x % kvH), + .kvL_idx_start = (int32_t)kvL_tile_count_start, + .kvL_idx_end = (int32_t)kvL_tile_count_end, + .dH_idx = 0, // no dH tiling for bmm2 + .qHLocal_idx = (int32_t)(blockIdx.y % num_qHLocal), + .qL_idx = (int32_t)(blockIdx.y / num_qHLocal), + .is_valid_tile = (kvL_tile_count > 0) + }; + + if (warp_idx == 0) { + gqa::DMA_Q_warp(shared_storage, work_tile_info, mQ, &tma_atom_Q, tiled_bmm1); + } + else if (warp_idx == 1) { + DMA_KV_warp(shared_storage, work_tile_info, mK, mV, mPageTable, &tma_atom_K, &tma_atom_V); + } + else if (warp_idx == 2) { + gqa::MMA_warp(shared_storage, work_tile_info, mO, tiled_bmm1, tiled_bmm2); + } + // warp_idx == 3 is unused (the previous Read_Page_Idx warp was folded into DMA_KV). + else if (warp_idx >= 4) { + // epilog tid is from 128 to 255, need to offset by -128 when getting the per thread slice + int tid = threadIdx.x - 128; + // EPILOG_warp only uses mK for shape(mK) (to build a kvL coord predicate). mK on the paged kernel has shape + // (Page_Size, CTA_dH, Num_Page_Global, kvH) with BS folded into Num_Page_Global, so we build a coord-only + // identity tensor with the original (kvL, dH, kvH, BS) shape. BS comes straight from mO (rank 4, mode 3). + // kvL is recovered from mPageTable shape (kvL/Page_Size, BS); Page_Size is a compile-time constant so this + // is a constant multiply (no integer divide), unlike deriving from shape<2>(mK)/BS which costs ~3%. + int dH = shape<1>(mK); + int BS = shape<3>(mO); + int kvL = static_cast(shape<0>(mPageTable)) * Page_Size; + auto mK_coord = make_identity_tensor(make_shape(kvL, dH, kvH, BS)); + // warp_idx - 4 because epilog warp group starts from warp 4 + gqa::EPILOG_warp(shared_storage, work_tile_info, mK_coord, mO, mSink, seq_len, tiled_bmm1, tiled_bmm2, softmax_scale_log2, sliding_window_size, epilog_barrier, NumSplits, tid, warp_idx - 4, rank); + } + + __syncthreads(); +} + +// KV has shape (num_pages_total, 2, Page_Size, kvH, dH); KV(_,0,...) is K, KV(_,1,...) is V; BS is folded into num_pages_total +// Q has shape ((qHLocal, qL), dH, kvH, BS) +// O has shape (dH, (qHLocal, qL), kvH, BS) +// sinks has shape (qHLocal * kvH), i.e. one sink per q head, when device_ptr_sinks is nullptr, it's disabled +// seq_lens has shape (BS); kvL is max_seq_len, seq_lens[bs] is the actual seq len for batch bs +// page_table has shape (kvL/Page_Size, BS) +// sliding_window_size is the size of the sliding window, when it's 0, it's disabled +template< + class TypeQKV, class TypeO, class TypeAcc, + int CTA_qHLocal, int CTA_qL, int CTA_kvL, int CTA_dH, + int Page_Size, + int BMM1_DMA_Stage, int BMM2_DMA_Stage, + int Page_Idx_Stage, int Num_Page_Idx_Per_Stage, + int MaxSplits, int NumReductionCTA> +void gqa_paged_host( + TypeQKV* device_ptr_KV, + TypeQKV* device_ptr_Q, + TypeO* device_ptr_O, + TypeAcc* device_ptr_sinks, + int* seq_lens, + int* device_ptr_page_table, // (kvL/Page_Size, BS); underlying allocation must be tail-padded by Num_Page_Idx_Per_Stage ints + int kvH, int qHLocal, int qL, int kvL, int dH, int BS, + int stride_KV_pages, int stride_KV_KV, int stride_KV_ps, int stride_KV_kvH, int stride_KV_dH, + int stride_Q_kvH, int stride_Q_qHLocal, int stride_Q_qL, int stride_Q_dH, int stride_Q_BS, + int stride_O_kvH, int stride_O_qHLocal, int stride_O_qL, int stride_O_dH, int stride_O_BS, + int stride_PT_p, int stride_PT_BS, + float softmax_scale, + int sliding_window_size, + bool pdl, int pdl_count = -1, + cudaStream_t stream = 0 +) { + assert(kvL % Page_Size == 0); + int num_pages_total = BS * (kvL / Page_Size); + + // Reconstruct the combined KV gmem tensor exactly as the harness laid it out: shape + // (num_pages_total, 2, Page_Size, kvH, dH) with the 5 strides supplied by the caller. + auto layout_KV = make_layout( + make_shape(num_pages_total, Int<2>{}, Int{}, kvH, dH), + make_stride(stride_KV_pages, stride_KV_KV, stride_KV_ps, stride_KV_kvH, stride_KV_dH)); + Tensor mKV = make_tensor(make_gmem_ptr(device_ptr_KV), layout_KV); // (num_pages_total, 2, Page_Size, kvH, dH) + + // Slice on the K/V mode (=1). Each slice has shape (num_pages_total, Page_Size, kvH, dH); the kernel expects the + // modes in MMA order, so we permute via select<...>: K wants (Page_Size, dH, num_pages_total, kvH) -> indices + // (1,3,0,2); V wants (dH, Page_Size, num_pages_total, kvH) -> indices (3,1,0,2). + Tensor mKV_K = mKV(_, 0, _, _, _); + Tensor mKV_V = mKV(_, 1, _, _, _); + Tensor mK = make_tensor(mKV_K.data(), select<1, 3, 0, 2>(mKV_K.layout())); // (Page_Size, dH, num_pages_total, kvH) + Tensor mV = make_tensor(mKV_V.data(), select<3, 1, 0, 2>(mKV_V.layout())); // (dH, Page_Size, num_pages_total, kvH) + + Layout layout_Q = gqa::make_layout_Q(kvH, qHLocal, qL, dH, BS, stride_Q_kvH, stride_Q_qHLocal, stride_Q_qL, stride_Q_dH, stride_Q_BS); + Layout layout_O = gqa::make_layout_O(kvH, qHLocal, qL, dH, BS, stride_O_kvH, stride_O_qHLocal, stride_O_qL, stride_O_dH, stride_O_BS); + Layout layout_sinks = gqa::make_layout_sinks(qHLocal, qL, kvH); + + // Page table tensor as the harness owns it: rank-2 shape (kvL/Page_Size, BS), strides supplied by the caller. + // DMA_KV warp on the device side partitions mode-0 via logical_divide(_, Shape) into + // ((NumPagePerCTATile, MaxNumKVTiles), BS) -- the (page-within-CTA-tile, CTA-tile-idx, batch) view it actually + // indexes. Keeping the gmem-side layout flat per batch lets the harness pass any contiguous or page-strided table. + // (DMA_KV uses 4-byte cp.async, so no cp.async.bulk-style 16B alignment requirements on bytes/base/stride.) + auto layout_PageTable = make_layout( + make_shape(kvL / Page_Size, BS), + make_stride(stride_PT_p, stride_PT_BS)); + Tensor mPageTable = make_tensor(make_gmem_ptr(device_ptr_page_table), layout_PageTable); // (kvL/Page_Size, BS) + + // how we handle oob: + // oob for K, Q, V are handled by TMA + // oob for O is explicitly handled by predicate in the epilog since it uses simple st.global epilog + // we partition kvL with tile size of CTA_kvL, and we evenly distribute the kvL blocks to MaxSplits number of cta in the cluster + assert(NumReductionCTA <= MaxSplits); + static_assert(((CTA_qHLocal * CTA_qL) % NumReductionCTA) == 0, "each reduction cta must have even number of q tokens"); + + // mK and mV are constructed above by slicing+permuting the combined mKV tensor. + Tensor mQ = make_tensor(make_gmem_ptr(device_ptr_Q), layout_Q); // ((qHLocal, qL), dH, kvH, BS) + Tensor mO = make_tensor(make_gmem_ptr(device_ptr_O), layout_O); // (dH, (qHLocal, qL), kvH, BS) + Tensor mSink = make_tensor(make_gmem_ptr(device_ptr_sinks), layout_sinks); // ((qHLocal, qL), kvH) + Tensor mSeqLens = make_tensor(make_gmem_ptr(seq_lens), make_layout(make_shape(BS))); // (BS) + + static_assert(CTA_kvL == 128, "BMM1's MMA_M needs to be 128 for tcgen05.ld->softmax"); + static_assert(((CTA_qHLocal * CTA_qL) % 8) == 0, "BMM1's MMA_N needs to be divisible by 8 for tcgen05.mma"); + assert(dH == CTA_dH); // bmm1 only has 1 kblock (i.e. 1 Q tile), bmm2 deal with all dH for now, in the foreseable future this is the hardest constraint to lift + static_assert((CTA_dH == 128) || (CTA_dH == 64), "BMM2's MMA_M needs to be at either 128 or 64 for tcgen05.ld->correction"); + // we swap AB so bmm1 is K (CTA_kvL, CTA_dH) x Q (CTA_dH, CTA_qHLocal * CTA_qL) + // both Q and K are dH (K in gemm terminology) major + // M = CTA_kvL, N = CTA_qHLocal * CTA_qL, K = CTA_dH + TiledMMA tiled_bmm1 = cutlass::gemm::collective::detail::sm100_make_1sm_trivial_tiled_mma< + TypeQKV, TypeQKV, TypeAcc, // Mma's A, B, and Accumulator types + Shape, Int, Int>, // TileShape_MNK + Shape<_1, _1, _1>, // ClusterShape_MNK + cute::UMMA::Major::K, cute::UMMA::Major::K>(); + + // we swap AB for bmm2 as well, V (dH, CTA_kvL) x P (CTA_kvL, CTA_qHLocal * CTA_qL) + // V is dH (M in gemm terminology) major, P is CTA_kvL (K in gemm terminology) major in smem after each thread writes P from rmem to smem + // M = CTA_dH, N = CTA_qHLocal * CTA_qL, K = CTA_kvL + TiledMMA tiled_bmm2 = cutlass::gemm::collective::detail::sm100_make_1sm_trivial_tiled_mma< + TypeQKV, TypeQKV, TypeAcc, // Mma's A, B, and Accumulator types + Shape, Int, Int>, // TileShape_MNK + Shape<_1, _1, _1>, // ClusterShape_MNK + cute::UMMA::Major::MN, cute::UMMA::Major::K>(); + + // Pre-partitioned smem Tile Shape to post-partitioned smem tile shape ((Mma_M, Mma_K), NumMma_M, NumMma_K, DMA_Stage) + auto shape_K = make_shape(Int{}, Int{}, Int{}); + auto shape_Q = make_shape(make_shape(Int{}, Int{}), Int{}, Int<1>{}); + auto shape_S = make_shape(Int{}, make_shape(Int{}, Int{}), Int<1>{}); + auto mma_shape_K = partition_shape_A(tiled_bmm1, shape_K); + auto mma_shape_Q = partition_shape_B(tiled_bmm1, shape_Q); + + auto shape_V = make_shape(Int{}, Int{}, Int{}); + auto shape_P = select<1, 0, 2>(shape_S); // just a permutation of shape_S + auto mma_shape_V = partition_shape_A(tiled_bmm2, shape_V); + + // choose the swizzle atom for K, Q, S, V and P + auto SmemLayoutAtomK = cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, // gmem layout of K + TypeQKV, // data type of K + decltype(shape<0>(shape_K)), decltype(shape<1>(shape_K))>(); // tile size of K + auto SmemLayoutAtomQ = cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, // gmem layout of Q + TypeQKV, // data type of Q + decltype(shape<0>(shape_Q)), decltype(shape<1>(shape_Q))>(); // tile size of Q + // for bmm1 tcgen05.ld, each register is holding a row of S (CTA_kvL is mapped to the thread dimension), if we do + // st.shared from rmem to smem, to avoid bank conflict, we need to put T0V0, T1V0, T2V0, ... T31V0 contiguously in smem. + // then the smem layout of S is M (CTA_kvL) major, so we choose MN major swizzle atom + auto SmemLayoutAtomS = UMMA::Layout_MN_SW128_Atom{}; + + auto SmemLayoutAtomV = cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::MN, // gmem layout of V + TypeQKV, // data type of V + decltype(shape<0>(shape_V)), decltype(shape<1>(shape_V))>(); // tile size of V + // swizzle atom for P should be the transpose of the swizzle atom for S, because they literally represent the same tensor just with different dimension order (aka a pytorch transpose) + auto SmemLayoutAtomP = UMMA::Layout_K_SW128_Atom{}; + + // finally construct the smem layout for tile K, Q, V, S, and P + auto sK_layout = UMMA::tile_to_mma_shape(SmemLayoutAtomK, mma_shape_K); // ((Mma_M, Mma_K), NumMma_M, NumMma_K, BMM1_DMA_Stage) + // paged K smem layout: same memory as sK_layout, viewed as (Page_Size, CTA_dH, NumPagePerCTATile, BMM1_DMA_Stage). + // tile_to_shape uses Step<_1, _3, _2, _4> to stack SmemLayoutAtomK the same way as sK_layout, + // i.e. first along CTA_kvL/(Page_Size, NumPagePerCTATile), then along CTA_dH, finally along BMM1_DMA_Stage + static_assert(CTA_kvL % Page_Size == 0, "Page_Size must divide CTA_kvL"); + constexpr int NumPagePerCTATile = CTA_kvL / Page_Size; + auto sK_paged_layout = tile_to_shape(SmemLayoutAtomK, + make_shape(Int{}, Int{}, Int{}, Int{}), + Step<_1, _3, _2, _4>{}); + + auto sQ_layout = UMMA::tile_to_mma_shape(SmemLayoutAtomQ, mma_shape_Q); // ((Mma_N, Mma_K), NumMma_N, NumMma_K, 1) + auto sV_layout = UMMA::tile_to_mma_shape(SmemLayoutAtomV, mma_shape_V); // ((Mma_M, Mma_K), NumMma_M, NumMma_K, BMM2_DMA_Stage) + // paged V smem layout: same memory as sV_layout, viewed as (CTA_dH, Page_Size, NumPagePerCTATile, BMM2_DMA_Stage). + // V is MN-major: atom is (atom_M=CTA_dH, atom_K) -> stacking inside the atom covers all of CTA_dH; multiple atoms tile + // along K (kvL). Splitting kvL into Page_Size + NumPage just renames the K-iter modes; default LayoutLeft step works + // because mode order (M -> K_inner -> K_outer -> Stage) matches the natural sV_layout stride sequence. + auto sV_paged_layout = tile_to_shape(SmemLayoutAtomV, + make_shape(Int{}, Int{}, Int{}, Int{})); + + // The paged and MMA-partitioned views of K/V must alias the same smem buffer -> cosize must match. + static_assert(cute::cosize_v == cute::cosize_v, + "sK_paged_layout and sK_layout must alias the same smem buffer (cosize must match)"); + static_assert(cute::cosize_v == cute::cosize_v, + "sV_paged_layout and sV_layout must alias the same smem buffer (cosize must match)"); + // S and P use tile_to_shape as we do the mma partition in the kernel later + auto sS_layout = tile_to_shape(SmemLayoutAtomS, shape_S, Step<_1, _2, _3>{}); // (CTA_kvL, (CTA_qHLocal, CTA_qL), 1) + auto sP_layout = tile_to_shape(SmemLayoutAtomP, shape_P, Step<_2, _1, _3>{}); // ((CTA_qHLocal, CTA_qL), CTA_kvL, 1) + auto sAcc1_layout = make_layout(shape_S); // (CTA_kvL, (CTA_qHLocal, CTA_qL), 1) + // for storing fmax and fsum warp reduce partial results + int constexpr NumEpilogWarps = 4; + // NumEpilogWarps contiguous because we often ld.shared all NumEpilogWarps from 1/32 threads, this has best vectorization + auto sWarpReduce_layout = make_layout(make_shape(Int{}, make_shape(Int{}, Int{}))); // (NumEpilogWarps, (CTA_qHLocal, CTA_qL)) + // MaxSplits contiguous because we often ld.shared all MaxSplits from 1/32 threads, this has best vectorization + auto sMSMailbox_Layout = make_layout(make_shape(Int{}, Int{})); // (MaxSplits, CTA_qHLocal * CTA_qL / NumReductionCTA) + // default layout is CTA_dH contiguous to maximize st.async/ld.shared bw + auto sAcc2Mailbox_layout = make_layout(make_shape(Int{}, Int{}, Int{})); // (CTA_dH, CTA_qHLocal * CTA_qL / NumReductionCTA, MaxSplits) + auto sSinks_layout = make_layout(Int{}); // (CTA_qHLocal * CTA_qL / NumReductionCTA) + // DMA_KV's page-idx staging buffer. Page_Idx_Stage and Num_Page_Idx_Per_Stage are configured by the host + // harness (template params). Single static_assert block lives here so warp functions don't repeat them. + static_assert(Num_Page_Idx_Per_Stage % NumPagePerCTATile == 0, + "Num_Page_Idx_Per_Stage must be a multiple of CTA_kvL/Page_Size so a DMA stage's pages live in one pi stage"); + // Page_Idx_Stage must be exactly 2: K leads V by 1 tile, so around a pi-stage boundary K reads slot S%2 + // (stage S) while V finishes the last tile of slot (S-1)%2 (stage S-1). =1 would alias these slots; >2 + // wastes smem because at most 2 pi-stage groups are live at once (the one being consumed + the pre-issued). + static_assert(Page_Idx_Stage == 2, "Page_Idx_Stage must be 2 for the K-leads-V-by-1 page-idx pipeline"); + constexpr int Tiles_Per_Pi_Stage_Host = Num_Page_Idx_Per_Stage / NumPagePerCTATile; + // DMA_KV's folded page-idx pipeline issues stage S's cp.async at K_t_in_stage==1 of stage S-1, so each + // pi-stage must hold at least 2 CTA tiles. Choose Num_Page_Idx_Per_Stage >= 2 * NumPagePerCTATile. + static_assert(Tiles_Per_Pi_Stage_Host >= 2, + "Tiles_Per_Pi_Stage (= Num_Page_Idx_Per_Stage / NumPagePerCTATile) must be >= 2 for the folded page-idx pipeline"); + // Hierarchical layout ((NumPagePerCTATile, Tiles_Per_Pi_Stage), Page_Idx_Stage), default LayoutLeft so + // mode 0 is fully contiguous (NumPagePerCTATile innermost). Letting cute carry the (p, t) split keeps the + // producer's inner write as sPageIdx(make_coord(p, t), stage_idx) without manual i/N + i%N arithmetic, and + // gives the consumer a contiguous NumPagePerCTATile-int slice via &sPageIdx(make_coord(0, t), stage_idx). + auto sPageIdx_layout = make_layout(make_shape(make_shape(Int{}, Int{}), Int{})); + + // Now we can find the SMEM allocation size + using SMEMStorage = SharedStorage; + + static_assert(BMM1_DMA_Stage >= BMM2_DMA_Stage, "otherwise you are wasting BMM2 stage because BMM1 TMA issue will block BMM2 TMA due to insufficient BMM1 stages"); + + // create TMA descriptors for K, Q, V matrices + // K TMA box is (Page_Size, CTA_dH) -- one page per TMA copy. The per-page SMEM destination layout points into the + // same memory as a single page slot inside sK_layout/sK_paged_layout. + Copy_Atom tma_atom_K = make_tma_atom( + SM90_TMA_LOAD{}, // TMA Load Op, sm100 reuses sm90 tma atom + mK, // Source GMEM tensor + take<0, 2>(sK_paged_layout), // Destination SMEM layout for 1 page = 1 TMA box, (Page_Size, CTA_dH) + make_shape(Int{}, Int{}) // TMA box shape + ); + Tensor mK_tma = tma_atom_K.get_tma_tensor(shape(mK)); // (Page_Size, dH, num_pages_total, kvH) + + Copy_Atom tma_atom_Q = make_tma_atom( + SM90_TMA_LOAD{}, // TMA Load Op, sm100 reuses sm90 tma atom + mQ, // Source GMEM tensor + // sQ_layout(_,_,_,Int<0>{}) doesn't work under some corner cases (composedlayout indexing), so we use + // the take method which is also correct. + take<0, 3>(sQ_layout), // Destination SMEM layout for 1 DMA_Stage, ((Mma_N, Mma_K), NumMma_N, NumMma_K) + make_shape(get<0>(shape_Q), get<1>(shape_Q)) // TMA box shape + ); + Tensor mQ_tma = tma_atom_Q.get_tma_tensor(shape(mQ)); // ((qHLocal, qL), dH, kvH, BS) + + // V TMA box is (CTA_dH, Page_Size) -- one page per TMA copy. Per-page SMEM destination layout points into the + // same memory as a single page slot inside sV_layout/sV_paged_layout. + Copy_Atom tma_atom_V = make_tma_atom( + SM90_TMA_LOAD{}, // TMA Load Op, sm100 reuses sm90 tma atom + mV, // Source GMEM tensor + take<0, 2>(sV_paged_layout), // Destination SMEM layout for 1 page = 1 TMA box, (CTA_dH, Page_Size) + make_shape(Int{}, Int{}) // TMA box shape + ); + Tensor mV_tma = tma_atom_V.get_tma_tensor(shape(mV)); // (dH, Page_Size, num_pages_total, kvH) + + int smemBytes = sizeof(SMEMStorage); + + // invoke the kernel + cudaLaunchConfig_t config; + cudaLaunchAttribute attrs[2]; + // bid.x: kvH * BS, bid.y: qHLocal * qL, bid.z: kvL + uint32_t Cluster_Size = cute::max(MaxSplits, NumReductionCTA); + config.gridDim = dim3{ + (uint32_t)kvH * BS, + (uint32_t)cutlass::ceil_div(qHLocal, CTA_qHLocal) * cutlass::ceil_div(qL, CTA_qL), + Cluster_Size}; + config.blockDim = 256; // 8 warps + config.dynamicSmemBytes = smemBytes; + config.stream = stream; + attrs[0].id = cudaLaunchAttributeClusterDimension; + attrs[0].val.clusterDim = {1, 1, Cluster_Size}; + attrs[1].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[1].val.programmaticStreamSerializationAllowed = 1; + config.attrs = attrs; + config.numAttrs = pdl ? 2 : 1; + + if (device_ptr_sinks != nullptr) { + auto *kernel_instance = + &gqa_paged_device; + gpuErrChk(cudaFuncSetAttribute(*kernel_instance, cudaFuncAttributeMaxDynamicSharedMemorySize, smemBytes)); + // portable max cluster size is 8, but sm100a supports 16, need explicit opt in + gpuErrChk(cudaFuncSetAttribute(*kernel_instance, cudaFuncAttributeNonPortableClusterSizeAllowed, 1)); + gpuErrChk(cudaLaunchKernelEx(&config, kernel_instance, mK_tma, mQ_tma, mV_tma, mO, mSink, + mSeqLens, mPageTable, + tma_atom_K, tma_atom_Q, tma_atom_V, + tiled_bmm1, tiled_bmm2, + softmax_scale * Log2_E, sliding_window_size, pdl_count)); + } + else { + auto *kernel_instance = + &gqa_paged_device; + gpuErrChk(cudaFuncSetAttribute(*kernel_instance, cudaFuncAttributeMaxDynamicSharedMemorySize, smemBytes)); + // portable max cluster size is 8, but sm100a supports 16, need explicit opt in + gpuErrChk(cudaFuncSetAttribute(*kernel_instance, cudaFuncAttributeNonPortableClusterSizeAllowed, 1)); + gpuErrChk(cudaLaunchKernelEx(&config, kernel_instance, mK_tma, mQ_tma, mV_tma, mO, mSink, + mSeqLens, mPageTable, + tma_atom_K, tma_atom_Q, tma_atom_V, + tiled_bmm1, tiled_bmm2, + softmax_scale * Log2_E, sliding_window_size, pdl_count)); + } +} + +} // namespace gqa_paged +} // namespace TGV diff --git a/include/cute/arch/mma_sm100_umma.hpp b/include/cute/arch/mma_sm100_umma.hpp index d7dfb7159..e7c262c52 100644 --- a/include/cute/arch/mma_sm100_umma.hpp +++ b/include/cute/arch/mma_sm100_umma.hpp @@ -1207,6 +1207,10 @@ struct SM100_MMA_S8_2x1SM_SS_SPARSE } }; +template struct SM100_MMA_F8F6F4_SS { using DRegisters = void; @@ -1452,6 +1456,10 @@ struct SM100_MMA_MXF8F6F4_SS_SPARSE } }; +template struct SM100_MMA_F8F6F4_2x1SM_SS { using DRegisters = void; diff --git a/include/cute/atom/mma_traits_sm100.hpp b/include/cute/atom/mma_traits_sm100.hpp index 5b6af4218..1949b2491 100644 --- a/include/cute/atom/mma_traits_sm100.hpp +++ b/include/cute/atom/mma_traits_sm100.hpp @@ -3327,12 +3327,9 @@ struct MMA_Traits -struct MMA_Traits, cute::C, - cute::integral_constant, - cute::integral_constant, - cute::integral_constant, - cute::integral_constant> +struct MMA_Traits> { using ValTypeD = c_type; using ValTypeA = a_type; @@ -3390,7 +3387,9 @@ struct MMA_Traits(traits.idesc_); - SM100_MMA_F8F6F4_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + SM100_MMA_F8F6F4_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); } }; @@ -3745,12 +3744,9 @@ struct MMA_Traits -struct MMA_Traits, cute::C, - cute::integral_constant, - cute::integral_constant, - cute::integral_constant, - cute::integral_constant> +struct MMA_Traits> { using ValTypeD = c_type; using ValTypeA = a_type; @@ -3808,7 +3804,9 @@ struct MMA_Traits(traits.idesc_); - SM100_MMA_F8F6F4_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + SM100_MMA_F8F6F4_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); } }; diff --git a/include/cutlass/float8.h b/include/cutlass/float8.h index 9aacd78fc..6019de643 100644 --- a/include/cutlass/float8.h +++ b/include/cutlass/float8.h @@ -57,6 +57,14 @@ # endif // (__CUDA_ARCH__ >= 900) #endif // defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) +# if (__CUDA_ARCH__ >= 1000) +# if (__CUDACC_VER_MAJOR__ > 13) || ((__CUDACC_VER_MAJOR__ >= 13) && (__CUDACC_VER_MINOR__ >= 2)) +# define CUDA_PTX_FP8_BF16_CVT_ENABLED 1 +# endif // (__CUDACC_VER_MAJOR__ > 13) || ((__CUDACC_VER_MAJOR__ >= 13) && (__CUDACC_VER_MINOR__ >= 2)) +# endif // (__CUDA_ARCH__ >= 1000) +#endif // defined(__CUDA_ARCH__) + #if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110A_ENABLED) ||\ diff --git a/include/cutlass/gemm/collective/builders/sm100_common.inl b/include/cutlass/gemm/collective/builders/sm100_common.inl index 0c8be4ad2..f6a316206 100644 --- a/include/cutlass/gemm/collective/builders/sm100_common.inl +++ b/include/cutlass/gemm/collective/builders/sm100_common.inl @@ -339,18 +339,16 @@ sm100_make_1sm_trivial_tiled_mma() { ) { return make_tiled_mma( - cute::MMA_Traits< - cute::SM100_MMA_F8F6F4_SS, + cute::SM100_MMA_F8F6F4_SS< ElementAMma, ElementBMma, ElementAMmaccumulator, - cute::C, - cute::C, - cute::integral_constant, - cute::integral_constant, - cute::integral_constant, - cute::integral_constant - >{} + M, + N, + UmmaMajorA, + UmmaMajorB, + ANeg, + BNeg>{} ); } else { @@ -407,18 +405,16 @@ sm100_make_2sm_trivial_tiled_mma() { ) { return make_tiled_mma( - cute::MMA_Traits< - cute::SM100_MMA_F8F6F4_2x1SM_SS, + cute::SM100_MMA_F8F6F4_2x1SM_SS< ElementAMma, ElementBMma, ElementAMmaccumulator, - cute::C, - cute::C, - cute::integral_constant, - cute::integral_constant, - cute::integral_constant, - cute::integral_constant - >{} + M, + N, + UmmaMajorA, + UmmaMajorB, + ANeg, + BNeg>{} ); } @@ -739,17 +735,16 @@ sm100_make_trivial_mixed_input_tiled_mma() { } if constexpr (cute::is_same_v) { return make_tiled_mma( - cute::MMA_Traits< - cute::SM100_MMA_F8F6F4_SS, + cute::SM100_MMA_F8F6F4_SS< ElementAMma, ElementBMma, ElementAccumulator, - cute::C, - cute::C, - cute::integral_constant, - cute::integral_constant, - cute::integral_constant, - cute::integral_constant>{}); + M, + N, + UmmaMajorA, + UmmaMajorB, + cute::UMMA::ScaleIn::One, + cute::UMMA::ScaleIn::One>{}); } } } diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index e9463b203..e57620be0 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -1820,14 +1820,21 @@ struct NumericArrayConverter(source); - - asm volatile( \ - "{\n" \ - "cvt.rn.f16x2.e4m3x2 %0, %1;\n" \ - "}\n" : "=r"(res_half): "h"(src_packed)); - float2 res_float = __half22float2(reinterpret_cast<__half2 &>(res_half)); - NumericArrayConverter converter; - return converter(reinterpret_cast const&>(res_float)); + #if defined(CUDA_PTX_FP8_BF16_CVT_ENABLED) + asm volatile( \ + "{\n" \ + "cvt.rn.bf16x2.e4m3x2 %0, %1;\n" \ + "}\n" : "=r"(res_half): "h"(src_packed)); + return reinterpret_cast(res_half); + #else + asm volatile( \ + "{\n" \ + "cvt.rn.f16x2.e4m3x2 %0, %1;\n" \ + "}\n" : "=r"(res_half): "h"(src_packed)); + float2 res_float = __half22float2(reinterpret_cast<__half2 &>(res_half)); + NumericArrayConverter converter; + return converter(reinterpret_cast const&>(res_float)); + #endif #else result_type result; NumericConverter converter; @@ -2961,19 +2968,31 @@ struct NumericArrayConverterPacked4Element src2float; - Array tmp_floats = src2float(source); - - // Convert float to bf16 result_type out; - Array* packed_tmp = reinterpret_cast*>(&tmp_floats); - Array* packed_out = reinterpret_cast*>(&out); - NumericArrayConverter float2result; - packed_out[0] = float2result(packed_tmp[0]); - packed_out[1] = float2result(packed_tmp[1]); + #if defined(CUDA_PTX_FP8_BF16_CVT_ENABLED) + uint32_t const& src_packed = reinterpret_cast(source); + Array& out_packed = reinterpret_cast&>(out); + asm volatile("{\n" + ".reg .b16 b0, b1;\n" + "mov.b32 {b0, b1}, %2;\n" + "cvt.rn.bf16x2.e4m3x2 %0, b0;\n" + "cvt.rn.bf16x2.e4m3x2 %1, b1;\n" + "}\n" + : "=r"(out_packed[0]), "=r"(out_packed[1]) + : "r"(src_packed)); + #else + // Convert f8 to float + NumericArrayConverterPacked4Element src2float; + Array tmp_floats = src2float(source); - return out; + // Convert float to bf16 + Array* packed_tmp = reinterpret_cast*>(&tmp_floats); + Array* packed_out = reinterpret_cast*>(&out); + NumericArrayConverter float2result; + packed_out[0] = float2result(packed_tmp[0]); + packed_out[1] = float2result(packed_tmp[1]); + #endif + return out; #else result_type result; NumericConverter converter; diff --git a/include/cutlass/version.h b/include/cutlass/version.h index 5c30d8c6a..08596546e 100644 --- a/include/cutlass/version.h +++ b/include/cutlass/version.h @@ -36,7 +36,7 @@ #define CUTLASS_MAJOR 4 #define CUTLASS_MINOR 5 -#define CUTLASS_PATCH 0 +#define CUTLASS_PATCH 1 #ifdef CUTLASS_VERSIONS_GENERATED #include "cutlass/version_extended.h" diff --git a/media/docs/pythonDSL/cute_dsl.rst b/media/docs/pythonDSL/cute_dsl.rst index 50a7341d2..f5675a85e 100644 --- a/media/docs/pythonDSL/cute_dsl.rst +++ b/media/docs/pythonDSL/cute_dsl.rst @@ -23,3 +23,5 @@ CuTe DSL Compile with TVM FFI Ahead-of-Time (AOT) Compilation Talks and Presentations + Naming Conventions + MMA Programming Guides diff --git a/media/docs/pythonDSL/cute_dsl_general/dsl_jit_arg_generation.rst b/media/docs/pythonDSL/cute_dsl_general/dsl_jit_arg_generation.rst index 18970012f..99919e071 100644 --- a/media/docs/pythonDSL/cute_dsl_general/dsl_jit_arg_generation.rst +++ b/media/docs/pythonDSL/cute_dsl_general/dsl_jit_arg_generation.rst @@ -83,6 +83,9 @@ an elementwise lambda function can be passed in as the ``epilogue_op`` argument. Refer to the `Blackwell dense GEMM example `__ for a complete example. +.. note:: + For the per-thread/partition naming convention used above (``tTR_rAcc``, ``tTR_rC``, and related tokens such as ``tAgA``, ``bSG_sC``, ``tQgQ_qdl``, …), see the :ref:`cute_dsl_naming_conventions`. + Type safety ----------- diff --git a/media/docs/pythonDSL/cute_dsl_general/naming_conventions.rst b/media/docs/pythonDSL/cute_dsl_general/naming_conventions.rst new file mode 100644 index 000000000..c14d84ef4 --- /dev/null +++ b/media/docs/pythonDSL/cute_dsl_general/naming_conventions.rst @@ -0,0 +1,253 @@ +.. _cute_dsl_naming_conventions: + +CuTe DSL Naming Conventions +=========================== + +This page summarizes the Hungarian-style naming conventions used for identifiers across the DSL examples and epilogue helpers: tensor partitions, per-thread copy-partitioners, copy atoms, and the axis-order suffixes that encode tensor layouts. It is meant as a lookup reference while reading example code — not as a style rule enforced on new code. + +Memory/space scopes +------------------- + +- ``g``: Global memory view (GMEM), e.g., ``gB_nkl``, ``tTR_gC`` +- ``s``: Shared memory view (SMEM), e.g., ``sA``, ``tRS_sC``, ``bSG_sC`` +- ``r``: Register view (RMEM), e.g., ``tTR_rAcc``, ``tRS_rC`` +- ``t``: Tensor-memory view (TMEM), used for any TMEM-resident fragment or layout regardless of role. The classical case is the accumulator (``tCtAcc``, ``tTR_tAcc``). The same scope letter also appears for non-accumulator TMEM tensors such as ``tCtE``, ``tCtState``, ``tCtQState``, ``tCtShared``. Read the operand suffix to distinguish the role from the memory scope. + +Per-thread/partitioned views and families +----------------------------------------- + +- ``tA…`` / ``tB…``: TMA load path for A/B + + - ``tAgA`` / ``tAsA``: per-thread partitioned global/shared A for TMA load + - ``tBgB`` / ``tBsB``: per-thread partitioned global/shared B for TMA load + - NVFP4/FP8 scale factors mirror this: ``tAgSFA`` / ``tAsSFA``, ``tBgSFB`` / ``tBsSFB`` + +- ``tC…``: Compute/epilogue path for C/Acc + + - ``tCgA`` / ``tCgB`` / ``tCgC``: per-thread partitions used by MMA/epilogue (derived from global tensors) + - ``tCrA`` / ``tCrB``: per-thread fragments used by MMA (derived from SMEM A/B) + - ``tCtAcc``: per-thread accumulator fragment/layout in TMEM + - Additional ``tC*`` tensors follow the same schema for kernels that carry more than the classical A/B/C/Acc operands (see Operands and roles below): e.g. ``tCtState`` / ``tCtQState`` / ``tCtShared`` (gated-delta-net recurrent state in TMEM), ``tCrValpha`` / ``tCrVbeta`` / ``tCrVbias`` (EVT/EFC broadcast vectors in registers), ``tCtAccInter`` / ``tCtAccIntra`` (hierarchical accumulators) + + .. {$nv-internal-release begin} + + - Sparse GEMM additionally defines ``tCtE`` for the sparsity metadata tensor in TMEM (sm_140 / Feynman sparse GEMM, not yet released) + + .. {$nv-internal-release end} + +- ``tTM…``: Per-thread TMEM tiled-copy partitions used by FMHA/attention kernels (e.g. ``tTMrO`` as the register-side view of a TMEM load partitioned through ``thr_tmem_load``) + +- Attention/MLA path families (``tQ…``, ``tK…``, ``tV…``, ``tP…``, ``tO…``): same schema as ``tA…`` / ``tB…`` / ``tC…`` but specialised to the Q/K/V/P/O operands of attention kernels, e.g.: + + - ``tQsQ`` / ``tQgQ_qdl``: per-thread SMEM / GMEM partitions of Q for TMA load + - ``tKrK`` / ``tVrV``: per-thread register fragments for K / V + - ``tOtO`` / ``tOrO``: per-thread TMEM / register views of the attention output accumulator O + - ``tPrP``: per-thread register fragment for the softmax probability matrix P + +Data-movement copy paths +------------------------ + +- ``tTR_*``: TMEM → Register (T2R) + + - ``tTR_tAcc``: TMEM accumulator source for T2R + - ``tTR_rAcc``: Register destination for T2R + - ``tTR_gC``: When not using TMA store, Register → Global C destination partition + +- ``tRS_*``: Register → Shared (R2S) + + - ``tRS_rC``: Register source (C dtype) + - ``tRS_sC``: Shared destination + +- ``bSG_*``: Thread(b)lock partition for Shared → Global via TMA store + + - ``bSG_sC``: Shared source for TMA store + - ``bSG_gC``: Global destination for TMA store + - Also used for accumulator in some flows: ``bSG_sAcc``, ``bSG_gAcc`` + - The same schema extends to additional store operands: ``bSG_sD`` / ``bSG_gD``, ``bSG_sP`` / ``bSG_gP``, ``bSG_sY`` / ``bSG_gY`` + +- ``bGS_*``: Thread(b)lock partition for Global → Shared via TMA **load** (the load-path mirror of ``bSG_*``) + + - ``bGS_gC`` / ``bGS_sC``: Global source / Shared destination for TMA load of C-like operands (seen in EFC row/column broadcast prologues) + +- ``simt_atom``: SIMT copy path used when TMA store is disabled (Register → Global) +- Generic SIMT / tiled copy atoms ``2_atom[_suffix]`` name the copy direction between two memory scopes: + + - ``s2r_atom_*``: Shared → Register atom used in specialised epilogues and attention loads (e.g. ``s2r_atom_delta``, ``s2r_atom_cumsum``, ``s2r_atom_d`` in Mamba2 SSD) + - ``r2s_atom``: Register → Shared atom + - ``t2r_atom`` / ``r2t_atom``: Tensor memory ↔ Register atoms (paired with ``thr_tmem_load`` / ``thr_tmem_store``) + - ``s2s_atom``: Shared → Shared atom (reshape/remap without register spill) + - ``s2t``: Shared → Tensor memory atom + + .. {$nv-internal-release begin} + + - ``sp2t_copy_op_*``: Sparse source → Tensor memory copy op (sm_140 / Feynman sparse GEMM, not yet released: e.g. ``Sp2TAsACopyOp``, ``Sp2TAsECopyOp``) + + .. {$nv-internal-release end} + + - Custom ``autovec_copy`` paths appear where the DSL auto-vectorises a bespoke layout + +Operands and roles +------------------ + +- ``A``, ``B``, ``C``: GEMM operands +- ``Acc``: Accumulator (TMEM/Register paths). Hierarchical MMA kernels split this into ``AccInter`` / ``AccIntra`` for the inter-/intra-CTA accumulator halves +- Classical extra outputs / intermediates: ``D`` (additional output), ``Y`` (fused output), ``SFA`` / ``SFB`` (per-operand scale-factor arrays for NVFP4/FP8), ``SF`` (generic scale factor) +- Attention / MLA operand letters (Q/K/V/P/O schema): + + - ``Q`` (query), ``K`` (key), ``V`` (value), ``P`` (softmax probability / score matrix), ``O`` (attention output) + - Variants: ``Kt`` / ``Vt`` for the transposed view of K/V, ``Qi`` / ``Ki`` / ``Vi`` for per-iteration slices, ``QK`` / ``PV`` / ``QKV`` where a single fragment spans multiple operands of the two back-to-back matmuls +- Mamba / recurrent-state letters: ``Delta`` / ``DeltaA`` (time-step and A-decay), ``State`` / ``QState`` / ``Shared`` (gated-delta-net recurrent state tensors), ``Cumsumlog`` / ``Cumprod`` (running reductions), ``Gate``, ``DecayV`` + +.. {$nv-internal-release begin} + +- Sparse-GEMM letters (sm_140 / Feynman, not yet released): ``E`` (sparsity metadata tensor in TMEM; paired with ``sp2t_*`` copy ops) + +.. {$nv-internal-release end} + +- EVT / EFC broadcast vectors: ``Valpha`` / ``Vbeta`` (alpha/beta scalars broadcast as vectors), ``Vbias`` (bias vector), ``Ainv`` (inverse of A for fused solvers) + +.. {$nv-internal-release begin} + +- LUT-based block-scaled GEMM letter (Rubin, not yet released): ``LutB`` (look-up-table operand) + +.. {$nv-internal-release end} +- Communication operands (multi-CTA / multicast flows): ``CommInMC`` / ``CommOutMC`` (multicast in/out), ``CommOutUC`` (unicast out) +- Head-dimension variants: ``Dv`` (value head dimension when distinct from Q/K dim), ``Nv`` (number of value heads) + +Axis-order suffixes +------------------- + +- Suffix encodes axis order of the view (lowercase letters each stand for one tensor mode): + + - GEMM layouts use ``m``/``n``/``k``/``l``: + + - ``_mnl``, ``_nkl``, ``_mkl``, … map to (M, N, K, L) ordering + - Example: ``gB_nkl`` is B with axes (N, K, L); ``gC_mnl`` is C with (M, N, L) + + - Attention / FMHA layouts use ``q``/``k``/``d``/``l`` (sequence-Q, sequence-K, head-dim, batch): + + - ``mQ_qdl``: Q tensor with axes (SeqQ, HeadDim, Batch) + - ``mK_kdl``: K tensor with axes (SeqK, HeadDim, Batch) + - ``mV_dkl``: V tensor with axes (HeadDim, SeqK, Batch) — the ``d``-first order reflects the V-transpose that makes the second matmul (P·V) a standard row-major ``MxK·KxN`` + + - Lower-rank 2D slices drop the batch letter: ``_mn``, ``_mk``, ``_nk`` + +- Internally, CuTe layouts also expose grouped modes like ``MMA_M/N/K``, ``EPI_M/N``, ``RestM/N/K/L``, ``STAGE``, etc. (these are typically implementation details not directly used in example code). + +Reading compound tokens +----------------------- + +- From left to right: ``[t|b][A|B|C|Q|K|V|P|O|TR|RS|SG|GS|TM]_[g|s|r|t][Operand/Role][AxisSuffix?]`` + + - ``t`` = per-thread/partitioned view; ``b`` = block/threadblock partition context + - family/path letters: + + - Operand-based: ``A`` / ``B`` / ``C`` (GEMM), ``Q`` / ``K`` / ``V`` / ``P`` / ``O`` (attention) + - Direction-based: ``TR`` (TMEM → Register), ``RS`` (Register → Shared), ``SG`` (Shared → Global, store), ``GS`` (Global → Shared, load), ``TM`` (TMEM tiled-copy partition), ``R2G`` / ``S2R`` / ``T2R`` / ``R2T`` convenience aliases + - memory = ``g``/``s``/``r``/``t`` + - operand/role = ``A``/``B``/``C``/``Acc``/``SFA``/``SFB``/``Q``/``K``/``V``/``P``/``O``/``E``/``State``/… + - axis suffix = ``_mnl``, ``_nkl``, ``_qdl``, ``_kdl``, ``_dkl``, ``_mn``, … when applicable + +- Per-thread-partitioner objects follow a parallel ``thr_*`` vocabulary, grouped by role: + + - MMA partitioner: ``thr_mma`` + - Tiled-copy direction variants ``thr_copy_2``: ``thr_copy_g2s``, ``thr_copy_s2r``, ``thr_copy_t2r``, ``thr_copy_r2s``, ``thr_copy_r2t``, ``thr_copy_s2t`` + - Role-qualified copy variants: ``thr_copy_sfa``, ``thr_copy_sfb``, ``thr_copy_load``, ``thr_copy_beta_g2s`` + - MMA variants for multi-matmul kernels: ``thr_mma_qk``, ``thr_mma_pv``, ``thr_mma_kv``, ``thr_mma_qkv``, ``thr_mma_intra1`` / ``thr_mma_intra2``, ``thr_mma_leader_cta``, ``thr_mma_sfb`` + - TMEM access partitioners: ``thr_tmem_load``, ``thr_tmem_store`` (with ``_stats`` / ``_vec`` suffix variants) + + The tensor produced by ``thr_foo.partition_S(X)`` or ``.partition_D(X)`` is then named by the ``[t|b]FamilyPrefix_*`` convention above. + +Concrete references +------------------- + +Open these files in the repository to see each pattern in context: + +- TMA load partitions for A/B: + + - ``tAgA``, ``tAsA``, ``tBgB``, ``tBsB`` + - ``CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py`` (around TMA partition of A/B) + +- Accumulator fragment in TMEM: + + - ``tCtAcc`` + - ``CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py`` (accumulator creation and use) + +- TMEM → Register (T2R): + + - ``tTR_tAcc``, ``tTR_rAcc``, ``tTR_gC`` + - ``CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py`` (``epilog_tmem_copy_and_partition``) + +- Register → Shared (R2S): + + - ``tRS_rC``, ``tRS_sC`` + - ``CuTeDSL/cute/blackwell/kernel/mixed_input_gemm/mixed_input_gemm.py`` (``epilog_smem_copy_and_partition``) + +- Shared → Global via TMA store: + + - ``bSG_sC``, ``bSG_gC`` + - ``CuTeDSL/cute/blackwell/kernel/blockscaled_gemm/dense_blockscaled_gemm_persistent.py`` (``epilog_gmem_copy_and_partition``) + +- NVFP4/FP8 scale factors: + + - ``tAgSFA``/``tAsSFA``, ``tBgSFB``/``tBsSFB`` + - ``CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/nvfp4_gemm_0.py`` (scale factor partition and usage) + +- Additional examples across ``examples/``: + + - Register → Global helper naming in MLA: ``tR2G_rO_src``, ``tR2G_rO_dst`` + - ``CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp16.py`` (output store section) + + - Shared → Register SIMT atoms in Mamba2 SSD: ``s2r_atom_delta``, ``s2r_atom_cumsum``, ``s2r_atom_d`` + - ``CuTeDSL/cute/blackwell/kernel/attention/mamba2_ssd/mamba2_ssd.py`` (SMEM load paths for delta and D) + + - ``thr_*`` slices for partitioning per-thread work: ``thr_mma``, ``thr_copy_t2r``, ``thr_copy_r2s``, etc. + - ``CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py`` (``thr_mma``, ``thr_copy_t2r``, ``thr_copy_r2s``) + +- Axis-order suffix examples: + + - ``gB_nkl``, ``gC_mnl`` + - ``CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py`` (global tensor tiling and partitioning) + +- Global → Shared (TMA load) block partition ``bGS_*``: + + - ``bGS_gC``, ``bGS_sC`` + - ``CuTeDSL/cute/blackwell/efc/common_efc.py`` (row/column broadcast prologue building the C-like input for EVT) + +- Attention Q/K/V/P/O families and ``_qdl`` / ``_kdl`` / ``_dkl`` axis suffixes: + + - ``tQsQ``, ``tQgQ_qdl``, ``mK_kdl``, ``mV_dkl`` + - ``CuTeDSL/cute/hopper/kernel/attention/fmha.py`` (Q/K/V TMA partitions) + - ``tOtO``, ``tOrO``, ``tPrP`` + - ``CuTeDSL/cute/blackwell/tutorial/tutorial_fmha/fmha_0.py`` (output and softmax fragments) + - ``tKrK``, ``tVrV`` + - ``CuTeDSL/cute/blackwell/kernel/attention/mixed_input_fmha/mixed_input_fmha_decode.py`` (mixed-input K/V register fragments) + +- TMEM tiled-copy ``tTM*`` family and the generalised ``2_atom`` naming: + + - ``tTMrO`` driven by ``thr_tmem_load`` + - ``CuTeDSL/cute/blackwell/tutorial/tutorial_fmha/fmha_0.py`` + +- Recurrent-state operands (``State`` / ``QState`` / ``Shared``) in TMEM: + + - ``tCtState``, ``tCtQState``, ``tCtShared`` + - ``CuTeDSL/cute/blackwell/kernel/attention/gated_delta_net/gated_delta_net_chunked.py`` + +.. {$nv-internal-release begin} + +- Sparse-metadata operand ``E`` and ``sp2t_*`` copy ops (sm_140 / Feynman, not yet released): + + - ``tCtE``, ``sp2t_copy_op_A``, ``sp2t_copy_op_E`` + - ``CuTeDSL/internal/feynman/sm140_sparse_gemm.py`` and ``sm140_sparse_gemm_temporal_split_k.py`` + +- LUT-based block-scaled GEMM operand ``LutB`` (Rubin, not yet released): + + - ``CuTeDSL/cute/rubin/kernel/blockscaled_gemm/dense_blockscaled_gemm_lut.py`` + - ``CuTeDSL/cute_ext/rubin/dense_gemm_lutb.py`` + +.. {$nv-internal-release end} + +- Richer ``thr_*`` and ``thr_copy_*`` / ``thr_mma_*`` / ``thr_tmem_*`` partitioner taxonomy: + + - ``thr_copy_g2s``, ``thr_copy_s2r``, ``thr_copy_s2t``, ``thr_copy_r2t``, ``thr_mma_qk``, ``thr_mma_pv``, ``thr_tmem_load``, ``thr_tmem_store`` + - The attention and Mamba2 examples above are the densest references; any ``fmha_*.py`` or ``mamba2_ssd.py`` file will show the full vocabulary in use diff --git a/media/docs/pythonDSL/mma_docs/intro.rst b/media/docs/pythonDSL/mma_docs/intro.rst new file mode 100644 index 000000000..38605cdee --- /dev/null +++ b/media/docs/pythonDSL/mma_docs/intro.rst @@ -0,0 +1,11 @@ +Architecture-specific MMA Programming Guides +============================================= + +This section contains architecture-specific MMA programming guides. + +.. toctree:: + :maxdepth: 2 + + wmma_programming + wgmma_programming + tcgen05_programming \ No newline at end of file diff --git a/media/docs/pythonDSL/mma_docs/tcgen05_programming.rst b/media/docs/pythonDSL/mma_docs/tcgen05_programming.rst new file mode 100644 index 000000000..5d37124e7 --- /dev/null +++ b/media/docs/pythonDSL/mma_docs/tcgen05_programming.rst @@ -0,0 +1,1528 @@ +.. _tcgen05_programming: + +tcgen05 MMA Programming Guide +============================== + +Blackwell (SM100) introduces the **tcgen05** family of PTX instructions — the +5th-generation Tensor Core MMA (matrix multiply-accumulate) operations. They +compute ``D = A * B + C`` with 2x–4x the throughput of Hopper's WGMMA +instructions, depending on data type. + +Key architectural characteristics: + +* **Tensor Memory (TMEM):** A new on-chip memory dedicated to the accumulator + (and, optionally, operand A). tcgen05 MMA reads and writes the accumulator + in TMEM directly, freeing the register file for other work. +* **Single-thread launch:** Only one thread issues the MMA instruction. +* **CTA-pair cooperation:** Two adjacent CTAs can jointly execute a single + MMA, doubling the tile size without extra synchronization logic. + +This guide shows how to program these operations through the CuTe Python +DSL, using SMEM for operands A and B, and TMEM for the accumulator. + +.. contents:: **Contents** + :local: + :depth: 2 + +Global Memory (GMEM) to MMA data flow overview +---------------------------------------------- + +Tcgen05 MMA instructions requires us to stage A input operands in Shared Memory (SMEM) or Tensor Memory (TMEM), +and B input operands in SMEM. The accumulator is always stored in TMEM. + +The diagram below traces the full data flow of a tcgen05 GEMM kernel, for the most +common case where A and B matrices are stored in GMEM, and the output matrix --read from TMEM-- +is written to GMEM. + +There are 3 parallel tracks where each has 2 sub-tracks. Three parallel tracks are +for operands A, B, and C/D, respectively. The two sub-tracks are for copying data between different memory +spaces and for MMA execution. + +- **Operand A** (and symmetrically **Operand B**): + + - First, we need to create SMEM tensors for A and B matrices: ``sA`` and ``sB``. These + tensors are physically allocated tensors that are the destination of copy and the source operands + for the MMA instructions. + - Next the **data copy flow** creates the tensor views for copying data from GMEM to SMEM. + It starts with ``mA`` tensor that represents the matrix A in global memory. + ``mA`` → ``local_tile`` → ``gA`` operation creates the local tile view of A that is the + slice of A matrix needed to compute the given MMA's output tile partitioning. + ``gA`` → ``partition_A`` → ``tCgA`` partitions the full MMA sized tile into smaller tiles which + are needed to copy the correct portion of A/B matrix to SMEM by individual CTAs cooperating + for the MMA (1CTA vs 2CTA pair MMA cases). + Then ``tma_partition`` produces TMA views ``tAsA``, ``tAgA``, and the loop copies tiles from + GMEM into SMEM via ``copy(tma, tAgA[k], tAsA[stage])``. + - In parallel, the **MMA flow** turns the SMEM tensors into iterable tensors of SMEM descriptors for MMA instructions. + ``sA`` (the same shared-memory allocation written by TMA) → ``make_fragment_A`` → ``tCrA`` + (they are passed to ``cute.gemm()``). Note that the SMEM descriptors are views created + from the SMEM tensor that is interpretable by the MMA instructions. + +- **Accumulator C/D**: + + - **TMEM accumulator flow** (gemm input/output): ``make_fragment_C(MMA_partition_shape_C)`` → ``tCtAcc``, + which serves as the accumulator input/output of ``cute.gemm()`` (and MMA instruction). + - **Output flow** (GMEM destination): The LDTM loads results into registers and a final store writes them + to global memory. ``mC`` → ``local_tile`` → ``gC`` + → ``partition_C`` → ``tCgC``. This path creates the tensor views that will be stored to GMEM. + +.. code-block:: text + + Operand A Dataflow Path Operand B Dataflow Path Accumulator C/D Dataflow Path + ─────────────────────── ─────────────────────── ───────────────────────────── + + mA: (M, K) [GMEM] mB: (N, K) [GMEM] ┌──── TMEM ──────────┐ + │ │ │ partition_shape_C()│ + │ local_tile(mA, mma_tiler, coord) │ local_tile(mB, mma_tiler, coord) │ make_fragment_C() │ + ▼ ▼ │ bind to tmem_ptr │ + gA: (BM, BK, k) [GMEM] gB: (BN, BK, k) [GMEM] └───────┬────────────┘ + │ │ │ + │ thr_mma.partition_A(gA) │ thr_mma.partition_B(gB) tCtAcc:(MMA,MMA_M,MMA_N) [TMEM] + ▼ ▼ │ + tCgA:(MMA,MMA_M, [GMEM] tCgB:(MMA,MMA_N, [GMEM] │ + MMA_K,k) MMA_K,k) │ + │ │ │ mC: (M, N) [GMEM] + │ ┌──── SMEM ─────────┐ │ ┌──── SMEM ─────────┐ │ │ + │ │ sA = alloc(layout)│ │ │ sB = alloc(layout)│ │ │ local_tile + │ └──┬────────┬───────┘ │ └──┬────────┬───────┘ │ ▼ + │ │ │ │ │ │ │ gC: (BM, BN) [GMEM] + │ │ make_fragment_A(sA) │ │ make_fragment_B(sB) │ │ partition_C + │ │ │ │ │ │ │ ▼ + │ │ ▼ │ │ ▼ │ tCgC:(MMA,MMA_M, + │ │ tCrA:(MMA,MMA_M, │ │ tCrB:(MMA,MMA_N, │ MMA_N) + │ │ MMA_K,STAGE) │ │ MMA_K,STAGE) │ [GMEM] (epi dest) + │ │ [SMEM descriptors] │ │ [SMEM descriptors] │ │ + │ │ └─────────────┐ │ │ └─────────────┐ │ │ + ╰─────┤ │ ╰─────┤ │ │ │ + ▼ │ ▼ │ │ │ + tma_partition(tma, │ tma_partition(tma, │ │ │ + sA, tCgA) │ sB, tCgB) │ │ │ + → tAsA, tAgA │ → tBsB, tBgB │ │ │ + ▼ │ ▼ │ │ │ + ┌───┴────────────────────┐ │ ┌──────┴─────────────────┐│ │ │ + │ TMA copy loop (A path):│ │ │ TMA copy loop (B path):││ │ │ + │ copy(tma, tAgA[k], │ │ │ copy(tma, tBgB[k], ││ │ │ + │ tAsA[stage]) │ │ │ tBsB[stage]) ││ │ │ + ┌─▶│ (writes into sA; │ │ ┌──▶│ (writes into sB; ││ │ │ + │ │ tCrA reads same sA) │ │ │ │ tCrB reads same sB) ││ │ │ + │ │ repeat for next k/stage│ │ │ │ repeat for next k/stage││ │ │ + │ └────────────────────────┘ │ │ └────────────────────────┘│ │ │ + │ │ │ │ │ │ │ │ + └────────┘ ▼ └─────────┘ ▼ ▼ │ + └───────┬───────────────────────────────┴──────────────────┘ │ + │ │ + ▼ │ + ┌──────────────────────────────────────────────┐ │ + │ GEMM Loop: | │ + | cute.gemm(tiled_mma, │ │ + │ tCtAcc, D (output), │ │ + ┌──▶ │ tCrA[stage], A (SMEM desc -> sA), │ │ + │ │ tCrB[stage], B (SMEM desc -> sB), │ │ + │ │ tCtAcc) C (accumulator input) │ │ + │ └──────────────────────────────────────────────┘ │ + │ │ │ │ + └───────┘ | │ + ▼ │ + Epilogue: │ + t2r = make_tmem_copy(LdOp, tCtAcc) │ + tTR_tAcc = t2r.partition_S(tCtAcc) │ + tTR_gC = t2r.partition_D(tCgC) ◀────────────────────────────────┘ + tTR_rAcc = make_rmem_tensor(...) + │ + ▼ + LDTM: copy(t2r, tTR_tAcc, tTR_rAcc) + [TMEM → RMEM] + │ + ▼ + Store: copy(atom, tTR_rAcc, tTR_gC) + [RMEM → GMEM] + + +**Naming convention:** + +* ``mma_tiler_mnk`` = ``(BM, BN, BK)`` — per-CTA (or per-CTA-pair) MMA tile +* ``mX`` = a global tensor, such as ``mA``, ``mB``, ``mC`` +* ``gX`` = MMA-tiler tiled GMEM slice, e.g. ``(BM, BK, k)`` for A +* ``tCgX`` = CTA-partitioned GMEM tensor, e.g. ``(MMA, MMA_M, MMA_K, k)`` for A +* ``sX`` = SMEM allocation (``sA``, ``sB``) +* ``tCrX`` = SMEM-descriptor MMA fragment, e.g. ``(MMA, MMA_M, MMA_K, STAGE)`` for A +* ``tCtX`` = TMEM tensor; ``tCtAcc`` = TMEM accumulator ``(MMA, MMA_M, MMA_N[, ACC_STAGE])`` +* ``tAsA`` / ``tBsB`` = TMA-partitioned SMEM views of A / B +* ``tAgA`` / ``tBgB`` = TMA-partitioned GMEM views of A / B +* ``tTR_*`` = T2R (TMEM→RMEM) partitioned tensors used in the epilogue + + +Setting up the TiledMMA, MMA Ops +--------------------------------- + +As shown in the data flow overview, CuTe DSL provides many utilities to tile/partition +the global memory tensors, and create fragment views of SMEM and TMEM tensors for MMA instructions. + +To utilize these functions, we need to setup the TiledMMA, MMA Ops first. + +Creating a tcgen05 MMA Op +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +A tcgen05 MMA op describes the hardware instruction to use, it has parameters like +data types, instruction shape, CTA group, operand A source (SMEM or TMEM), +and operand major modes. + +.. code-block:: python + + import cutlass + import cutlass.cute as cute + from cutlass.cute.nvgpu import tcgen05, OperandMajorMode + + op = tcgen05.MmaF16BF16Op( + cutlass.Float16, # A/B element type + cutlass.Float32, # accumulator type + (128, 256, 16), # instruction shape (M, N, K) + tcgen05.CtaGroup.ONE, # CTA group + tcgen05.OperandSource.SMEM, # A operand from shared memory + OperandMajorMode.K, # A is K-major + OperandMajorMode.K, # B is K-major + ) + +The key parameters are: + +- **Instruction shape** ``(M, N, K)``: determines the size of one hardware MMA + instruction. Larger M and N amortize instruction overhead. +- **OperandSource**: ``SMEM`` reads A from a shared memory descriptor; ``TMEM`` + reads A directly from tensor memory. +- **OperandMajorMode**: ``K`` for K-major (default), ``MN`` for transposed layout. + Transpose A requires ``a_src=SMEM``; when ``a_src=TMEM``, A is always K-major. + + +CuTe DSL provides implementation of many tcgen05 MMA ops: + +.. list-table:: tcgen05 MMA ops + :header-rows: 1 + :widths: 30 24 46 + + * - PTX name + - Python class + - Constructor parameters + * - ``tcgen05.mma.cta_group::{cg}.kind::tf32`` + - ``tcgen05.MmaTF32Op`` + - ``instruction_shape, cta_group, a_src, a_major_mode, b_major_mode`` + * - ``tcgen05.mma.cta_group::{cg}.kind::f16`` + - ``tcgen05.MmaF16BF16Op`` + - ``ab_dtype, acc_dtype, instruction_shape, cta_group, a_src, a_major_mode, b_major_mode`` + * - ``tcgen05.mma.cta_group::{cg}.kind::i8`` + - ``tcgen05.MmaI8Op`` + - ``ab_dtype, instruction_shape, cta_group, a_src, a_major_mode, b_major_mode`` + * - ``tcgen05.mma.cta_group::{cg}.kind::f8f6f4`` + - ``tcgen05.MmaF8F6F4Op`` + - ``a_dtype, b_dtype, acc_dtype, instruction_shape, cta_group, a_src, a_major_mode, b_major_mode`` + * - ``tcgen05.mma.cta_group::{cg}.kind::mxf8f6f4.block_scale`` + - ``tcgen05.MmaMXF8F6F4Op`` + - ``a_dtype, b_dtype, instruction_shape, cta_group, a_src, a_major_mode, b_major_mode`` + * - ``tcgen05.mma.cta_group::{cg}.kind::mxf4.block_scale`` + - ``tcgen05.MmaMXF4Op`` + - ``instruction_shape, cta_group, a_src`` + * - ``tcgen05.mma.cta_group::{cg}.kind::mxf4nvf4.block_scale`` + - ``tcgen05.MmaMXF4NVF4Op`` + - ``sf_dtype, instruction_shape, cta_group, a_src`` + + +Creating a Tiled MMA +~~~~~~~~~~~~~~~~~~~~~ + +A ``TiledMma`` tiles the MMA atom across the thread block. You can pass the op +directly or create an explicit atom first. + +.. code-block:: python + + # Option 1: directly from op (common shorthand) + tiled_mma = cute.make_tiled_mma(op) + + # Option 2: explicit atom creation + atom = cute.make_mma_atom(op) + tiled_mma = cute.make_tiled_mma(atom) + + +Spatial tiling with a repeat count (using ``atom_layout_mnk``) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +A repeat tuple ``(M_rep, N_rep, K_rep)`` replicates the atom across the M, N, and K dimensions, producing +a larger tiled MMA that covers a bigger CTA tile. + +.. code-block:: python + + tiled_mma = cute.make_tiled_mma(atom, (2, 2, 1)) + +The coordinates of atoms could be thought as a 4D coordinate: (v, m, n, k). +v is the CTAs for a single MMA (for CtaGroup.ONE always 0, for CtaGroup.TWO always 0 or 1), +m is the M dimension repeat count, n is the N dimension repeat count, and k is the K dimension repeat count. + +.. code-block:: text + + MMA Atom CtaGroup.ONE make_tiled_mma(atom, (2, 2, 1)) + +---------------+ +----------------+----------------+ + | | | | | ^ + | 128 x 256 | | Atom (0,0,0,0) | Atom (0,0,1,0) | | + | x 16 | --(2,2,1)--> | 128 x 256 | 128 x 256 | | 2 x M_atom + | | repeat | x 16 | x 16 | | = 256 + | | | | | | + +---------------+ +----------------+----------------+ | + | | | | + | Atom (0,1,0,0) | Atom (0,1,1,0) | | + | 128 x 256 | 128 x 256 | | + | x 16 | x 16 | | + | | | v + +----------------+----------------+ + <---- 2 x N_atom = 512 --------> + K unchanged = 16 + +.. code-block:: text + + MMA Atom CtaGroup.TWO make_tiled_mma(atom, (2, 2, 1)) + +---------------+ +----------------+----------------+ + | CTA v = 0 | | Atom (0,0,0,0) | Atom (0,0,1,0) | ^ + | 128 x 256 | | 128 x 256 | 128 x 256 | | + | x 16 | | x 16 | x 16 | | 2CTA Atom + +...............+ +................+................+ | + | CTA v = 1 | --(2,2,1)--> | Atom (1,0,0,0) | Atom (1,0,1,0) | | + | 128 x 256 | repeat | 128 x 256 | 128 x 256 | | + | x 16 | | x 16 | x 16 | v + +---------------+ +----------------+----------------+ + | Atom (0,1,0,0) | Atom (0,1,1,0) | ^ + | 128 x 256 | 128 x 256 | | + | x 16 | x 16 | | 2CTA Atom + +................+................+ | + | Atom (1,1,0,0) | Atom (1,1,1,0) | | + | 128 x 256 | 128 x 256 | | + | x 16 | x 16 | v + +----------------+----------------+ + <---- 2 x N_atom = 512 --------> + Per CTA: 2 x M_atom = 256 + Cluster M (v*m*128): 512 + K unchanged = 16 + + +Custom tile permutation with ``permutation_mnk`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +``make_tiled_mma`` accepts an optional ``permutation_mnk`` argument that +controls how the atom tiles are laid out across the M, N, and K dimensions. +``permutation_mnk`` is a tuple of layouts or ints that represent the tile size and reordering of values. +These permutation operations could be applied to optimize the data access patterns for MMAs. + +For example, with ``inst_m=256`` and 2 atoms in M (total M tile = 512), +a permutation can interleave the two atoms' M rows: + +.. code-block:: python + + # inst_m=256, inst_n=256, inst_k=16 + m_layout = cute.make_layout( + shape=(128, 2, 2), # (inst_m // 2, 2, 2) + stride=(1, 256, 128), # (1, inst_m, inst_m // 2) + ) + tiled_mma = cute.make_tiled_mma( + atom, + atom_layout_mnk=(1, 1, 1), + permutation_mnk=(m_layout, 256, 16), + ) + +The layout ``(128,2,2):(1,256,128)`` maps logical flat indices to physical +M rows in colex order (mode 0 fastest), interleaving the two atoms' halves: + +.. code-block:: text + + Without permutation With permutation_mnk + (sequential, default) m_layout = (128,2,2):(1,256,128) + + +---------------+ ^ ^ +---------------+ ^ + | MMA 0 top | | 128 CTA 0 | | MMA 0 top | | 128 CTA 0 + | rows 0-127 | | | | rows 0-127 | | + +...............+ + | Tile 0 +---------------+ v + | MMA 0 bottom | | 128 CTA 1 | | MMA 1 top | | 128 CTA 0 + | rows 128-255 | | | | rows 128-255 | | + +---------------+ v v +---------------+ v + | MMA 1 top | ^ ^ | MMA 0 bottom | | 128 CTA 1 + | rows 256-383 | | 128 CTA 0 | | rows 256-383 | | + +...............+ + | Tile 1 +---------------+ v + | MMA 1 bottom | | 128 CTA 1 | | MMA 1 bottom | | 128 CTA 1 + | rows 384-511 | | | | rows 384-511 | | + +---------------+ v v +---------------+ v + <-- inst_N=256 -> <-- inst_N=256 -> + inst_K = 16 inst_K = 16 + + Tile 0: rows 0-255 (contiguous) Tile 0: rows {0-127, 256-383} + Tile 1: rows 256-511 (contiguous) Tile 1: rows {128-255, 384-511} + CTA 0 owns rows {0-127, 256-383} CTA 0 owns rows {0-127, 256-383} + CTA 1 owns rows {128-255, 384-511} CTA 1 owns rows {128-255, 384-511} + +When ``permutation_mnk`` is not provided (default), the tile ordering is +sequential and no permutation is applied. + +Creating Trivial Tiled MMA +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Since tcgen05 MMAs have quite large instruction shapes, most common TiledMmas created are +trivial tiled MMAs, with single M, N repetitions, i.e., ``atom_layout_mnk``, and ``permutation_mnk`` are generally unused. +CuTe DSL provides a convenience function ``make_trivial_tiled_mma`` to create such trivial MMAs with +automatic MmaOp kind selection based on the data types. + +.. code-block:: python + + import cutlass.utils.blackwell_helpers as sm100_utils + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + a_dtype, + b_dtype, + a_major_mode, + b_major_mode, + acc_dtype, + cta_group, + mma_tiler_mnk, + ) + + # Equivalent to + tiled_mma = cute.make_tiled_mma( + cute.make_mma_atom( + cute.MmaXyzOp( + # ... parameters of MmaXyzOp + ), + ), + ) + +Partitioning Tensors +-------------------- + +Before computing MMAs, we want to partition the global memory tensors according to the +tiled MMA layout. For tcgen05, this maps each CTA's work to the correct portion of +the global memory tensors. + +We have two steps to partition the global memory tensors: + +* Local tile partitioning: partition the global memory tensors into local tiles, each of size ``mma_tiler_mnk``. + This is the portion of the global memory tensors that will be processed by a single CTA MMA or a 2CTA cooperative MMA. +* MMA partition: partition the local tile into CTA-sized, per-MMA-instruction tiles (note that each CTA needs + to load its own portion to SMEM for 2CTA cooperative MMA). The per-operand shapes are + ``(MMA, MMA_M, MMA_K, ...)`` for A, ``(MMA, MMA_N, MMA_K, ...)`` for B, and ``(MMA, MMA_M, MMA_N, ...)`` for C. + +Note that for tcgen05, SMEM tensors are not partitioned. +See `Making Fragments`_ for more details. + +Trivial TiledMma with CtaGroup.ONE MMAs (single CTA): +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +For the trivial tiled MMAs with CtaGroup.ONE tcgen05 MMA operations, partitioning the mma_tiler sized tile +is an identity operation, i.e., single CTAs' tile is the same as the mma_tiler sized tile. +The main difference between the result of ``local_tile`` and ``partition_[A/B]`` is that, the latter +produces a view that can be iterated in per-MMA instruction fashion. + +Example: ``GEMM (M, N, K) = (512, 768, 384)``, ``mma_tiler_mnk = (128, 256, 64)``, +``CtaGroup.ONE``, F16 atom = 128x256x16 (inst_M=128, inst_N=256, inst_K=16). + +Global memory tensors: + +.. code-block:: text + + mA: (M, K) = (512, 384) mB: (N, K) = (768, 384) mC: (M, N) = (512, 768) + + K=384 K=384 N=768 + |<----------->| |<----------->| |<----------------->| + +-------------+ +-------------+ +---+---+---+-------+ + | | ^ | | ^ | | | | | ^ + | mA | | M=512 | mB | | N=768 | | | | | | M=512 + | | v | | v | | | | | v + +-------------+ +-------------+ +---+---+---+-------+ + +Tiling with ``mma_tiler_mnk = (BM, BN, BK) = (128, 256, 64)`` gives +M/BM = 512/128 = 4 tiles, N/BN = 768/256 = 3 tiles, K/BK = 384/64 = 6 tiles: + +.. code-block:: text + + mA tiled into (M/BM x K/BK) mB tiled into (N/BN x K/BK) mC tiled into (M/BM x N/BN) + = (4 x 6) blocks = (3 x 6) blocks = (4 x 3) blocks + * coordinates annotated on the matrix + are the mma_coord_mn of the GEMM. + + BK=64 x6 BK=64 x6 BN=256 x3 + |<--->| |<--->| |<----->| + +-----+-----+-- --+ +-----+-----+-- --+ +-------+-------+-------+ + | | |..| | ^ | | |..| | ^ | (0,0) | (0,1) | (0,2) | ^ + | | | | | | BM=128 | | | | | | BN=256 | | | | | BM=128 + +-----+-----+-- --+ v +-----+-----+-- --+ v +-------+-------+-------+ v + | | |..| | ^ | | |..| | ^ | (1,0) | (1,1) | (1,2) | ^ + | | | | | | BM=128 | | | | | | BN=256 | | | | | BM=128 + +-----+-----+-- --+ v +-----+-----+-- --+ v +-------+-------+-------+ v + | | |..| | ^ | | |..| | ^ | (2,0) | (2,1) | (2,2) | ^ + | | | | | | BM=128 | | | | | | BN=256 | | | | | BM=128 + +-----+-----+-- --+ v +-----+-----+-- --+ v +-------+-------+-------+ v + | | |..| | ^ | (3,0) | (3,1) | (3,2) | ^ + | | | | | | BM=128 | | | | | BM=128 + +-----+-----+-- --+ v +-------+-------+-------+ v + +Each CTA picks one (M-coord, N-coord) coordinate. +For example, CTA at ``mma_coord = (0, 1, :)``. + +After ``local_tile`` — one CTA's tile has ``k = K/BK = 384/64 = 6`` tiles to process +for A, B tensors, and a single tile for C tensor: + +.. code-block:: text + + gA: (BM, BK, k) = (128, 64, 6) gB: (BN, BK, k) = (256, 64, 6) gC: (BM, BN) = (128, 256) + (k has 6 tiles total: indices 0..5) + + BK=64 BK=64 BN=256 + |<----->| |<----->| |<--------->| + +-------+---------+-------+ +-------+---------+-------+ +-----------+ + | | | | | | | | | | ^ + BM= | gA k0 | k1...k4 | gA k5 | BN= | gB k0 | k1...k4 | gB k5 | BM= | gC | 128 + 128 | | | | 256 | | | | 128 | | v + +-------+---------+-------+ +-------+---------+-------+ +-----------+ + +``get_slice(0)`` — single CTA owns the full tile. +BM and BN match the atom, BK is split into MMA_K atom-sized steps: + +.. code-block:: text + + gA (BK split into MMA_K atoms) gC + inst_K inst_K inst_K inst_K + =16 =16 =16 =16 + |<--->|<--->|<--->|<--->| + +-----+-----+-----+-----+-- +-----------+ + | 0 | 1 | 2 | 3 |.. BM=128 (MMA_M=1) | | BM=128 + +-----+-----+-----+-----+ +-----------+ + |<-- MMA_K = BK/inst_K = 4 -->| + + gB (BK split into MMA_K atoms) + inst_K inst_K inst_K inst_K + =16 =16 =16 =16 + |<--->|<--->|<--->|<--->| + +-----+-----+-----+-----+-- + | 0 | 1 | 2 | 3 |.. BN=256 (MMA_N=1) + +-----+-----+-----+-----+ + +After partition (single CTA): + +- ``tCgA: (MMA, MMA_M, MMA_K, k) = (MMA, 1, 4, 6)`` — MMA_M = BM/inst_M = 128/128 = 1, MMA_K = BK/inst_K = 64/16 = 4 +- ``tCgB: (MMA, MMA_N, MMA_K, k) = (MMA, 1, 4, 6)`` — MMA_N = BN/inst_N = 256/256 = 1, MMA_K = BK/inst_K = 64/16 = 4 +- ``tCgC: (MMA, MMA_M, MMA_N) = (MMA, 1, 1)`` — MMA_M = BM/inst_M = 1, MMA_N = BN/inst_N = 1 + +With CuTe DSL, all these calculations are handled for you provided ``mma_tiler_mnk`` and ``mma_coord``. + +.. code-block:: python + + @cute.kernel + def kernel(tiled_mma: cute.TiledMma, ...): + gA = cute.local_tile(mA, mma_tiler_mnk, mma_coord, proj=(1, None, 1)) # (BM, BK, k) for A + gB = cute.local_tile(mB, mma_tiler_mnk, mma_coord, proj=(None, 1, 1)) # (BN, BK, k) for B + gC = cute.local_tile(mC, mma_tiler_mnk, mma_coord, proj=(1, 1, None)) # (BM, BN) for C + + # Single CTA MMA: cta index is always 0 + thr_mma = tiled_mma.get_slice(0) + + tCgA = thr_mma.partition_A(gA) # (MMA, MMA_M, MMA_K, num_k_tiles) for A + tCgB = thr_mma.partition_B(gB) # (MMA, MMA_N, MMA_K, num_k_tiles) for B + tCgC = thr_mma.partition_C(gC) # (MMA, MMA_M, MMA_N) for C + +Trivial TiledMma with CtaGroup.TWO MMAs (2-CTA cluster, each CTA owns half the M-tile): +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +With ``CtaGroup.TWO``, two CTAs cooperate on a single tile. The V-coordinate +(0 or 1) identifies which CTA within the pair. ``get_slice(V)`` gives each CTA +its half of the M dimension, while B is fully shared. + + +Example: ``GEMM (M, N, K) = (512, 768, 384)``, ``mma_tiler_mnk = (256, 256, 64)``, +``CtaGroup.TWO``, F16 MMA atom = 128x256x16 (inst_M=128, inst_N=256, inst_K=16). + +Global matrices: + +.. code-block:: text + + mA: (M, K) = (512, 384) mB: (N, K) = (768, 384) mC: (M, N) = (512, 768) + + K=384 K=384 N=768 + |<----------->| |<----------->| |<----------------->| + +-------------+ +-------------+ +---+---+---+-------+ + | | ^ | | ^ | | | | | ^ + | mA | | M=512 | mB | | N=768 | | | | | | M=512 + | | v | | v | | | | | v + +-------------+ +-------------+ +---+---+---+-------+ + +Tiling with ``mma_tiler_mnk = (BM, BN, BK) = (256, 256, 64)`` gives +M/BM = 512/256 = 2 tiles in M-mode, N/BN = 768/256 = 3 tiles in N-mode, K/BK = 384/64 = 6 tiles in K-mode: + +.. code-block:: text + + mA tiled into (M/BM x K/BK) mB tiled into (N/BN x K/BK) mC tiled into (M/BM x N/BN) + = (2 x 6) blocks = (3 x 6) blocks = (2 x 3) blocks + * coordinates annotated on the matrix + are the mma_coord_mn of the GEMM. + + BK=64 x6 BK=64 x6 BN=256 x3 + |<--->| |<--->| |<----->| + +-----+-----+-- --+ +-----+-----+-- --+ +-------+-------+-------+ + | | |..| | ^ | | |..| | ^ | (0,0) | (0,1) | (0,2) | ^ + | | | | | | BM=256 | | | | | | BN=256 | | | | | BM=256 + +-----+-----+-- --+ v +-----+-----+-- --+ v +-------+-------+-------+ v + | | |..| | ^ | | |..| | ^ | (1,0) | (1,1) | (1,2) | ^ + | | | | | | BM=256 | | | | | | BN=256 | | | | | BM=256 + +-----+-----+-- --+ v +-----+-----+-- --+ v +-------+-------+-------+ v + | | |..| | ^ + | | | | | | BN=256 + +-----+-----+-- --+ v + +Each CTA pair picks one (M-coord, N-coord) coordinate. +For example, CTA pair at ``mma_coord_mnk = (0, 0, :)``. + +.. code-block:: text + + gA: (BM, BK, k) = (256, 64, 6) gB: (BN, BK, k) = (256, 64, 6) gC: (BM, BN) = (256, 256) + + BK=64 BK=64 BN=256 + |<----->| |<----->| |<--------->| + +-------+-- +-------+-- +-----------+ + | |.. | |.. | | ^ + BM= | gA | k=6 BN= | gB | k=6 BM= | gC | 256 + 256 | | 256 | | | | v + +-------+ +-------+ +-----------+ + +``get_slice(V)`` splits BM between CTAs; BK is split into ``MMA_K`` steps: + +.. code-block:: text + + gA (BM split, BK split into MMA_K atoms) gC (BM split) + inst_K inst_K inst_K inst_K + =16 =16 =16 =16 + |<--->|<--->|<--->|<--->| + +-----+-----+-----+-----+-- +-----------+ + | 0 | 1 | 2 | 3 |.. ^ CTA 0 | CTA 0 | ^ + | | | | | | BM/2=128 (V=0) | (V=0) | | BM/2=128 + +-----+-----+-----+-----+ v +-----------+ v + | 0 | 1 | 2 | 3 |.. ^ CTA 1 | CTA 1 | ^ + | | | | | | BM/2=128 (V=1) | (V=1) | | BM/2=128 + +-----+-----+-----+-----+ v +-----------+ v + |<-- MMA_K = BK/inst_K = 4 -->| + + gB (BN split for SMEM loading, BK split into MMA_K atoms) + inst_K inst_K inst_K inst_K + =16 =16 =16 =16 + |<--->|<--->|<--->|<--->| + +-----+-----+-----+-----+-- + | 0 | 1 | 2 | 3 |.. ^ CTA 0 + | | | | | | BN/2=128 (V=0) + +-----+-----+-----+-----+ v + | 0 | 1 | 2 | 3 |.. ^ CTA 1 + | | | | | | BN/2=128 (V=1) + +-----+-----+-----+-----+ v + +Both CTAs consume the full gB for MMA, but for SMEM loading each CTA loads +its N-half. + +After partition (per CTA, e.g. CTA 0): + +- ``tCgA: (MMA, MMA_M, MMA_K, k) = (MMA, 1, 4, 6)`` — MMA_M = (BM/2)/inst_M = 128/128 = 1, MMA_K = BK/inst_K = 64/16 = 4 +- ``tCgB: (MMA, MMA_N, MMA_K, k) = (MMA, 1, 4, 6)`` — MMA_N = BN/inst_N = 256/256 = 1, MMA_K = BK/inst_K = 64/16 = 4 +- ``tCgC: (MMA, MMA_M, MMA_N) = (MMA, 1, 1)`` — MMA_M = (BM/2)/inst_M = 1, MMA_N = BN/inst_N = 1 + +Of course with CuTe DSL none of these calculations are needed. +The DSL handles all the tiling and partitioning for you provided ``mma_tiler_mnk`` and ``mma_coord``. + +.. code-block:: python + + @cute.kernel + def kernel(tiled_mma: cute.TiledMma, cta_layout_vmnk: cute.Layout, ...): + bidx, bidy, _ = cute.arch.block_idx() + + # V-coordinate: which CTA within the 2-CTA group (0 or 1) + mma_coord_vmnk = ( + bidx % cute.size(cta_layout_vmnk, mode=[0]), # V (CTA rank) + bidx // cute.size(cta_layout_vmnk, mode=[0]), # M tile + bidy, # N tile + None, # K (all tiles) + ) + mma_coord_mnk = mma_coord_vmnk[1:] + + gA = cute.local_tile(mA, mma_tiler_mnk, mma_coord_mnk, proj=(1, None, 1)) + gB = cute.local_tile(mB, mma_tiler_mnk, mma_coord_mnk, proj=(None, 1, 1)) + gC = cute.local_tile(mC, mma_tiler_mnk, mma_coord_mnk, proj=(1, 1, None)) + + # 2-CTA: each CTA passes its V-coordinate to get its half of the work + thr_mma = tiled_mma.get_slice(mma_coord_vmnk[0]) + + tCgA = thr_mma.partition_A(gA) # (MMA, MMA_M, MMA_K, num_k_tiles) + tCgB = thr_mma.partition_B(gB) # (MMA, MMA_N, MMA_K, num_k_tiles) + tCgC = thr_mma.partition_C(gC) # (MMA, MMA_M, MMA_N) + +.. note:: Annotation `tCgX` means that the tensor is partitioned w.r.t C matrix coordinates, i.e., the output tile of each CTA. + +Pre and Post-Conditions for TiledMMA Partitioning +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +* The inputs of the partition should be at least rank-2 tensors. +* The output of the partition will have the layout that is compatible with the MMA atom's operand: + - For A, the output will have the layout (MMA, MMA_M, MMA_K, ...). + - For B, the output will have the layout (MMA, MMA_N, MMA_K, ...). + - For C, the output will have the layout (MMA, MMA_M, MMA_N, ...). +* Note that the partition doesn't enforce any rules on the tensor's memory space or the tensor's data type. It only cares about the layout. + + +What happens when we use ``atom_layout_mnk``? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The valid coordinates to ``get_slice`` are the valid coordinates to (v,m,n,k) coordinate space +of the tiled MMA. ``mma_tiler_mnk`` should be updated such that ``mma_tiler_mnk[0] >= mma_shape[0] * |m|``, +``mma_tiler_mnk[1] >= mma_shape[1] * |n|``, and ``mma_tiler_mnk[2] >= mma_shape[2] * |k|``. + +The result of the ``partition_A``, ``partition_B``, and ``partition_C`` remain the same. + +Making Fragments +----------------- + +Fragments are the descriptor-level tensors that the MMA +instruction operates on. For tcgen05: + +- **Fragment A**: SMEM descriptor when ``a_src=SMEM``, or a TMEM address when + ``a_src=TMEM``. +- **Fragment B**: SMEM descriptor pointing into staged shared memory buffers. +- **Fragment C (accumulator)**: lives in Tensor Memory (TMEM), allocated via + ``TmemAllocator``. + +Creating fragment descriptors and descriptor tensors +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Unlike older architectures where fragments live in per-thread registers, +tcgen05 fragments are **descriptors** pointing into SMEM (for A and B) or +**addresses** into TMEM (for the accumulator C). The fragment creation has +three parts: + +**1. A and B fragments** + +*When A comes from SMEM* (``a_src=OperandSource.SMEM``): + +``make_fragment_A`` and ``make_fragment_B`` take the staged SMEM tensors +(``sA``, ``sB``) and produce descriptor tensors that the MMA instruction +consumes. Each descriptor points to one stage's tile in shared memory. + +.. code-block:: python + + # 1. Build the SMEM layouts (see "Creating SMEM layouts for A and B") + # a_smem_layout = sm100_utils.make_smem_layout_a(tiled_mma, mma_tiler_mnk, ...) + # b_smem_layout = sm100_utils.make_smem_layout_b(tiled_mma, mma_tiler_mnk, ...) + + # 2. Allocate SMEM tensors from those layouts + # sA = smem.allocate_tensor(layout=a_smem_layout.outer, swizzle=a_smem_layout.inner, ...) + # sB = smem.allocate_tensor(layout=b_smem_layout.outer, swizzle=b_smem_layout.inner, ...) + + # 3. Create fragment descriptors from the SMEM tensors + tCrA = tiled_mma.make_fragment_A(sA) # (MMA, MMA_M, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) # (MMA, MMA_N, MMA_K, STAGE) + +Continuing the CtaGroup.ONE example (m128n256k16 atom, +``mma_tiler_mnk = (128, 256, 64)``, 3 pipeline stages): + +.. code-block:: text + + sA is an SMEM tensor with shape (MMA, MMA_M, MMA_K, STAGES), + allocated with appropriate size (see "Creating SMEM layouts for A and B"). + + For 128x256x16 atom, mma_tiler_mnk = (BM, BN, BK) = (128, 256, 64), 3 stages: + MMA_M = BM/inst_M = 128/128 = 1, MMA_K = BK/inst_K = 64/16 = 4, STAGES = 3 + + sA: (MMA, MMA_M=1, MMA_K=4, STAGES=3) + + |<--MMA_K = BK/inst_K = 4-->| + stage 0: +------+------+------+------+ + | k=0 | k=1 | k=2 | k=3 | inst_M=128 + +------+------+------+------+ + stage 1: +------+------+------+------+ + | k=0 | k=1 | k=2 | k=3 | inst_M=128 + +------+------+------+------+ + stage 2: +------+------+------+------+ + | k=0 | k=1 | k=2 | k=3 | inst_M=128 + +------+------+------+------+ + + make_fragment_A(sA) produces SMEM descriptors with the same shape: + tCrA: (MMA, MMA_M, MMA_K, STAGES) = (MMA, 1, 4, 3) + + Each element is an SMEM descriptor — one per (MMA_K, STAGE) pair. + Similarly for sB/tCrB with shape (MMA, MMA_N=1, MMA_K=4, STAGE=3). + +Each element of ``tCrA`` / ``tCrB`` is an SMEM descriptor that the MMA +hardware reads directly — not a register value. Note that, when we print the +layout of ``tCrA`` (or similarly ``tCrB``), we will see that ``MMA`` dimension +of ``(MMA, MMA_M, MMA_K, STAGES)`` will appear to be ``1``. This is because +this mode is an indivisible SMEM descriptor representing the whole SMEM buffer +that a single MMA instruction will consume. + +*When A comes from TMEM* (``a_src=OperandSource.TMEM``): + +In use cases like FMHA or mixed-input GEMM, operand A can be sourced from +TMEM instead of SMEM. In this case, ``make_fragment_A`` is called to obtain +the layout, but the fragment is bound to a TMEM pointer instead of an SMEM +tensor: + +.. code-block:: python + + # Build the SMEM layout for A (see `Creating SMEM layouts for A and B`_). + # The layout defines the tile shape the MMA expects, even though the data + # will live in TMEM. + # a_smem_layout = sm100_utils.make_smem_layout_a(...) + + # Use make_fragment_A with the outer layout to get the expected shape + tCrA_layout = tiled_mma.make_fragment_A(a_smem_layout.outer).layout + + # Compute the TMEM pointer offset (A is placed after the accumulator columns). + # TMEM columns are 32-bit wide, so scale to element offset for narrower types + # (e.g. Float16: scale = 32 // 16 = 2). + column_to_element_scale = 32 // acc_dtype.width + tmem_ptr_a = cute.recast_ptr( + accumulators.iterator + num_acc_tmem_cols * column_to_element_scale, + dtype=mma_dtype, + ) + + # Bind to TMEM storage + tCrA = cute.make_tensor(tmem_ptr_a, tCrA_layout) + +The A fragment in TMEM is laid out after the accumulator's TMEM columns. + +**2. C fragment (accumulator) — TMEM allocation** + +The accumulator lives in Tensor Memory (TMEM), a dedicated on-chip memory +separate from registers and SMEM. Creating the C fragment is a four-step +process: + +.. code-block:: python + + # Step 1: Query the partitioned accumulator shape + acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2]) + # acc_shape: (MMA, MMA_M, MMA_N) = (MMA, 1, 1) + + # Step 2: Append a staging dimension for ping-pong overlap + # Use 1 for simple kernels (no overlap), or 2+ to overlap + # MMA and epilogue on different TMEM buffers. + num_acc_stages = 2 + acc_shape_staged = cute.append(acc_shape, num_acc_stages) + + # Step 3: Create a fragment to establish the layout + tCtAcc = tiled_mma.make_fragment_C(acc_shape_staged) + # tCtAcc: (MMA, MMA_M, MMA_N, ACC_STAGE) + + # Step 4: Bind to actual TMEM storage + tmem_ptr = tmem.retrieve_ptr(cutlass.Float32) + tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc.layout) + +.. code-block:: text + + partition_shape_C((BM, BN) = (128, 256)) + -> (MMA, MMA_M, MMA_N) = (MMA, 1, 1) + MMA_M = BM/inst_M = 128/128 = 1 + MMA_N = BN/inst_N = 256/256 = 1 + + cute.append(acc_shape, 2) + -> (MMA, 1, 1, 2) + + make_fragment_C(acc_shape_staged) + -> tCtAcc layout: ((128, 256), 1, 1, 2) + + After binding to TMEM (2-stage ping-pong): + +---------------------------+---------------------------+ + | tCtAcc stage 0 | tCtAcc stage 1 | + | 128 x 256 accumulators | 128 x 256 accumulators | + | (Float32) | (Float32) | + +---------------------------+---------------------------+ + + +Creating SMEM layouts for A and B +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The SMEM layouts define how A and B tiles are stored in shared memory, including +swizzling for bank-conflict-free access. The helper functions handle the +details: partitioned shape from the tiled MMA, swizzle atom selection, tiling to +the MMA shape, and multi-stage buffering. + +**Host side** (``@cute.jit``): + +.. code-block:: python + + import cutlass.utils.blackwell_helpers as sm100_utils + + # Create SMEM layouts (includes swizzle + staging) + a_smem_layout = sm100_utils.make_smem_layout_a( + tiled_mma, mma_tiler_mnk, a.element_type, num_stages, + ) + b_smem_layout = sm100_utils.make_smem_layout_b( + tiled_mma, mma_tiler_mnk, b.element_type, num_stages, + ) + +``make_smem_layout_a`` and ``make_smem_layout_b`` are convenience helpers that +build a complete, staged SMEM layout in four steps: + +1. **Determine the major mode.** The major mode (K-major or MN-major) is read + from the MMA op's ``a_major_mode`` / ``b_major_mode`` attribute (or can be + overridden via the ``is_k_major`` keyword argument). + +2. **Compute the partitioned SMEM tile shape.** The tiled MMA is asked for the + partitioned shape of the operand via ``tiled_mma.partition_shape_A`` + (or ``partition_shape_B``). ``cute.dice`` strips the irrelevant mode first + — for A the ``(M, K)`` portion is kept, for B the ``(N, K)`` portion. The + result is a hierarchical shape ``((MMA, MMA_MN, MMA_K), repeat_MN, repeat_K)`` + that is flattened into a 2D ``(MN, K)`` size for swizzle selection. + +3. **Select and materialise the swizzle atom.** A heuristic + (``get_smem_layout_atom_ab``) picks the widest swizzle whose contiguous + size (in bits) evenly divides the major-mode dimension: + + ========== ================== + Swizzle Contiguous bits + ========== ================== + SW128 1024 (128 B) + SW64 512 (64 B) + SW32 256 (32 B) + Interleave 128 (16 B) + ========== ================== + + ``make_smem_layout_atom`` then combines the chosen swizzle with a compact + ``(MN_elems, 8)`` or ``(8, K_elems)`` outer layout (depending on the major + mode) into a ``ComposedLayout(swizzle, outer)``. + +4. **Tile to the MMA shape and append the staging dimension.** + ``tile_to_mma_shape`` broadcasts the atom to the full partitioned shape + (with ``num_stages`` appended). The ``order`` argument controls which + dimension is contiguous: ``(1, 2, 3)`` for K-major (K innermost), + ``(2, 1, 3)`` for MN-major (MN innermost). + +The resulting layout is then fed into SMEM tensors are allocated using +the layout info: + +**Kernel side** (``@cute.kernel``): + +.. code-block:: python + + smem = cutlass.utils.SmemAllocator() + sA = smem.allocate_tensor( + element_type=io_dtype, + layout=a_smem_layout.outer, + byte_alignment=128, + swizzle=a_smem_layout.inner, + ) + sB = smem.allocate_tensor( + element_type=io_dtype, + layout=b_smem_layout.outer, + byte_alignment=128, + swizzle=b_smem_layout.inner, + ) + +.. note:: **Creating SMEM layouts without utilities** + If you want to create SMEM layouts without using the utilities, you can do the following: + + .. code-block:: python + + swizzle = cute.Swizzle(3, 4, 3) + mma_tile = tiled_mma.partition_shape_A((mma_tiler_mnk[0], mma_tiler_mnk[2])) + smem_tile = tcgen05.tile_to_mma_shape(swizzle, mma_tile, order=(1, 2, 3)) + + +Executing the GEMM (Main Loop) +------------------------------- + +The main loop iterates over K-tiles, loading A and B from global memory via TMA +into staged SMEM buffers, then issuing ``cute.gemm`` for each tile. The TMA copy +details are omitted for brevity. + +.. code-block:: python + + for k_tile_idx in cutlass.range(num_k_tiles): + # Wait for TMA load to complete for this K-tile + ab_full = ab_consumer.wait_and_advance() + + # Set accumulate mode: first tile overwrites, subsequent tiles accumulate + tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0) + + # Issue MMA: tCtAcc += tCrA * tCrB + tile_crd = (None, None, None, ab_full.index) + cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc) + + # Release the SMEM buffer for the next TMA load + ab_full.release() + +Key points: + +- ``tcgen05.Field.ACCUMULATE`` controls whether the MMA accumulates into D + (``True``) or overwrites D with ``A * B`` (``False``). Set to ``False`` for + the first K-tile and ``True`` for all subsequent tiles. +- ``cute.gemm`` is asynchronous. Synchronization is handled by the pipeline + barriers (``cutlass.pipeline.sm100.PipelineTmaUmma``). +- The ``tile_crd`` selects which pipeline stage's SMEM buffer to read from. + +Reading the accumulator from TMEM +---------------------------------- + +.. code-block:: python + + tCtAcc = tiled_mma.make_fragment_C(mma_tiler_mnk[:2]) # (MMA, MMA_M, MMA_N) for C + # TMEM allocation (done once, before the main loop) + tmem = cutlass.utils.TmemAllocator(...) + tmem.allocate(num_cols=512) + tmem_ptr = tmem.retrieve_ptr(cutlass.Float32) + tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc.layout) + + # Build copy atom for TMEM → RMEM load + copy_atom_t2r = cute.make_copy_atom( + tcgen05.Ld32x32bOp(tcgen05.Repetition.x64), + cutlass.Float32, + ) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc[(None, None), 0, 0]) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + + # (T2R, T2R_M, NumTiles) + tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc) + # (T2R, T2R_M, NumTiles) + tTR_gC = thr_copy_t2r.partition_D(tCgC) + + # (T2R, T2R_M) + tTR_rAcc = cute.make_rmem_tensor(tTR_gC[None, None, 0].shape, acc_dtype) + + cute.copy(tiled_copy_t2r, tTR_tAcc[None, None, i], tTR_rAcc) + + +Complete Workflow +------------------ + +Putting it all together, a typical Blackwell tcgen05 GEMM has this structure: + +**Host function** (``@cute.jit``): + +.. code-block:: python + + import cutlass + import cutlass.cute as cute + from cutlass.cute.nvgpu import cpasync, tcgen05, OperandMajorMode + import cutlass.utils.blackwell_helpers as sm100_utils + + @cute.jit + def host_function(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor): + # 1. Create the MMA op and tiled MMA + op = tcgen05.MmaF16BF16Op( + cutlass.Float16, cutlass.Float32, + (128, 256, 16), + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + OperandMajorMode.K, + OperandMajorMode.K, + ) + tiled_mma = cute.make_tiled_mma(op) + + # 2. Create SMEM layouts for A and B + a_smem_layout = sm100_utils.make_smem_layout_a( + tiled_mma, mma_tiler_mnk, a.element_type, num_stages, + ) + b_smem_layout = sm100_utils.make_smem_layout_b( + tiled_mma, mma_tiler_mnk, b.element_type, num_stages, + ) + + # 3. Create TMA copy atoms for global -> shared memory loads + copy_op = cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + tma_a = cute.nvgpu.make_tiled_tma_atom_A( + copy_op, a, a_smem_layout, mma_tiler_mnk, tiled_mma, + ) + tma_b = cute.nvgpu.make_tiled_tma_atom_B( + copy_op, b, b_smem_layout, mma_tiler_mnk, tiled_mma, + ) + + # 4. Launch the kernel + grid = cute.ceil_div((*c.layout.shape, 1), mma_tiler_mnk[:2]) + kernel(tiled_mma, tma_a, tma_b, c).launch( + grid=grid, block=(128, 1, 1), + ) + +**Kernel function** (``@cute.kernel``): + +.. code-block:: python + + @cute.kernel + def kernel( + tiled_mma: cute.TiledMma, + tma_a: cpasync.TmaInfo, + tma_b: cpasync.TmaInfo, + mC: cute.Tensor, + ): + # -- Setup -- + bidx, bidy, _ = cute.arch.block_idx() + mma_coord_mnk = (bidx, bidy, None) + + # Global tensors for A and B live inside the TMA descriptor + mA = tma_a.tma_tensor # (M, K) + mB = tma_b.tma_tensor # (N, K) + + # Allocate SMEM for A, B (staged) and pipeline barriers + smem = cutlass.utils.SmemAllocator() + sA = smem.allocate_tensor(...) # staged SMEM for A + sB = smem.allocate_tensor(...) # staged SMEM for B + + # Allocate TMEM for the accumulator + tmem = cutlass.utils.TmemAllocator(...) + tmem.allocate(num_cols=512) + + # -- Partition and make fragments -- + # (BM, BK, k) + gA = cute.local_tile(mA, mma_tiler_mnk, mma_coord_mnk, proj=(1, None, 1)) + # (BN, BK, k) + gB = cute.local_tile(mB, mma_tiler_mnk, mma_coord_mnk, proj=(None, 1, 1)) + # (BM, BN) + gC = cute.local_tile(mC, mma_tiler_mnk, mma_coord_mnk, proj=(1, 1, None)) + + thr_mma = tiled_mma.get_slice(0) + tCgA = thr_mma.partition_A(gA) # (MMA, MMA_M, MMA_K, num_k_tiles) + tCgB = thr_mma.partition_B(gB) # (MMA, MMA_N, MMA_K, num_k_tiles) + tCgC = thr_mma.partition_C(gC) # (MMA, MMA_M, MMA_N) + + # SMEM descriptor fragments + tCrA = tiled_mma.make_fragment_A(sA) # (MMA, MMA_M, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) # (MMA, MMA_N, MMA_K, STAGE) + + # TMEM accumulator + acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2]) + tCtAcc = tiled_mma.make_fragment_C(acc_shape) # (MMA, MMA_M, MMA_N) + + # Bind accumulator to TMEM + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(cutlass.Float32) + tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc.layout) + + # TMA partition for global → shared memory copies + tAsA, tAgA = cute.nvgpu.cpasync.tma_partition( + tma_a.atom, 0, cute.make_layout(1), + cute.group_modes(sA, 0, 3), cute.group_modes(tCgA, 0, 3), + ) + tBsB, tBgB = cute.nvgpu.cpasync.tma_partition( + tma_b.atom, 0, cute.make_layout(1), + cute.group_modes(sB, 0, 3), cute.group_modes(tCgB, 0, 3), + ) + + # -- Main loop: iterate over K-tiles -- + num_k_tiles = cute.size(gA, mode=[2]) + for k_tile_idx in cutlass.range(num_k_tiles): + # TMA load A, B into staged SMEM (producer side) + # copy(tma_a.atom, tAgA[k], tAsA[stage]) + # copy(tma_b.atom, tBgB[k], tBsB[stage]) + # ... (see pipeline documentation) + + # Wait for data + ab_full = ab_consumer.wait_and_advance() + + # MMA + tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0) + tile_crd = (None, None, None, ab_full.index) + cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc) + + ab_full.release() + + # -- Epilogue: copy accumulator from TMEM to global memory -- + copy_atom_t2r = cute.make_copy_atom( + tcgen05.Ld32x32bOp(tcgen05.Repetition.x64), cutlass.Float32, + ) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc[(None, None), 0, 0]) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + + tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc) + tTR_gC = thr_copy_t2r.partition_D(tCgC) + tTR_rAcc = cute.make_rmem_tensor(tTR_gC[None, None, 0].shape, acc_dtype) + + # TMEM → RMEM, then RMEM → GMEM + for i in cutlass.range(num_tiles): + cute.copy(tiled_copy_t2r, tTR_tAcc[None, None, i], tTR_rAcc) + cute.copy(store_atom, tTR_rAcc, tTR_gC[None, None, i]) + + +Beyond Simple Dense MMAs +------------------------ + +The tcgen05 MMA DSL supports more complex MMA operations than just the simple dense MMA. + + +- Block-scaled MMA + +.. {$nv-internal-release begin} + +Internal builds additionally provide: + +- Sparse MMA + +.. {$nv-internal-release end} + +.. {$nv-internal-release begin} + +Sparse MMA +~~~~~~~~~~ + +Sparse MMA exploits **X:Y = {1:2, 2:4, 4:8} structured sparsity** in operand A: out of every +Y consecutive K-elements, exactly X are non-zero and Y-X are zero. The +kernel stores the compressed A values separately from the **metadata** +tensor ``E``, which records which 2 of 4 positions are non-zero. + +Compared to a dense MMA kernel, a sparse kernel differs in five areas: + +**1. MMA op creation** — use ``MmaF16BF16SparseOp`` with an extra +``sparse_metadata_format`` parameter. The instruction K is **doubled** +(32 vs 16 for dense F16/BF16) to account for the compressed operand. The +example here builds its sparse ``TiledMma`` through +``sm100_utils.make_sparse_trivial_tiled_mma(...)``: + +.. code-block:: python + + from cutlass.cute.nvgpu.warp.mma import SparseMetadataFormat + + # Dense F16 (for comparison): inst_K = 16 + dense_op = tcgen05.MmaF16BF16Op( + cutlass.Float16, cutlass.Float32, (128, 256, 16), + tcgen05.CtaGroup.ONE, tcgen05.OperandSource.SMEM, + cute.nvgpu.OperandMajorMode.K, cute.nvgpu.OperandMajorMode.K, + ) + + # Sparse F16: inst_K = 32 (2× dense, since A is 2:4 compressed) + sparse_op = tcgen05.MmaF16BF16SparseOp( + cutlass.Float16, cutlass.Float32, (128, 256, 32), + tcgen05.CtaGroup.ONE, tcgen05.OperandSource.SMEM, + cute.nvgpu.OperandMajorMode.K, cute.nvgpu.OperandMajorMode.K, + SparseMetadataFormat.TID, + ) + + # The sparse GEMM example uses the public helper + tiled_mma = sm100_utils.make_sparse_trivial_tiled_mma( + a_raw_dtype, a_major_mode, b_major_mode, acc_dtype, cta_group, + mma_tiler_mn=mma_tiler_mnk[:2], + sparse_metadata_format=SparseMetadataFormat.TID, + ) + +**2. Compressed A and metadata E tensors** — operand A is stored with +**half** the K-elements (the two non-zero values per group of 4), using a +``sparse_elem<2, dtype>`` type. The metadata tensor E is a compact bit-field +(``sparse_elem<8, uint8>`` for F16/BF16) that encodes the sparsity pattern. + +.. code-block:: python + + # Sparse element types + a_sparse_dtype = sm100_utils.make_sparse_a_dtype(a_raw_dtype) # sparse_elem<2, F16> + e_sparse_dtype = sm100_utils.make_sparse_e_dtype(a_raw_dtype) # sparse_elem<8, uint8> + + # GMEM layouts for compressed A and metadata E + sp_a_ptr = cute.recast_ptr(a.iterator, dtype=a_sparse_dtype) + sp_a_layout = sm100_utils.make_sparse_gmem_layout_a( + mnkl, + a_raw_dtype, + is_k_major=(a_major_mode == cute.nvgpu.OperandMajorMode.K), + sparsity=2, + ) + sp_a_tensor = cute.make_tensor(sp_a_ptr, sp_a_layout) + + sp_e_ptr = cute.recast_ptr(e.iterator, dtype=e_sparse_dtype) + sp_e_layout = sm100_utils.make_sparse_gmem_layout_e(mnkl, a_raw_dtype) + sp_e_tensor = cute.make_tensor(sp_e_ptr, sp_e_layout) + +.. code-block:: text + + Dense A: (M, K) Sparse A (compressed): (M, (2, K/2)) + +--+--+--+--+--+--+--+--+ +--+--+--+--+ + | a| 0| b| 0| c| 0| d| 0| → | a| b| c| d| (only non-zeros stored) + +--+--+--+--+--+--+--+--+ +--+--+--+--+ + + Metadata E encodes positions: E: [00, 10, 00, 10] + (which 2 of 4 are non-zero) ↑ ↑ + positions of a,b in each group + +**3. Extra SMEM layouts, TMA loads, and allocations** — sparse kernels use +dedicated layout helpers for A and E. An additional TMA descriptor loads +the metadata into SMEM alongside A and B. In the example here, +``E`` uses its own logical tile ``mma_tiler_e``: + +.. code-block:: python + + # Host side: SMEM layouts + a_smem_layout = sm100_utils.make_sparse_smem_layout_a( + tiled_mma, mma_tiler_mnk, a_raw_dtype, num_stages, sparsity=2, + ) + e_smem_layout = sm100_utils.make_sparse_smem_layout_e( + tiled_mma, mma_tiler_e, a_raw_dtype, num_stages, + ) + + # Host side: TMA atom for metadata E (note mma_tiler_e and internal_type=Uint64) + a_op = sm100_utils.cluster_shape_to_tma_atom_A(cluster_shape_mn, tiled_mma.thr_id) + tma_atom_e, tma_tensor_e = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + sp_e_tensor, + cute.slice_(e_smem_layout, (None, None, None, 0)), + mma_tiler_e, + tiled_mma, + cluster_layout_vmnk.shape, + internal_type=cutlass.Uint64, + ) + + # Kernel side: SMEM allocation for metadata + sE = smem.allocate_tensor( + element_type=e_sparse_dtype, + layout=e_smem_layout.outer, + byte_alignment=128, + swizzle=e_smem_layout.inner, + ) + +**4. Metadata TMEM allocation and SMEM→TMEM copy (S2T)** — the metadata +must live in TMEM for the MMA instruction. It is placed **after** the +accumulator columns, and an S2T copy moves it from SMEM to TMEM each +K-tile. The example here also recasts both sides to raw ``uint8`` +before building the S2T copy and wraps the SMEM source in an S2T +descriptor tensor: + +.. code-block:: python + + # TMEM layout for metadata (placed after accumulator) + e_tmem_layout = sm100_utils.make_sparse_tmem_layout_e( + cute.slice_(e_smem_layout_staged, (None, None, None, 0)).shape, + a_raw_dtype, + ) + acc_tmem_col_offset = tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + if cutlass.const_expr(acc_dtype.width < 32): + acc_tmem_col_offset = acc_tmem_col_offset * (32 // acc_dtype.width) + e_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + acc_tmem_col_offset, dtype=e_sparse_dtype, + ) + tCtE = cute.make_tensor(e_tmem_ptr, e_tmem_layout) + + # S2T copy setup (SMEM → TMEM for metadata) + e_raw_dtype = cutlass.Uint8 + copy_atom_s2t_e = cute.make_copy_atom( + tcgen05.Cp128x128bOp(cta_group), e_raw_dtype, + ) + tCtE_recast = cute.recast_tensor(tCtE, e_raw_dtype) + tiled_copy_s2t_E = tcgen05.make_s2t_copy(copy_atom_s2t_e, tCtE_recast) + thr_copy_s2t_E = tiled_copy_s2t_E.get_slice(0) + + sE_recast = cute.recast_tensor(sE, e_raw_dtype) + thr_tCsE_s2t_ = thr_copy_s2t_E.partition_S(sE_recast) + thr_tCsE_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_E, thr_tCsE_s2t_ + ) + thr_tCtE_s2t = thr_copy_s2t_E.partition_D(tCtE_recast) + +**5. Modified main loop** — each K-tile iteration loads the metadata via +S2T, then sets the ``METADATA`` field on the atom before calling ``gemm``. +The ``gemm`` call signature itself is unchanged; the metadata is implicit +via the atom field. The full kernel also contains leader-CTA +synchronization and optional metadata reuse when ``utccp_reuse_cnt > 1``; +the schematic below keeps only the dataflow-relevant steps: + +.. code-block:: python + + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + for k_tile in range(k_tile_cnt): + # S2T: move metadata for the current stage from SMEM to TMEM. + cute.copy(tiled_copy_s2t_E, e_smem_stage, e_tmem_stage) + + for kblk_idx in cutlass.range(cute.size(tCrA, mode=[2]), unroll_full=True): + e_idx = metadata_index_for(k_tile, kblk_idx) + tiled_mma.set(tcgen05.Field.METADATA, tCtE[None, None, e_idx].iterator) + cute.gemm(tiled_mma, tCtAcc, tCrA_kblk, tCrB_kblk, tCtAcc) + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + +.. code-block:: text + + Dense main loop (per K-tile): + set(ACCUMULATE, ...) + gemm(tiled_mma, tCtAcc, tCrA[s], tCrB[s], tCtAcc) + + Sparse main loop (schematic, per K-tile): + copy(s2t_E, sE[stage], tCtE) ← metadata SMEM → TMEM + set(METADATA, tCtE[e_idx].iterator) ← point atom at metadata + set(ACCUMULATE, ...) + gemm(tiled_mma, tCtAcc, tCrA[s], tCrB[s], tCtAcc) ← same call + +The epilogue (TMEM → RMEM → GMEM) is identical to a dense kernel. + +.. {$nv-internal-release end} + + +Block-scaled MMA +~~~~~~~~~~~~~~~~ + +Block-scaled MMA multiplies narrow-type matrices (the tcgen05 MXF8 and +MXF4-family ops shown here) with **per-block scale factors** along GEMM-K. +Each ``sf_vec_size`` consecutive K-elements shares one scale factor, so the +hardware computes ``D = (SFA · A) * (SFB · B) + C``. Unlike dense A/B, +SFA/SFB must be staged in **TMEM** before ``gemm``. + +Supported ops: ``MmaMXF8F6F4Op``, ``MmaMXF4Op``, ``MmaMXF4NVF4Op``. + +Compared to a dense MMA kernel, a block-scaled kernel differs in five areas: + +**1. MMA op creation** — block-scaled ops fix the accumulator to FP32 and add +scale-factor typing. The examples usually build ``TiledMma`` through +``sm100_utils.make_blockscaled_trivial_tiled_mma(...)``, which dispatches to +``MmaMXF8F6F4Op``, ``MmaMXF4Op``, or ``MmaMXF4NVF4Op`` from +``(ab_dtype, sf_vec_size)``: + +.. code-block:: python + + # Direct op examples + mxf8_op = tcgen05.MmaMXF8F6F4Op( + cutlass.Float8E4M3FN, (128, 256, 32), + tcgen05.CtaGroup.ONE, tcgen05.OperandSource.SMEM, + cute.nvgpu.OperandMajorMode.K, cute.nvgpu.OperandMajorMode.K, + ) + + # MXF4/NVF4 example (MmaMXF4Op is the sf_vec_size=32 companion) + nvf4_op = tcgen05.MmaMXF4NVF4Op( + cutlass.Float8E8M0FNU, # sf_dtype: UE8M0 or UE4M3 + (128, 256, 64), + tcgen05.CtaGroup.ONE, tcgen05.OperandSource.SMEM, + ) + + # Helper used by the block-scaled examples here + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + ab_dtype, a_major_mode, b_major_mode, sf_dtype, sf_vec_size, + cta_group, mma_tiler_mn, + ) + +**2. Extra scale-factor tensors and SMEM layouts** — derive SFA/SFB tensors +from the A/B shapes, then build staged SMEM layouts for them: + +.. code-block:: python + + import cutlass.utils.blockscaled_layout as blockscaled_utils + + # Scale-factor tensors + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(a_tensor.shape, sf_vec_size) + sfa_tensor = cute.make_tensor(sfa_ptr, sfa_layout) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b_tensor.shape, sf_vec_size) + sfb_tensor = cute.make_tensor(sfb_ptr, sfb_layout) + + # Staged SMEM layouts + sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, mma_tiler_mnk, sf_vec_size, num_ab_stage, + ) + sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, mma_tiler_mnk, sf_vec_size, num_ab_stage, + ) + +**3. Extra TMA loads and SMEM allocations** — there are four GMEM→SMEM loads +instead of two. SFA follows the A-side TMA path; SFB uses +``cluster_shape_to_tma_atom_SFB(...)`` and may use its own tiler/layout in +2CTA kernels. The pipeline byte count also includes the SFA/SFB traffic: + +.. code-block:: python + + # TMA atoms for SFA/SFB (note internal_type=Int16 for packing) + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(cluster_shape_mn, tiled_mma.thr_id) + tma_sfa = cute.nvgpu.make_tiled_tma_atom_A( + sfa_op, sfa_tensor, sfa_smem_layout_staged, + mma_tiler_mnk, tiled_mma, cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(cluster_shape_mn, tiled_mma.thr_id) + tma_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, sfb_tensor, sfb_smem_layout_staged, + mma_tiler_sfb, tiled_mma_sfb, cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # Kernel side: allocate staged SMEM for scale factors + sSFA = smem.allocate_tensor(element_type=sf_dtype, layout=tma_sfa.smem_layout, ...) + sSFB = smem.allocate_tensor(element_type=sf_dtype, layout=tma_sfb.smem_layout, ...) + +**4. Scale-factor TMEM allocation and SMEM→TMEM copy (S2T)** — before each +``gemm``, SFA/SFB must be copied from staged SMEM into TMEM. The examples +compact away zero-stride modes and wrap the SMEM source in an S2T descriptor +tensor: + +.. code-block:: python + + # TMEM allocation for scale factors + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, mma_tiler_mnk, sf_vec_size, + cute.slice_(tma_sfa.smem_layout, (None, None, None, 0)), + ) + tCtSFA = tmem_pool.allocate_tensor(tCtSFA_layout, sf_dtype) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, mma_tiler_mnk, sf_vec_size, + cute.slice_(tma_sfb.smem_layout, (None, None, None, 0)), + ) + tCtSFB = tmem_pool.allocate_tensor(tCtSFB_layout, sf_dtype) + + # S2T copy setup (SMEM → TMEM for scale factors) + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(cta_group), sf_dtype, + ) + + # SFA shown; SFB follows the same pattern. + tCtSFA_compact = cute.filter_zeros(tCtSFA) + tiled_copy_s2t_sfa = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFA_compact) + thr_copy_s2t_sfa = tiled_copy_s2t_sfa.get_slice(0) + tCsSFA_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t_sfa, + thr_copy_s2t_sfa.partition_S(cute.filter_zeros(sSFA)), + ) + tCtSFA_compact_s2t = thr_copy_s2t_sfa.partition_D(tCtSFA_compact) + # Repeat with tCtSFB / sSFB to produce: + # tiled_copy_s2t_sfb, tCsSFB_compact_s2t, and tCtSFB_compact_s2t. + +**5. Modified main loop** — per K-tile, load A/B/SFA/SFB into SMEM, copy +SFA/SFB to TMEM, then call ``gemm`` with ``[value, scale]`` operands. The +persistent kernel separates TMA and MMA warps; the tutorial-style loop below +keeps only the operand flow: + +.. code-block:: python + + for k_tile in cutlass.range(num_k_tiles): + # TMA load A, B, SFA, SFB into SMEM + cute.copy(tma_a.atom, tAgA[None, ab_empty.count], tAsA[None, ab_empty.index], ...) + cute.copy(tma_b.atom, tBgB[None, ab_empty.count], tBsB[None, ab_empty.index], ...) + cute.copy(tma_sfa.atom, tAgSFA[None, ab_empty.count], tAsSFA[None, ab_empty.index], ...) + cute.copy(tma_sfb.atom, tBgSFB[None, ab_empty.count], tBsSFB[None, ab_empty.index], ...) + + ab_full = ab_consumer.wait_and_advance() + + # S2T: copy scale factors from SMEM to TMEM + s2t_stage_coord = (None, None, None, None, ab_full.index) + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t[s2t_stage_coord], + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t[s2t_stage_coord], + tCtSFB_compact_s2t, + ) + + # MMA with scale factors passed as [value, scale] pairs + tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile != 0) + tile_crd = (None, None, None, ab_full.index) + cute.gemm( + tiled_mma, + tCtAcc, + [tCrA[tile_crd], tCtSFA], # A value (SMEM) + A scale (TMEM) + [tCrB[tile_crd], tCtSFB], # B value (SMEM) + B scale (TMEM) + tCtAcc, + ) + + ab_full.release() + +.. code-block:: text + + Dense tcgen05 mainloop (schematic): + gemm(tiled_mma, tCtAcc, tCrA[s], tCrB[s], tCtAcc) + + Block-scaled tcgen05 mainloop (schematic): + copy(s2t_sfa, sSFA[stage], tCtSFA) ← scale A to TMEM + copy(s2t_sfb, sSFB[stage], tCtSFB) ← scale B to TMEM + gemm(tiled_mma, tCtAcc, [tCrA[s], tCtSFA], [tCrB[s], tCtSFB], tCtAcc) + +The epilogue (TMEM → RMEM → GMEM) is identical to a dense kernel. + +See also: + +- Tutorial: step-by-step dense F16 GEMM — ``examples/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_0.py`` + (and ``fp16_gemm_1.py`` through ``fp16_gemm_6.py`` for progressive optimizations) +- Tutorial: block-scaled NVFP4 GEMM — ``examples/cute/blackwell/tutorial/tutorial_gemm/nvfp4_gemm_0.py`` +- Dense GEMM (production): ``examples/cute/blackwell/kernel/dense_gemm/dense_gemm.py`` +- Persistent dense GEMM: ``examples/cute/blackwell/kernel/dense_gemm/dense_gemm_persistent.py`` +- Block-scaled GEMM: ``examples/cute/blackwell/kernel/blockscaled_gemm/dense_blockscaled_gemm_persistent.py`` +- Sparse GEMM: ``examples/cute/blackwell/kernel/sparse_gemm/sparse_gemm_persistent.py`` +- Helper utilities: ``cutlass.utils.blackwell_helpers`` +- Block-scaled layout utilities: ``cutlass.utils.blockscaled_layout`` diff --git a/media/docs/pythonDSL/mma_docs/wgmma_programming.rst b/media/docs/pythonDSL/mma_docs/wgmma_programming.rst new file mode 100644 index 000000000..ba83af07d --- /dev/null +++ b/media/docs/pythonDSL/mma_docs/wgmma_programming.rst @@ -0,0 +1,987 @@ +.. _wgmma_programming: + +Warpgroup MMA Programming Guide +================================ + +Hopper (SM90a) introduced the **warpgroup-level MMA** PTX instruction family +``wgmma.mma_async.sync.aligned``. A warpgroup (128 threads / 4 warps) +cooperates on one asynchronous ``D = A * B + C`` matrix multiply-accumulate. + +Key architectural characteristics: + +* **Warpgroup scope:** One MMA is issued collectively by a 128-thread + warpgroup rather than by a single warp. +* **Asynchronous issue model:** WGMMA instructions are ordered with + ``cute.nvgpu.warpgroup.fence()``, ``commit_group()``, and ``wait_group()``. +* **Descriptor-based operand path:** Operand B is sourced from staged shared + memory. Operand A can be sourced either from shared memory descriptors or + from registers via ``OperandSource``. +* **Register accumulator:** The accumulator lives in RMEM and serves as both + the input C and output D of ``cute.gemm()``. +* **Architecture-specific operand layouts:** F16/BF16 supports K-major and + MN-major dense layouts when A comes from SMEM. FP8 and INT8 variants are + K-major only. + +The dense DSL op classes currently exposed are ``MmaF16BF16Op`` (F16/BF16), +``MmaF8Op`` (FP8 E4M3/E5M2), and ``MmaI8Op`` (INT8/UINT8); see +`Setting up the TiledMMA, MMA Ops`_ for their full constructor parameters, +instruction K extents, and major-mode constraints. + +This guide outlines the CuTe Python DSL programming model for WGMMA kernels: +stage operands in SMEM, build fragment descriptors, launch asynchronous +warpgroup MMAs, and stage the RMEM accumulator back to GMEM in the epilogue. + + +.. contents:: **Contents** + :local: + :depth: 2 + +Global Memory (GMEM) to MMA data flow overview +---------------------------------------------- + +WGMMA instructions require us to stage B input operands in Shared Memory (SMEM), +while A input operands can be sourced from either SMEM or registers (RMEM). +SMEM operands are read asynchronously by the hardware via SMEM descriptors. +The accumulator is always kept in registers (RMEM) of the warpgroup. + +The diagram below traces the full data flow of a WGMMA GEMM kernel, for the most +common case where A and B matrices are stored in GMEM and both are staged through +SMEM (``a_src=SMEM``), and the output matrix --accumulated in RMEM-- is written +back to GMEM through an SMEM staging buffer. + +There are 3 parallel tracks where each has 2 sub-tracks. Three parallel tracks are +for operands A, B, and C/D, respectively. The two sub-tracks are for copying data between different memory +spaces and for MMA execution. + +- **Operand A** (and symmetrically **Operand B**): + + - First, we need to create SMEM tensors for A and B matrices: ``sA`` and ``sB``. These + tensors are physically allocated tensors that are the destination of TMA copy and + the source operands for the WGMMA instructions. + - Next the **data copy flow** creates the tensor views for copying data from GMEM to SMEM. + It starts with ``mA`` tensor that represents the matrix A in global memory. + Then ``mA`` → ``local_tile`` → ``gA`` operation creates the local tile view of A that is the + slice of A matrix needed to compute the given CTA's output tile. + Then ``tma_partition(tma, sA, gA)`` produces TMA views ``tAsA``, ``tAgA``, + and the loop copies tiles from GMEM into SMEM via ``copy(tma, tAgA[k], tAsA[stage])``. + - In parallel, the **MMA flow** turns the SMEM tensor into an iterable tensor of SMEM descriptors + for the WGMMA instructions. ``sA`` (the same shared-memory allocation written by TMA) + → ``partition_A`` → ``tCsA`` (MMA-partitioned SMEM view) + → ``make_fragment_A`` → ``tCrA`` (SMEM descriptor passed to ``cute.gemm()``). + Note that the SMEM descriptor is a view created from the SMEM tensor that is + interpretable by the WGMMA instructions. + +- **Accumulator C/D**: + + - **RMEM accumulator flow** (gemm input/output): ``partition_C(gC)`` → ``tCgC`` + → ``make_rmem_tensor(tCgC.shape)`` → ``acc``, which serves as both the accumulator + input (C) and output (D) of ``cute.gemm()`` (and the WGMMA instruction). + - **Output flow** (RMEM → SMEM → GMEM): After the main loop, the accumulator is + type-converted and copied from registers to SMEM via ``stmatrix`` (R2S copy), + then stored to global memory via TMA store (S2G copy): + ``mC`` → ``local_tile`` → ``gC`` → ``partition_C`` → ``tCgC`` on the destination side, + and ``tRS_rAcc``/``tRS_sD`` / ``bSG_sD``/``bSG_gD`` views drive the two copy stages. + +.. code-block:: text + + Operand A Dataflow Path Operand B Dataflow Path Accumulator C/D Dataflow Path + ─────────────────────── ─────────────────────── ───────────────────────────── + + mA: (M, K) [GMEM] mB: (N, K) [GMEM] ┌──── RMEM ──────────┐ + │ │ │ make_rmem_tensor() │ + │ local_tile(mA, cta_tiler, coord) │ local_tile(mB, cta_tiler, coord) │ acc: accumulator │ + ▼ ▼ └───────┬────────────┘ + gA: (BM, BK, k) [GMEM] gB: (BN, BK, k) [GMEM] │ + │ │ acc:(MMA,MMA_M,MMA_N) [RMEM] + │ ┌──── SMEM ─────────┐ │ ┌──── SMEM ─────────┐ │ + │ │ sA = alloc(layout)│ │ │ sB = alloc(layout)│ │ mC: (M, N) [GMEM] + │ └──┬────────┬───────┘ │ └──┬────────┬───────┘ │ │ + │ │ │ │ │ │ │ │ local_tile + │ │ thr_mma.partition_A(sA) │ │ thr_mma.partition_B(sB) │ ▼ + │ │ ▼ │ │ ▼ │ gC: (BM, BN) [GMEM] + │ │ tCsA:(MMA,MMA_M, │ │ tCsB:(MMA,MMA_N, │ │ partition_C + │ │ MMA_K,PIPE) [SMEM] │ │ MMA_K,PIPE) [SMEM] │ ▼ + │ │ │ │ │ │ │ tCgC:(MMA,MMA_M, + │ │ make_fragment_A(tCsA) │ │ make_fragment_B(tCsB) │ MMA_N) + │ │ ▼ │ │ ▼ │ [GMEM] (epi dest) + │ │ tCrA:(MMA,MMA_M, │ │ tCrB:(MMA,MMA_N, │ │ + │ │ MMA_K,PIPE) │ │ MMA_K,PIPE) │ │ + │ │ [SMEM descriptors] │ │ [SMEM descriptors] │ │ + │ │ └─────────────┐ │ │ └─────────────┐ │ │ + ╰─────┤ │ ╰─────┤ │ │ │ + ▼ │ ▼ │ │ │ + tma_partition(tma, │ tma_partition(tma, │ │ │ + sA, gA) │ sB, gB) │ │ │ + → tAsA, tAgA │ → tBsB, tBgB │ │ │ + ▼ │ ▼ │ │ │ + ┌───┴────────────────────┐ │ ┌──────┴─────────────────┐│ │ │ + │ TMA copy loop (A path):│ │ │ TMA copy loop (B path):││ │ │ + │ copy(tma, tAgA[k], │ │ │ copy(tma, tBgB[k], ││ │ │ + │ tAsA[stage]) │ │ │ tBsB[stage]) ││ │ │ + ┌─▶│ (writes into sA; │ │ ┌──▶│ (writes into sB; ││ │ │ + │ │ tCrA reads same sA) │ │ │ │ tCrB reads same sB) ││ │ │ + │ │ repeat for next k/stage│ │ │ │ repeat for next k/stage││ │ │ + │ └────────────────────────┘ │ │ └────────────────────────┘│ │ │ + │ │ │ │ │ │ │ │ + └────────┘ ▼ └─────────┘ ▼ ▼ │ + └───────┬───────────────────────────────┴───────────────────┘ │ + │ │ + ▼ │ + ┌──────────────────────────────────────────────┐ │ + │ GEMM Loop: | │ + │ warpgroup.fence() │ │ + │ cute.gemm(tiled_mma, │ │ + │ acc, D (output, RMEM), │ │ + ┌──▶ │ tCrA[stage], A (SMEM desc -> sA), │ │ + │ │ tCrB[stage], B (SMEM desc -> sB), │ │ + │ │ acc) C (accumulator, RMEM) │ │ + │ │ warpgroup.commit_group() │ │ + │ │ warpgroup.wait_group(n) │ │ + │ └──────────────────────────────────────────────┘ │ + │ │ │ │ + └───────┘ | │ + ▼ │ + Epilogue: │ + tRS_rAcc = retile(acc) │ + tRS_rD = type_convert(tRS_rAcc) │ + │ │ + ▼ │ + R2S: copy(tiled_copy_r2s, tRS_rD, tRS_sD) │ + [RMEM → SMEM via stmatrix] │ + │ │ + ▼ │ + sC = alloc(epi_layout) [SMEM] │ + bSG_sD, bSG_gD = tma_partition(tma_c, sC, gC) ◀───────────────────┘ + │ + ▼ + S2G: copy(tma_c, bSG_sD[stage], bSG_gD[coord]) + [SMEM → GMEM via TMA store] + +**Naming convention:** + +* cta_tiler = (BM, BN, BK) = CTA-wide tiler dimensions +* ``mX`` = a global tensor, e.g., (M, K) for A +* ``gX`` = CTA-tiled GMEM slice, e.g., (BM, BK, k) for A +* ``sX`` = SMEM allocation, e.g., (BM, BK, PIPE) for A +* ``tAsA``/``tBsB`` = TMA-partitioned SMEM views +* ``tAgA``/``tBgB`` = TMA-partitioned GMEM views +* ``tCsX`` = MMA-partitioned SMEM view, e.g., (MMA, MMA_M, MMA_K, PIPE) for A +* ``tCrX`` = SMEM descriptor fragment, e.g., (MMA, MMA_M, MMA_K, PIPE) for A +* ``acc`` = RMEM accumulator, (MMA, MMA_M, MMA_N) +* ``tCgC`` = MMA-partitioned GMEM, (MMA, MMA_M, MMA_N) +* ``tRS_rAcc``/``tRS_sD`` = epilogue retile views for R2S (RMEM → SMEM) copy +* ``bSG_sD``/``bSG_gD`` = TMA-partitioned SMEM/GMEM views for epilogue store +* MMA = warpgroup atom thread-value layout; MMA_M/MMA_N/MMA_K = repeat counts + (e.g., BM/inst_M), k = outer K-tiles, PIPE = pipeline stages + +Setting up the TiledMMA, MMA Ops +--------------------------------- + +As shown in the data flow overview, CuTe DSL provides many utilities to tile/partition +the global memory tensors, and create fragment views of SMEM tensors for MMA instructions. + +To utilize these functions, we need to setup the TiledMMA, MMA Ops first. + +Creating a WGMMA Op +~~~~~~~~~~~~~~~~~~~~ + +A WGMMA op describes the hardware instruction to use, it has parameters like +data types, instruction shape, operand A source (SMEM or RMEM), +and operand major modes. + +.. code-block:: python + + import cutlass + import cutlass.cute as cute + from cutlass.cute.nvgpu import OperandMajorMode + import cutlass.cute.nvgpu.warpgroup as warpgroup + + op = warpgroup.MmaF16BF16Op( + cutlass.Float16, # A/B element type + cutlass.Float32, # accumulator type + (64, 128, 16), # instruction shape (M, N, K) + warpgroup.OperandSource.SMEM, # A operand from shared memory + OperandMajorMode.K, # A is K-major + OperandMajorMode.K, # B is K-major + ) + +The key parameters are: + +- **Instruction shape** ``(M, N, K)``: determines the size of one hardware MMA + instruction. WGMMA requires ``M = 64`` and ``8 <= N <= 256`` in steps of 8. + K is fixed by the op class (16 for F16/BF16, 32 for FP8 and INT8). +- **OperandSource**: ``SMEM`` reads A from a shared memory descriptor; ``RMEM`` + reads A directly from registers. +- **OperandMajorMode**: ``K`` for K-major (default), ``MN`` for transposed layout. + F16/BF16 supports both K-major and MN-major for A and B when ``a_src=SMEM``; + when ``a_src=RMEM``, only B can be transposed. FP8 and INT8 are K-major only. + + +CuTe DSL provides implementation of the following WGMMA ops: + +.. list-table:: WGMMA ops + :header-rows: 1 + :widths: 30 24 46 + + * - PTX name + - Python class + - Constructor parameters + * - ``wgmma.mma_async.m64n{N}k16.{acc}.f16.f16`` / ``.bf16.bf16`` + - ``warpgroup.MmaF16BF16Op`` + - ``ab_dtype, acc_dtype, instruction_shape, a_src, a_major_mode, b_major_mode`` + * - ``wgmma.mma_async.m64n{N}k32.{acc}.{e4m3|e5m2}.{e4m3|e5m2}`` + - ``warpgroup.MmaF8Op`` + - ``a_dtype, b_dtype, acc_dtype, instruction_shape, a_src, a_major_mode, b_major_mode`` + * - ``wgmma.mma_async.m64n{N}k32.s32.{s8|u8}.{s8|u8}`` + - ``warpgroup.MmaI8Op`` + - ``a_dtype, b_dtype, acc_dtype, instruction_shape, a_src, a_major_mode, b_major_mode`` + + +Creating a Tiled MMA +~~~~~~~~~~~~~~~~~~~~~ + +A ``TiledMma`` tiles the WGMMA atom across the CTA tile. You can pass the op +directly or create an explicit atom first. + +.. code-block:: python + + # Option 1: directly from op (common shorthand) + tiled_mma = cute.make_tiled_mma(op) + + # Option 2: explicit atom creation + atom = cute.make_mma_atom(op) + tiled_mma = cute.make_tiled_mma(atom) + +Spatial tiling with a repeat count +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +A repeat tuple ``(M_rep, N_rep, K_rep)`` replicates the WGMMA atom across the +M, N, and K dimensions, producing a larger tiled MMA that covers a bigger CTA +tile with a single ``cute.gemm`` call. Each entry in the repeat tuple +corresponds to one **warpgroup** (128 threads / 4 warps), so ``(2, 1, 1)`` +uses two warpgroups — the standard configuration for large Hopper tiles: + +.. code-block:: python + + atom = cute.make_mma_atom(op) # op shape: (64, 128, 16) + tiled_mma = cute.make_tiled_mma( + atom, + atom_layout_mnk=(2, 1, 1), # 2 warpgroups in M + ) + +.. code-block:: text + + WGMMA Atom make_tiled_mma(atom, (2, 1, 1)) + +---------------+ +----------------+ + | | | | ^ + | 64 x 128 | | Atom (0,0,0) | | + | x 16 | --(2,1,1)--> | 64 x 128 | | 2 x M_atom + | | repeat | x 16 | | = 128 + | | | [Warpgroup 0] | | + +---------------+ +----------------+ | + | | | + | Atom (1,0,0) | | + | 64 x 128 | | + | x 16 | | + | [Warpgroup 1] | v + +----------------+ + <-- N_atom = 128 --> + K unchanged = 16 + +The Hopper dense GEMM examples +(``examples/cute/hopper/kernel/dense_gemm/dense_gemm.py``) use this pattern. +The helper ``sm90_utils.make_trivial_tiled_mma(...)`` selects the repeat count +automatically: + +- ``atom_layout_mnk = (2, 1, 1)`` when both ``tile_M > 64`` and + ``tile_N > 128`` (two warpgroups reduce register pressure). +- ``atom_layout_mnk = (1, 1, 1)`` otherwise (a single warpgroup suffices). + +.. code-block:: python + + import cutlass.utils.hopper_helpers as sm90_utils + + tiled_mma = sm90_utils.make_trivial_tiled_mma( + a_dtype, + b_dtype, + a_major_mode, + b_major_mode, + acc_dtype, + atom_layout_mnk=(2, 1, 1), + tiler_mn=(64, 128), # atom instruction shape (M, N) + ) + +Custom tile permutation with ``permutation_mnk`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +``make_tiled_mma`` also accepts an optional ``permutation_mnk`` argument that +controls how the tiled atom footprint is laid out across M, N, and K. At a +high level: + +- ``atom_layout_mnk`` tells CuTe how many atoms (warpgroups) to replicate. +- ``permutation_mnk`` tells CuTe how the final tiled footprint is ordered. + +``permutation_mnk`` is a tuple of layouts or integers that represent the +tile size and ordering of values along each dimension. When a mode's +permutation size is larger than the atom layout's natural coverage +(``atom_layout x inst_shape``), each warpgroup receives additional values +to fill the extended region — the warpgroup count stays the same, but each +warpgroup holds more data. + +.. code-block:: python + + atom = cute.make_mma_atom(op) # op shape: (64, 128, 16) + tiled_mma = cute.make_tiled_mma( + atom, + atom_layout_mnk=(2, 1, 1), + permutation_mnk=(128, 256, 16), # extend N from 128 to 256 + ) + +.. code-block:: text + + Without permutation — natural atom coverage (M = 128, N = 128): + + C tile (M=128, N=128) + +----------------+ + | | ^ + | [Warpgroup 0] | | + | 64 x 128 | | 2 x inst_M + | | | = 128 + +----------------+ | + | | | + | [Warpgroup 1] | | + | 64 x 128 | | + | | v + +----------------+ + <--- N = 128 ---> + (each warpgroup owns one (64, 128) atom) + + With permutation_mnk = (128, 256, 16) — N extended to 256: + + C tile (M=128, N=256) + +----------------+----------------+ + | | | ^ N = 128 → 256: + | [Warpgroup 0] | [Warpgroup 0] | | atom pattern repeats + | 64 x 128 | 64 x 128 | | along N. Each warpgroup + | | | | now holds 2x the values + +----------------+----------------+ | along N (same threads, + | | | | more data). + | [Warpgroup 1] | [Warpgroup 1] | | + | 64 x 128 | 64 x 128 | | + | | | v + +----------------+----------------+ + <------------ N = 256 ------------> + | atom coverage | value repeat | + +**Why WGMMA typically does not need permutation_mnk:** The WGMMA +instruction already has a large N dimension (64, 128, or 256), so the natural +atom coverage is wide enough that no permutation is needed to align with SMEM +swizzle widths. The Hopper +dense GEMM examples (``dense_gemm.py``, ``dense_gemm_persistent.py``) use +``atom_layout_mnk`` alone without ``permutation_mnk``. + +When ``permutation_mnk`` is not provided (default), the tile ordering is +sequential and no permutation is applied. + + +Partitioning Tensors +--------------------- + +Before computing, partition the CTA-tiled tensors according to the +tiled MMA layout. WGMMA partitioning is **warpgroup-oriented**: each +warpgroup (128 threads / 4 warps) receives its own slice of the CTA +tile, sized to match the SMEM descriptors and register accumulators +that the WGMMA instruction expects. + +**2-warpgroup example** + +``GEMM (M, N, K) = (512, 768, 256)``, ``tile_shape_mnk = (128, 256, 64)``, +F16 WGMMA atom = (64, 256, 16), ``atom_layout_mnk = (2, 1, 1)``, +``num_stages = 4``, 2 warpgroups = 256 threads. + +Global matrices: + +.. code-block:: text + + mA: (M, K) = (512, 256) mB: (N, K) = (768, 256) mC: (M, N) = (512, 768) + + K=256 K=256 N=768 + |<--------->| |<--------->| |<----------------->| + +-----------+ +-----------+ +---+---+---+-------+ + | | ^ | | ^ | | | | | ^ + | mA | | M=512 | mB | | N=768 | | | | | | M=512 + | | v | | v | | | | | v + +-----------+ +-----------+ +---+---+---+-------+ + +Tiling with ``tile_shape_mnk = (BM, BN, BK) = (128, 256, 64)`` gives +M/BM = 4 tiles, N/BN = 3 tiles, K/BK = 4 tiles: + +.. code-block:: text + + mA tiled into (M/BM x K/BK) mB tiled into (N/BN x K/BK) mC tiled into (M/BM x N/BN) + = (4 x 4) blocks = (3 x 4) blocks = (4 x 3) blocks + + BK=64 x4 BK=64 x4 BN=256 x3 + |<--->| |<--->| |<------>| + +-----+-----+-----+-----+ +-----+-----+-----+-----+ +--------+--------+--------+ + | | | | | ^ | | | | | ^ | (0,0) | (0,1) | (0,2) | ^ + | | | | | |128 | | | | | |256 | | | | |128 + +-----+-----+-----+-----+ v +-----+-----+-----+-----+ v +--------+--------+--------+ v + | | | | | ^ | | | | | ^ | (1,0) | (1,1) | (1,2) | ^ + | | | | | |128 | | | | | |256 | | | | |128 + +-----+-----+-----+-----+ v +-----+-----+-----+-----+ v +--------+--------+--------+ v + | | | | | | | | | | | (2,0) | (2,1) | (2,2) | + +-----+-----+-----+-----+ +-----+-----+-----+-----+ +--------+--------+--------+ + | | | | | | (3,0) | (3,1) | (3,2) | + +-----+-----+-----+-----+ +--------+--------+--------+ + +Each CTA picks one (M-tile, N-tile) coordinate. +For example, CTA at ``tile_coord = (1, 0, :)``. + +After ``local_tile`` — one CTA's tile (``k = K/BK = 256/64 = 4``): + +.. code-block:: text + + gA: (BM, BK, k) = (128, 64, 4) gB: (BN, BK, k) = (256, 64, 4) gC: (BM, BN) = (128, 256) + + BK=64 BK=64 BN=256 + |<----->| |<----->| |<--------->| + +-------+-- +-------+-- +-----------+ + | |.. | |.. | | ^ + BM= | gA | k=4 BN= | gB | k=4 BM= | gC | | 128 + 128 | | 256 | | 128 | | v + +-------+ +-------+ +-----------+ + +SMEM tensors ``sA`` and ``sB`` include a pipeline staging dimension: + +.. code-block:: text + + sA: (BM, BK, PIPE) = (128, 64, 4) sB: (BN, BK, PIPE) = (256, 64, 4) + +``get_slice(warp_group_thread_layout(warp_group_idx))`` — each +warpgroup receives its slice of the tiled MMA footprint. +With ``atom_layout_mnk = (2, 1, 1)`` and inst shape ``(64, 256, 16)``, +the tiled MMA covers ``(2x64, 1x256, 16) = (128, 256, 16)`` which +exactly matches the CTA tile in M and N. Each warpgroup owns one +64-row slice of M: + +.. code-block:: text + + sA (one pipeline stage, BM=128, BK=64): + + Warpgroup 0's slice Warpgroup 1's slice + inst_K inst_K inst_K inst_K + =16 =16 =16 =16 + |<--->|<--->|<--->|<--->| |<--->|<--->|<--->|<--->| + +-----+-----+-----+-----+ ^ +-----+-----+-----+-----+ ^ + | 0 | 1 | 2 | 3 | |64 | 0 | 1 | 2 | 3 | |64 + +-----+-----+-----+-----+ v +-----+-----+-----+-----+ v + |<-- MMA_K = BK/inst_K = 4 -->| |<-- MMA_K = 4 ---------->| + MMA_M = 64/64 = 1 MMA_M = 64/64 = 1 + + gC (BM=128, BN=256): + + +---------------------------+ ^ + | Warpgroup 0: 64 x 256 | | 64 + | | | + +---------------------------+ v + | Warpgroup 1: 64 x 256 | ^ + | | | 64 + +---------------------------+ v + <--------- N = 256 --------> + MMA_M = 64/64 = 1, MMA_N = 256/256 = 1 + +After partition (per warpgroup): + +- ``tCsA: (MMA, MMA_M, MMA_K, PIPE) = (MMA, 1, 4, 4)`` — MMA_M = BM / (atom_M x inst_M) = 128 / (2x64) = 1, MMA_K = BK / inst_K = 64 / 16 = 4 +- ``tCsB: (MMA, MMA_N, MMA_K, PIPE) = (MMA, 1, 4, 4)`` — MMA_N = BN / (atom_N x inst_N) = 256 / (1x256) = 1, MMA_K = 4 +- ``tCgC: (MMA, MMA_M, MMA_N) = (MMA, 1, 1)`` — MMA_M = 1, MMA_N = 1 + +The first mode ``MMA`` contains the atom's **thread x value** layout — it +encodes which registers within a warpgroup hold which matrix elements. +The remaining modes are repeat counts that tile the atom across the +full CTA tile. + +.. note:: Because the WGMMA instruction shape is large (64 x {64..256}), + the tiled MMA footprint typically covers the entire CTA tile in M and N + with just one or two warpgroups. This means MMA_M and MMA_N are often 1. + The MMA_K dimension is where the repeat count is non-trivial (BK / inst_K + iterations per pipeline stage). + +**1-warpgroup example (contrast)** + +For a smaller tile ``(128, 128, 64)`` with ``atom_layout_mnk = (1, 1, 1)``, +inst shape ``(64, 128, 16)``, and ``num_stages = 4``, +the tiled MMA covers only ``(64, 128, 16)``. +Now a single warpgroup must iterate over two atom-blocks along M: + +- ``tCsA: (MMA, MMA_M, MMA_K, PIPE) = (MMA, 2, 4, 4)`` — MMA_M = 128 / (1x64) = 2 +- ``tCsB: (MMA, MMA_N, MMA_K, PIPE) = (MMA, 1, 4, 4)`` — MMA_N = 128 / (1x128) = 1 +- ``tCgC: (MMA, MMA_M, MMA_N) = (MMA, 2, 1)`` + +.. code-block:: python + + # Based on examples/cute/hopper/kernel/dense_gemm/dense_gemm.py + @cute.kernel + def kernel(tiled_mma: cute.TiledMma, ...): + tidx, _, _ = cute.arch.thread_idx() + + # CTA-tiled global tensors + gA_mkl = cute.local_tile( + mA_mkl, tile_shape_mnk, tile_coord_mnkl, proj=(1, None, 1) + ) + gB_nkl = cute.local_tile( + mB_nkl, tile_shape_mnk, tile_coord_mnkl, proj=(None, 1, 1) + ) + gC_mnl = cute.local_tile( + mC_mnl, tile_shape_mnk, tile_coord_mnkl, proj=(1, 1, None) + ) + + # Warpgroup-oriented slicing (128 threads per warpgroup) + warp_group_idx = cute.arch.make_warp_uniform( + tidx // num_threads_per_warp_group # 128 + ) + warp_group_thread_layout = cute.make_layout( + mma_warp_groups, # e.g. 2 + stride=num_threads_per_warp_group, # 128 + ) + thr_mma = tiled_mma.get_slice( + warp_group_thread_layout(warp_group_idx) + ) + + # Partition C from global + tCgC = thr_mma.partition_C(gC_mnl) # (MMA, MMA_M, MMA_N) + + # Partition A/B from staged SMEM + tCsA = thr_mma.partition_A(sA) # (MMA, MMA_M, MMA_K, PIPE) + tCsB = thr_mma.partition_B(sB) # (MMA, MMA_N, MMA_K, PIPE) + + +Pre and Post-Conditions for Partitioning +----------------------------------------- + +* The inputs of ``partition_A``, ``partition_B``, and ``partition_C`` should be + at least rank-2 tensors. +* The output layout is constrained by the selected MMA atom: + + - For A, the output has layout ``(MMA, MMA_M, MMA_K, ...)``. + - For B, the output has layout ``(MMA, MMA_N, MMA_K, ...)``. + - For C, the output has layout ``(MMA, MMA_M, MMA_N, ...)``. + +* Partitioning reasons about layout, not memory space or element type. + When ``a_src=OperandSource.RMEM``, the same tiled MMA shape still + determines the logical A footprint, but A is materialized as a register + fragment rather than a shared-memory descriptor. + + +Making Fragments +----------------- + +Fragments are the tensors that the WGMMA instruction operates on. For dense +WGMMA: + +- **Fragment A**: an SMEM descriptor when ``a_src=OperandSource.SMEM``, or an + RMEM register fragment when ``a_src=OperandSource.RMEM``. +- **Fragment B**: an SMEM descriptor pointing into staged shared memory buffers. +- **Fragment C (accumulator)**: an RMEM tensor that serves as both the input C + and output D of ``cute.gemm()``. + +WGMMA fragments for A and B are **SMEM descriptors** — the hardware reads +directly from shared memory. There is no explicit SMEM → RMEM copy step for +operands A and B. The accumulator, however, still lives in per-thread +registers (RMEM). + +Creating fragment descriptors and accumulator fragments +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Fragment creation has two parts: + +**1. A and B fragment descriptors** + +``make_fragment_A`` and ``make_fragment_B`` take the MMA-partitioned SMEM +views (``tCsA`` / ``tCsB``) and produce descriptor tensors that the WGMMA +instruction consumes. Each descriptor points to one tile within a pipeline +stage in shared memory. + +.. code-block:: python + + # MMA-partitioned SMEM views (see "Partitioning Tensors") + tCsA = thr_mma.partition_A(sA) # (MMA, MMA_M, MMA_K, PIPE) + tCsB = thr_mma.partition_B(sB) # (MMA, MMA_N, MMA_K, PIPE) + + # SMEM descriptor fragments consumed by cute.gemm() + tCrA = tiled_mma.make_fragment_A(tCsA) # (MMA, MMA_M, MMA_K, PIPE) + tCrB = tiled_mma.make_fragment_B(tCsB) # (MMA, MMA_N, MMA_K, PIPE) + +Continuing the 2-warpgroup example from `Partitioning Tensors`_ +(F16 atom = (64, 256, 16), ``tile_shape_mnk = (128, 256, 64)``, +``atom_layout_mnk = (2, 1, 1)``, ``num_stages = 4``): + +.. code-block:: text + + tCsA: (MMA, MMA_M=1, MMA_K=4, PIPE=4) + tCsB: (MMA, MMA_N=1, MMA_K=4, PIPE=4) + + make_fragment_A(tCsA) -> tCrA: (MMA, 1, 4, 4) + make_fragment_B(tCsB) -> tCrB: (MMA, 1, 4, 4) + + Each element of tCrA/tCrB is an SMEM descriptor — one per + (MMA_K, PIPE) pair. The hardware reads SMEM directly via the + descriptor; no explicit SMEM -> RMEM load is needed. + + tCrA per warpgroup (4 pipeline stages, 4 K-blocks each): + + |<-- MMA_K = BK/inst_K = 4 -->| + stage 0: +------+------+------+------+ + | k=0 | k=1 | k=2 | k=3 | inst_M=64 (MMA_M=1) + +------+------+------+------+ + stage 1: +------+------+------+------+ + | k=0 | k=1 | k=2 | k=3 | inst_M=64 + +------+------+------+------+ + stage 2: +------+------+------+------+ + | k=0 | k=1 | k=2 | k=3 | inst_M=64 + +------+------+------+------+ + stage 3: +------+------+------+------+ + | k=0 | k=1 | k=2 | k=3 | inst_M=64 + +------+------+------+------+ + + Similarly for tCrB with shape (MMA, MMA_N=1, MMA_K=4, PIPE=4). + +.. note:: WGMMA fragments for A and B are SMEM descriptors — the hardware + reads SMEM directly, so there is no ``ldmatrix`` retiling step required + before ``cute.gemm()``. + +**When A comes from registers (``OperandSource.RMEM``)** + +In fused kernels, the output of one MMA can become the A operand of the +next. The second ``TiledMma`` is created with +``a_src=OperandSource.RMEM``, and ``make_fragment_A`` is **not** used. +Instead: + +1. The accumulator's C layout ``(MMA, MMA_M, MMA_N)`` is converted to the + A layout ``(MMA, MMA_M, MMA_K)`` expected by the second ``TiledMma``. +2. The accumulator values are type-converted and stored into an RMEM tensor + with the A layout. +3. The resulting RMEM tensor is passed directly to ``cute.gemm()`` as the A + operand — no SMEM descriptor is involved. + +See the Hopper FMHA example (``examples/cute/hopper/kernel/attention/fmha.py``) for the complete pattern. + +**2. C fragment (accumulator)** + +The accumulator lives in per-thread registers (RMEM). Its shape is derived +from the partitioned C layout. The accumulator starts at zero before the K +loop and is updated in-place by each ``cute.gemm()`` call. + +.. code-block:: python + + # Partition C from global (see "Partitioning Tensors") + tCgC = thr_mma.partition_C(gC_mnl) # (MMA, MMA_M, MMA_N) + + # Allocate RMEM accumulator with the same shape + acc_shape = tCgC.shape + acc = cute.make_rmem_tensor(acc_shape, cutlass.Float32) + acc.fill(0.0) + +For the same running example: + +.. code-block:: text + + tCgC: (MMA, MMA_M=1, MMA_N=1) + + make_rmem_tensor(tCgC.shape, Float32) -> acc: (MMA, 1, 1) + + The accumulator stays in RMEM for the entire main loop. + cute.gemm() reads A/B from SMEM descriptors and accumulates into acc. + + +-----------------------------------+ + | acc: (MMA, 1, 1) in RMEM | + | 64 x 256 elements per warpgroup | + | Float32 | + +-----------------------------------+ + + +Creating SMEM layouts for A and B +---------------------------------- + +The SMEM layouts define how A and B tiles are staged in shared memory, +including swizzling for bank-conflict-free descriptor access. The helper +functions in ``cutlass.utils.hopper_helpers`` handle the details. + +**Host side** (``@cute.jit``): + +.. code-block:: python + + import cutlass.utils.hopper_helpers as sm90_utils + + # Create SMEM layouts (includes swizzle + staging) + a_smem_layout = sm90_utils.make_smem_layout_a( + a_layout, # LayoutEnum — row-major or col-major + tile_shape_mnk, # CTA tile (M, N, K) + a_dtype, # element type (e.g. Float16) + num_stages, # pipeline depth + ) + b_smem_layout = sm90_utils.make_smem_layout_b( + b_layout, + tile_shape_mnk, + b_dtype, + num_stages, + ) + epi_smem_layout = sm90_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + epi_stage, + ) + +``make_smem_layout_a`` and ``make_smem_layout_b`` are convenience helpers that +build a complete, staged SMEM layout in four steps: + +1. **Extract the operand tile shape.** For A the ``(M, K)`` portion of + ``tile_shape_mnk`` is kept via ``cute.slice_``; for B the ``(N, K)`` + portion. + +2. **Determine the major mode.** The major mode (K-major or MN-major) is read + from the layout enum (``a_layout.is_k_major_a()``). The major-mode + dimension size is used for swizzle selection. + +3. **Select and materialise the swizzle atom.** A heuristic + (``get_smem_layout_atom``) picks the widest swizzle whose contiguous + size (in bits) evenly divides the major-mode dimension: + + +------------+-----------------+ + | Swizzle | Contiguous bits | + +============+=================+ + | SW128 | 1024 (128 B) | + +------------+-----------------+ + | SW64 | 512 (64 B) | + +------------+-----------------+ + | SW32 | 256 (32 B) | + +------------+-----------------+ + | Interleave | 128 (16 B) | + +------------+-----------------+ + + ``make_smem_layout_atom`` then combines the chosen swizzle with a compact + outer layout into a ``ComposedLayout(swizzle, outer)``. + +4. **Tile to the operand shape and append the staging dimension.** + ``cute.tile_to_shape`` broadcasts the atom to the full ``(M_or_N, K)`` + shape with ``num_stages`` appended. The ``order`` argument controls which + dimension is contiguous: ``(0, 1, 2)`` for K-major (K innermost), + ``(1, 0, 2)`` for MN-major (MN innermost). + +For the running F16 example (``tile_shape_mnk = (128, 256, 64)``, +``num_stages = 4``, K-major A, K-major B): + +.. code-block:: text + + A operand (K-major, tile = (M=128, K=64)): + major_mode_size = 64 + 64 * 16 bits = 1024 bits → SW128 + atom = make_smem_layout_atom(K_SW128, Float16) + tile_to_shape(atom, (128, 64, 4), order=(0,1,2)) + -> a_smem_layout: ComposedLayout with shape (128, 64, 4) + + B operand (K-major, tile = (N=256, K=64)): + major_mode_size = 64 + 64 * 16 bits = 1024 bits → SW128 + atom = make_smem_layout_atom(K_SW128, Float16) + tile_to_shape(atom, (256, 64, 4), order=(0,1,2)) + -> b_smem_layout: ComposedLayout with shape (256, 64, 4) + +**Kernel side** (``@cute.kernel``): + +The layout and swizzle are passed to shared-memory allocation. The result +is a ``ComposedLayout`` whose ``.outer`` is the logical layout and ``.inner`` +is the swizzle: + +.. code-block:: python + + # Based on examples/cute/hopper/kernel/dense_gemm/dense_gemm.py + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + +After allocation: + +- ``sA`` has shape ``(BM, BK, PIPE) = (128, 64, 4)``. +- ``sB`` has shape ``(BN, BK, PIPE) = (256, 64, 4)``. + +These are the staged SMEM tensors consumed by ``partition_A`` / +``partition_B`` and ``make_fragment_A`` / ``make_fragment_B`` +(see `Making Fragments`_). + +.. note:: If you need finer control, you can build layout atoms directly with + ``cute.nvgpu.warpgroup.make_smem_layout_atom(...)`` and compose the final + SMEM layout manually via ``cute.tile_to_shape``. + + +Executing the GEMM (Main Loop) +------------------------------- + +The main loop iterates over K-tiles. The WGMMA-specific part of each +iteration is the **fence / gemm / commit / wait** sequence: + +.. code-block:: python + + acc.fill(0.0) + tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True) + + for k_tile in cutlass.range(k_pipe_mmas, k_tile_cnt, 1, unroll=1): + # ... wait for TMA load (pipeline details in dense_gemm.py) ... + + cute.nvgpu.warpgroup.fence() + tile_crd = (None, None, None, consumer_read.index) + cute.gemm(tiled_mma, acc, tCrA[tile_crd], tCrB[tile_crd], acc) + cute.nvgpu.warpgroup.commit_group() + cute.nvgpu.warpgroup.wait_group(k_pipe_mmas) + + # ... release buffer & advance pipeline (see dense_gemm.py) ... + + cute.nvgpu.warpgroup.wait_group(0) + +Key points: + +- ``fence()`` orders prior SMEM writes before WGMMA issue. +- ``commit_group()`` publishes queued WGMMA instructions as a group. +- ``wait_group(n)`` waits until at most ``n`` groups remain in flight. + ``wait_group(0)`` after the loop drains all work before the epilogue. +- ``Field.ACCUMULATE`` — ``True`` accumulates (``D += A*B``), + ``False`` overwrites (``D = A*B``). The dense GEMM sets ``True`` and + zero-fills ``acc`` so the first iteration computes ``0 + A*B``. + + +Complete Workflow +------------------ + +Putting it all together, a typical Hopper WGMMA GEMM has this structure. +The MMA-relevant steps are highlighted; see ``dense_gemm.py`` for the full +kernel including TMA, pipeline, and epilogue details. + +.. code-block:: python + + import cutlass + import cutlass.cute as cute + from cutlass.cute.nvgpu import OperandMajorMode + import cutlass.cute.nvgpu.warpgroup as warpgroup + import cutlass.utils.hopper_helpers as sm90_utils + + # --- Host side (@cute.jit) --- + + # 1. MMA op + tiled MMA + op = warpgroup.MmaF16BF16Op( + cutlass.Float16, cutlass.Float32, (64, 128, 16), + warpgroup.OperandSource.SMEM, OperandMajorMode.K, OperandMajorMode.K, + ) + tiled_mma = cute.make_tiled_mma(op) + + # 2. SMEM layouts + a_smem_layout = sm90_utils.make_smem_layout_a(a_layout, tile_shape_mnk, a_dtype, num_stages) + b_smem_layout = sm90_utils.make_smem_layout_b(b_layout, tile_shape_mnk, b_dtype, num_stages) + + # 3. TMA copy atoms + kernel launch (see dense_gemm.py) + +.. code-block:: python + + # --- Kernel side (@cute.kernel) --- + + # 4. Allocate SMEM + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sA = storage.sA.get_tensor( + a_smem_layout.outer, swizzle=a_smem_layout.inner) # (BM, BK, PIPE) + sB = storage.sB.get_tensor( + b_smem_layout.outer, swizzle=b_smem_layout.inner) # (BN, BK, PIPE) + + # 5. CTA-tiled global tensors + gA_mkl = cute.local_tile(mA_mkl, tile_shape_mnk, tile_coord, proj=(1, None, 1)) + gB_nkl = cute.local_tile(mB_nkl, tile_shape_mnk, tile_coord, proj=(None, 1, 1)) + gC_mnl = cute.local_tile(mC_mnl, tile_shape_mnk, tile_coord, proj=(1, 1, None)) + + # 6. Warpgroup slice, partition & make fragments + warp_group_idx = cute.arch.make_warp_uniform(tidx // num_threads_per_warp_group) + warp_group_thread_layout = cute.make_layout(mma_warp_groups, stride=num_threads_per_warp_group) + thr_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)) + + tCsA = thr_mma.partition_A(sA) # (MMA, MMA_M, MMA_K, PIPE) + tCsB = thr_mma.partition_B(sB) # (MMA, MMA_N, MMA_K, PIPE) + tCrA = tiled_mma.make_fragment_A(tCsA) # SMEM descriptor + tCrB = tiled_mma.make_fragment_B(tCsB) # SMEM descriptor + tCgC = thr_mma.partition_C(gC_mnl) # (MMA, MMA_M, MMA_N) + acc = cute.make_rmem_tensor(tCgC.shape, acc_dtype) + + # 7. TMA pipeline setup + prefetch (see dense_gemm.py) + + # 8. Main loop — fence / gemm / commit / wait + acc.fill(0.0) + tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True) + + for k_tile in cutlass.range(k_pipe_mmas, k_tile_cnt, 1, unroll=1): + # ... wait for TMA load ... + cute.nvgpu.warpgroup.fence() + tile_crd = (None, None, None, consumer_read.index) + cute.gemm(tiled_mma, acc, tCrA[tile_crd], tCrB[tile_crd], acc) + cute.nvgpu.warpgroup.commit_group() + cute.nvgpu.warpgroup.wait_group(k_pipe_mmas) + # ... release buffer, advance pipeline ... + + cute.nvgpu.warpgroup.wait_group(0) + + # 9. Epilogue: RMEM → SMEM (stmatrix) → GMEM (TMA store) + # ... (see dense_gemm.py) + + +.. Beyond Simple Dense MMAs +.. ------------------------ + +.. The current Python DSL coverage for warpgroup MMA is centered on the three +.. dense ops above. PTX also defines additional WGMMA instruction families that +.. do **not** yet have DSL op classes. These are tracked in the source at +.. ``cutlass/cute/nvgpu/warpgroup/mma.py`` (marked ``✗`` in the instruction +.. table). + +.. **Structured-sparse WGMMA** (``wgmma.mma_async.sp``) + +.. 2:4 structured sparsity in operand A: out of every 4 consecutive K-elements, +.. exactly 2 are non-zero. The instruction K is **doubled** relative to the +.. dense counterpart (e.g. ``m64nNk32`` for F16/BF16 vs ``m64nNk16`` dense) +.. because A is stored in compressed form. Supported data types include +.. F16/BF16, TF32, FP8, and INT8. + +.. Compared to the dense workflow, a sparse kernel would add: + +.. - A **compressed A tensor** storing only the non-zero values (half the +.. K-elements), and a **metadata tensor E** encoding which 2 of 4 positions +.. are non-zero. +.. - Extra SMEM layouts, TMA loads, and allocations for both the compressed A +.. and the metadata E. +.. - A metadata staging step each K-tile (SMEM to the MMA instruction). + +.. Once DSL support is added, the same fence/commit/wait workflow described in +.. this guide applies, with the additional metadata operand. + +.. **Dense TF32 WGMMA** (``m64nNk8``) + +.. TF32 (19-bit truncated FP32) inputs with FP32 accumulator. The instruction +.. K = 8 is smaller than F16's K = 16, so MMA_K repeat counts are larger for +.. the same BK tile size. Otherwise the workflow is identical to the dense +.. F16/BF16 path — the same SMEM layout, descriptor, and fence/commit/wait +.. pattern applies. + +.. **Dense B1 WGMMA** (``m64nNk256``) + +.. 1-bit (binary) inputs with INT32 accumulator. The very large instruction +.. K = 256 means each atom consumes 256 bits along K per operand, resulting in +.. small MMA_K repeat counts. This is a niche instruction for binary neural +.. networks. + + +See also: + +- Dense GEMM example: ``examples/cute/hopper/kernel/dense_gemm/dense_gemm.py`` +- Persistent GEMM example: ``examples/cute/hopper/kernel/dense_gemm/dense_gemm_persistent.py`` +- FMHA example (RMEM A path): ``examples/cute/hopper/kernel/attention/fmha.py`` +- Helper utilities: ``cutlass.utils.hopper_helpers`` diff --git a/media/docs/pythonDSL/mma_docs/wmma_programming.rst b/media/docs/pythonDSL/mma_docs/wmma_programming.rst new file mode 100644 index 000000000..8387f01a8 --- /dev/null +++ b/media/docs/pythonDSL/mma_docs/wmma_programming.rst @@ -0,0 +1,1358 @@ +.. _wmma_programming: + +Warp-Level MMA Instructions Programming Guide +============================================= + +Ampere (SM80) introduced the modern **warp-level MMA** PTX instruction +family ``mma.sync.aligned``. A warp (32 threads) cooperates on one +synchronous ``D = A * B + C`` matrix multiply-accumulate; later +architectures extended the family with new data types and shapes — FP8 on +Ada (SM89) and block-scaled MX FP4 on Blackwell (SM120a) — while keeping +the same warp-synchronous issue model. + +Key architectural characteristics: + +* **Warp scope:** One MMA is issued collectively by a 32-thread warp + rather than by a warpgroup or a single thread. +* **Synchronous issue model:** ``mma.sync.aligned`` completes in program + order within the warp; no fences or commit/wait groups are required. +* **Register-resident operands and accumulator:** A, B, and C/D all live + in the register file (RMEM). Each thread holds a small fragment of every + operand in its own registers. +* **SMEM → RMEM loading:** Operands A and B are staged in shared memory + and loaded into register fragments via ``ldmatrix`` — a warp-collective + SMEM→RMEM load that distributes tiles in the exact layout the MMA + expects — or via regular shared-memory loads. +* **Fixed operand layout:** A is row-major (K-major) and B is col-major + (K-major); transpose is not supported at the instruction level. + +The dense DSL op classes currently exposed are ``MmaF16BF16Op`` (F16/BF16, +SM80+), ``MmaFP8Op`` (FP8 E4M3/E5M2, SM89+), and ``MmaMXF4Op`` / +``MmaMXF4NVF4Op`` (block-scaled MX FP4, SM120a+); see `Setting up the +TiledMMA, MMA Ops`_ for their full constructor parameters, instruction +shapes, and architecture requirements. + +.. {$nv-internal-release begin} + +Internal builds additionally expose ``MmaF16BF16SparseOp`` (2:4 structured +sparsity, SM80+). + +.. {$nv-internal-release end} + +This guide outlines the CuTe Python DSL programming model for warp-level +MMA kernels: stage operands in SMEM, load register fragments with +``ldmatrix`` or regular shared-memory loads, launch warp-synchronous MMAs, and stage the RMEM accumulator +back to GMEM in the epilogue. + + +.. contents:: **Contents** + :local: + :depth: 2 + +Global Memory (GMEM) to MMA data flow overview +---------------------------------------------- + +Warp MMA (``mma.sync.aligned``) instructions require all operands --A, B, and +the accumulator C/D-- to live in registers (RMEM) of the 32 threads of the warp. +Operand data must therefore be explicitly loaded into registers before each MMA +instruction. The most common way to implement these GEMMs is to stage A and B +from GMEM into SMEM with ``cp.async``, then use ``ldmatrix`` (an SMEM→RMEM +warp-collective load) to fill the A/B register fragments just before ``cute.gemm()``. + +The diagram below traces the full data flow of a warp MMA GEMM kernel, for the most +common case where A and B matrices are stored in GMEM and staged through SMEM +via ``cp.async``, and the output matrix --accumulated in RMEM-- is written back +to GMEM through an SMEM staging buffer for coalesced vectorized stores. + +There are 3 parallel tracks where each has 2 sub-tracks. Three parallel tracks are +for operands A, B, and C/D, respectively. The two sub-tracks are for copying data between different memory +spaces and for MMA execution. + +- **Operand A** (and symmetrically **Operand B**): + + - First, we need to create SMEM tensors for A and B matrices: ``sA`` and ``sB``. + These tensors are physically allocated tensors that are the staging destination + of ``cp.async`` and the source of ``ldmatrix`` for the warp MMA instructions. + - Next the **data copy flow** creates the tensor views for copying data from GMEM to SMEM. + It starts with ``mA`` tensor that represents the matrix A in global memory. + Then ``mA`` → ``local_tile`` → ``gA`` operation creates the local tile view of A that is + the slice of A matrix needed to compute the given CTA's output tile. + A copy partition maps this tile to per-thread copy views (``tAgA``, ``tAsA``), + and the multi-stage ``cp.async`` pipeline performs + ``copy(tiled_copy_A, tAgA[k], tAsA[stage])``. + - In parallel, the **MMA flow** turns the staged SMEM tensor into register fragments + consumed by the warp MMA. From the SMEM allocation ``sA``, MMA partitioning + produces the SMEM operand view ``tCsA = partition_A(sA)`` and the register-fragment + layout ``tCrA = make_fragment_A(tCsA)``. A dedicated S2R/``ldmatrix`` path then + retiles the source and destination (``partition_S`` on SMEM, ``retile`` on RMEM) + and executes ``copy(s2r_A, tCsA_copy_view[k_blk], tCrA_copy_view[k_blk])`` + per k-block, filling the ``tCrA`` registers read by ``cute.gemm()``. + +- **Accumulator C/D**: + + - **RMEM accumulator flow** (MMA input/output): output tile views are formed by + ``mC`` → ``local_tile`` → ``gC`` → ``partition_C`` → ``tCgC``, then + ``make_fragment_C(tCgC)`` creates the register accumulator ``tCrC``. + Warp MMA keeps C/D entirely in RMEM, and ``tCrC`` is both the input C + and output D of ``cute.gemm()``. + - **Epilogue flow** (RMEM → SMEM → RMEM → GMEM): the epilogue converts accumulator + values (for example ``tCrD = epilogue_op(tCrC)``), stages them through SMEM + (``autovec_copy(tCrD, tCsC)``), reloads them into registers with the epilogue + copy layout, and performs coalesced vectorized GMEM stores via + ``copy(tiled_copy_C, tCrC_epi, tCgC_epi)``. + +.. code-block:: text + + Operand A Dataflow Path Operand B Dataflow Path Accumulator C/D Dataflow Path + ─────────────────────── ─────────────────────── ───────────────────────────── + + mA: (M, K) [GMEM] mB: (N, K) [GMEM] ┌──── RMEM ──────────┐ + │ │ │ make_fragment_C() │ + │ local_tile(mA, cta_tiler, coord) │ local_tile(mB, cta_tiler, coord) │ tCrC: accumulator │ + ▼ ▼ └───────┬────────────┘ + gA: (BM, BK, k) [GMEM] gB: (BN, BK, k) [GMEM] │ + │ │ tCrC:(MMA,MMA_M,MMA_N) [RMEM] + │ ┌──── SMEM ─────────┐ │ ┌──── SMEM ─────────┐ │ + │ │ sA: (BM,BK,PIPE) │ │ │ sB: (BN,BK,PIPE) │ │ mC: (M, N) [GMEM] + │ └──┬────────┬───────┘ │ └──┬────────┬───────┘ │ │ + │ │ │ │ │ │ │ │ local_tile + │ │ thr_mma.partition_A(sA) │ │ thr_mma.partition_B(sB) │ ▼ + │ │ ▼ │ │ ▼ │ gC: (BM, BN) [GMEM] + │ │ tCsA:(MMA,MMA_M, │ │ tCsB:(MMA,MMA_N, │ │ partition_C + │ │ MMA_K,PIPE) [SMEM] │ │ MMA_K,PIPE) [SMEM] │ ▼ + │ │ │ │ │ │ │ tCgC:(MMA,MMA_M, + │ │ make_fragment_A(tCsA) │ │ make_fragment_B(tCsB) │ MMA_N) + │ │ ▼ │ │ ▼ │ [GMEM] (epi dest) + │ │ tCrA:(MMA,MMA_M, │ │ tCrB:(MMA,MMA_N, │ │ + │ │ MMA_K) [RMEM] │ │ MMA_K) [RMEM] │ │ + │ │ │ │ │ │ │ │ + │ │ S2R retiling (ldmatrix): │ │ S2R retiling (ldmatrix): │ │ + │ │ s2r_A = make_tiled_copy_A( │ │ s2r_B = make_tiled_copy_B( │ │ + │ │ ldmatrix, mma) │ │ ldmatrix, mma) │ │ + │ │ tCsA_copy_view = │ │ tCsB_copy_view = │ │ + │ │ s2r_A.partition_S(sA) │ │ s2r_B.partition_S(sB) │ │ + │ │ tCrA_copy_view = retile(tCrA) │ │ tCrB_copy_view = retile(tCrB) │ │ + │ │ └─────────────┐ │ │ └─────────────┐ │ │ + ╰─────┤ │ ╰─────┤ │ │ │ + ▼ │ ▼ │ │ │ + tAgA = thr_copy_A. │ tBgB = thr_copy_B. │ │ │ + partition_S(gA) │ partition_S(gB) │ │ │ + tAsA = thr_copy_A. │ tBsB = thr_copy_B. │ │ │ + partition_D(sA) │ partition_D(sB) │ │ │ + | │ | │ │ │ + ▼ │ ▼ │ │ │ + ┌───┴────────────────────┐ │ ┌──────┴─────────────────┐│ │ │ + │ cp.async loop (k-tile):│ │ │ cp.async loop (k-tile):││ │ │ + │ copy(tiled_copy_A, │ │ │ copy(tiled_copy_B, ││ │ │ + │ tAgA[k], │ │ │ tBgB[k], ││ │ │ + ┌─▶│ tAsA[stage]) │ │ ┌──▶│ tBsB[stage]) ││ │ │ + │ │ (writes into sA; │ │ │ │ (writes into sB; ││ │ │ + │ │ ldmatrix reads sA) │ │ │ │ ldmatrix reads sB) ││ │ │ + │ │ repeat for next k/stage│ │ │ │ repeat for next k/stage││ │ │ + │ └────────────────────────┘ │ │ └────────────────────────┘│ │ │ + │ │ │ │ │ │ │ │ + └────────┘ ▼ └─────────┘ ▼ ▼ │ + └───────┬───────────────────────────────┴───────────────────┘ │ + │ │ + ▼ │ + ┌────────────────────────────────────────────────────────┐ │ + │ MMA loop (k_blk): │ │ + │ S2R: copy(s2r_A, tCsA_copy_view[k_blk], │ │ + │ tCrA_copy_view[k_blk]) │ │ + │ S2R: copy(s2r_B, tCsB_copy_view[k_blk], │ │ + │ tCrB_copy_view[k_blk]) │ │ + │ [SMEM → RMEM via ldmatrix; fills tCrA/tCrB] │ │ + │ │ │ + │ cute.gemm(tiled_mma, │ │ + ┌──▶ │ tCrC, D (output, RMEM), │ │ + │ │ tCrA[k_blk], A (RMEM), │ │ + │ │ tCrB[k_blk], B (RMEM), │ │ + │ │ tCrC) C (accumulator, RMEM) │ │ + │ └────────────────────────────────────────────────────────┘ │ + │ │ │ │ + └───────┘ | │ + ▼ │ + Epilogue: │ + tCrD = epilogue_op(tCrC) [RMEM] │ + │ │ + ▼ │ + sC = alloc(sC_layout) [SMEM] │ + tCsC = thr_mma.partition_C(sC) │ + R2S: autovec_copy(tCrD, tCsC) │ + [RMEM → SMEM] │ + │ │ + ▼ │ + tCsC_epi = thr_copy_C.partition_S(sC) │ + tCgC_epi = thr_copy_C.partition_D(gC) ◀─────────────────────────────────┘ + tCrC_epi = make_fragment_like(...) + S2R: autovec_copy(tCsC_epi, tCrC_epi) + [SMEM → RMEM] + │ + ▼ + Store: copy(tiled_copy_C, tCrC_epi, tCgC_epi) + [RMEM → GMEM] + +**Naming convention:** + +* ``mma_tiler`` = ``(BM, BN, BK)`` (CTA tiler dimensions) +* ``mX`` = global tensor (for example A as ``(M, K)``) +* ``gX`` = CTA-tiled GMEM slice (for example ``(BM, BK, k)`` for A) +* ``sX`` = SMEM allocation (for example ``(BM, BK, PIPE)``) +* ``tAgA`` / ``tAsA`` = ``cp.async`` source/destination partitions + (``CPY, CPY_M, CPY_K, ...``) +* ``tCsX`` = MMA-partitioned SMEM view (for example ``(MMA, MMA_M, MMA_K, PIPE)``) +* ``tCrX`` = register fragment (for example ``(MMA, MMA_M, MMA_K)``) +* ``tCrC`` = RMEM accumulator (``MMA, MMA_M, MMA_N``) +* ``tCgC`` = MMA-partitioned GMEM view for output (``MMA, MMA_M, MMA_N``) +* ``tCsA_copy_view`` / ``tCrA_copy_view`` = ``ldmatrix`` retile views for SMEM→RMEM + copy (from ``partition_S(sA)`` and ``retile(tCrA)`` on the S2R tiled copy; + C++ equivalents: ``tXsA`` / ``tXrA``) +* ``MMA`` = atom thread-value layout; ``MMA_M/MMA_N/MMA_K`` = repeat counts + (for example ``BM/inst_M``), ``k`` = outer K-tiles, ``PIPE`` = pipeline stages + + +Setting up the TiledMMA, MMA Ops +--------------------------------- + +As shown in the data flow overview, CuTe DSL provides many utilities to tile/partition +the global memory tensors, and create fragment views of SMEM and register tensors for MMA instructions. + +To utilize these functions, we need to setup the TiledMMA, MMA Ops first. + +Creating a Warp MMA Op +~~~~~~~~~~~~~~~~~~~~~~~ + +A warp MMA op describes the hardware ``mma.sync.aligned`` instruction to use, +it has parameters like data types and instruction shape. The operand layout is +fixed (A = row-major, B = col-major). + +.. code-block:: python + + import cutlass + import cutlass.cute as cute + from cutlass.cute.nvgpu import warp + + op = warp.MmaF16BF16Op( + cutlass.Float16, # A/B element type + cutlass.Float32, # accumulator type + (16, 8, 16), # instruction shape (M, N, K) + ) + +The key parameters are: + +- **Instruction shape** ``(M, N, K)``: determines the size of one hardware MMA + instruction. Valid shapes depend on the data type (see ops table below). +- **A/B element type** (``ab_dtype``) and **accumulator type** (``acc_dtype``): + ``Float32`` is always a valid accumulator; ``Float16`` is only valid for F16 + inputs. Each op restricts ``ab_dtype`` to a specific family (F16/BF16, FP8, + MXF4, etc.). +- **Operand layout**: fixed to A = row-major (K-major), B = col-major (K-major). + Transpose is not supported. All 32 threads in a warp cooperate on each + instruction. + + +CuTe DSL provides implementation of many warp-level MMA ops: + +.. list-table:: warp-level MMA ops + :header-rows: 1 + :widths: 34 22 34 10 + + * - PTX name + - Python class + - Constructor parameters + - SM Arch + * - ``mma.sync.aligned.m16n8k{K}.row.col.{acc}.f16.f16`` / ``.bf16.bf16`` + - ``warp.MmaF16BF16Op`` + - ``ab_dtype, acc_dtype, shape_mnk`` + - ``sm_80+`` + * - ``mma.sync.aligned.m16n8k{K}.row.col.{acc}.{e4m3|e5m2}.{e4m3|e5m2}`` + - ``warp.MmaFP8Op`` + - ``ab_dtype, acc_dtype, shape_mnk`` + - ``sm_89+`` + * - ``mma.sync.aligned.kind::mxf4.block_scale.m16n8k64`` + - ``warp.MmaMXF4Op`` + - ``ab_dtype, acc_dtype, sf_type`` + - ``sm_120a+`` + * - ``mma.sync.aligned.kind::mxf4nvf4.block_scale.m16n8k64`` + - ``warp.MmaMXF4NVF4Op`` + - ``ab_dtype, acc_dtype, sf_type`` + - ``sm_120a+`` + +.. {$nv-internal-release begin} + +Internal builds additionally provide: + +.. list-table:: Internal warp-level MMA ops + :header-rows: 1 + :widths: 34 22 34 10 + + * - PTX name + - Python class + - Constructor parameters + - SM Arch + * - ``mma.sp.sync.aligned.m16n8k{K}.row.col.{acc}.f16.f16`` / ``.bf16.bf16`` + - ``warp.MmaF16BF16SparseOp`` + - ``ab_dtype, acc_dtype, shape_mnk, sparse_metadata_format`` + - ``sm_80+`` + +.. {$nv-internal-release end} + +Creating a Tiled MMA +~~~~~~~~~~~~~~~~~~~~~ + +A ``TiledMma`` tiles the MMA atom across the thread block so that multiple +warps cooperate on a larger tile. You can pass the op directly or create an +explicit atom first: + +.. code-block:: python + + # Option 1: directly from op (common shorthand) + tiled_mma = cute.make_tiled_mma(op) + + # Option 2: explicit atom creation + atom = cute.make_mma_atom(op) + tiled_mma = cute.make_tiled_mma(atom) + +With no extra arguments this wraps a single atom — one warp, one +``(16, 8, K)`` tile. The optional ``atom_layout_mnk`` and +``permutation_mnk`` parameters (described in the subsections below) +control multi-warp tiling and per-thread value layout respectively. + +Spatial tiling with a repeat count +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +A repeat tuple ``(M_rep, N_rep, K_rep)`` passed as ``atom_layout_mnk`` +replicates the warp MMA atom across the M, N, and K dimensions, producing a +larger tiled MMA that is executed cooperatively by ``M_rep * N_rep * K_rep`` +warps in a single ``cute.gemm`` call. Each entry in the repeat tuple +corresponds to one **warp** (32 threads), so ``(2, 2, 1)`` uses four warps — +a common configuration for warp-specialized SM80/SM89 kernels: + +.. code-block:: python + + atom = cute.make_mma_atom(op) # op shape: (16, 8, 16) + tiled_mma = cute.make_tiled_mma( + atom, + atom_layout_mnk=(2, 2, 1), # 4 warps: 2 in M, 2 in N + ) # total tiled-MMA tile = (32, 16, 16) + +The coordinates of atoms could be thought as a 3D coordinate: ``(m, n, k)``. +``m`` is the M repeat index, ``n`` is the N repeat index, and ``k`` is the K +repeat index. Each warp MMA atom is executed by a single warp within a +single CTA. + + +.. code-block:: text + + Warp MMA Atom (16x8x16) make_tiled_mma(atom, (2, 2, 1)) + +----------------+ +----------------+----------------+ + | | | | | ^ + | 16 x 8 | | Atom (0,0,0) | Atom (0,1,0) | | + | x 16 | --(2,2,1)--> | 16 x 8 | 16 x 8 | | 2 x inst_M + | | repeat | x 16 | x 16 | | = 32 + | | | [Warp 0] | [Warp 2] | | + +----------------+ +----------------+----------------+ | + | | | | + | Atom (1,0,0) | Atom (1,1,0) | | + | 16 x 8 | 16 x 8 | | + | x 16 | x 16 | | + | [Warp 1] | [Warp 3] | v + +----------------+----------------+ + <--- 2 x inst_N = 16 ---> + K unchanged = 16 + + +Custom tile permutation with ``permutation_mnk`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +``permutation_mnk`` is an optional third argument to ``make_tiled_mma``. +Each of its three entries is a **per-mode permutation** of the M, N, and +K coordinates inside the tiled MMA. In the common case shown in this +section, each entry is just a size, which is the identity permutation of +that size; in that case ``permutation_mnk`` simply sets the **total tile +footprint** of the tiled MMA along each dimension. When a mode's size is +larger than the atom layout's natural coverage +(``atom_layout x inst_shape``), each thread receives additional values to +fill the extended region — the thread count stays the same, but every +thread holds more data. The general form, where an entry is a +``Layout`` that reorders coordinates inside a mode, is covered in the +subsection below. + +The standard convention for warp MMA (used in ``tensorop_gemm.py`` and +throughout the Ampere examples) doubles the N dimension: + +.. code-block:: python + + # From examples/cute/ampere/kernel/dense_gemm/tensorop_gemm.py + permutation_mnk = ( + atom_layout_mnk[0] * mma_inst_shape[0], # M: matches atom coverage + atom_layout_mnk[1] * mma_inst_shape[1] * 2, # N: 2x atom coverage + atom_layout_mnk[2] * mma_inst_shape[2], # K: matches atom coverage + ) + + tC = cute.make_layout(atom_layout_mnk) + tiled_mma = cute.make_tiled_mma( + op, + tC, + permutation_mnk=permutation_mnk, + ) + +**Why double N?** The atom's N dimension is only 8 (inst_N = 8). Without +a permutation, each thread's B-operand values span a single 8-wide +N-range, which may not align well with SMEM load widths. The ``* 2`` +on N gives each thread's B fragment two 8-wide N-ranges instead of one, +aligning the access pattern with wider contiguous SMEM regions for more +efficient loads. + +For ``atom_layout_mnk = (2, 2, 1)`` and ``inst_shape = (16, 8, 16)``: + +- Atom coverage = ``(2x16, 2x8, 1x16) = (32, 16, 16)`` +- ``permutation_mnk = (32, 32, 16)`` — N extended from 16 to 32 + +.. code-block:: text + + Without permutation — natural atom coverage (M = 32, N = 16): + + C tile (M=32, N=16) + +----------------+----------------+ + | | | ^ + | [Warp 0] | [Warp 2] | | + | 16 x 8 | 16 x 8 | | 2 x inst_M + | | | | = 32 + +----------------+----------------+ | + | | | | + | [Warp 1] | [Warp 3] | | + | 16 x 8 | 16 x 8 | | + | | | v + +----------------+----------------+ + <------------- N = 16 ----------> + (each warp owns one (16, 8) atom; + thread T0 of Warp 0 holds 4 C values in its 16x8 block) + + With permutation_mnk = (32, 32, 16) — N extended from 16 to 32: + + C tile (M=32, N=32) + +----------------+----------------+----------------+----------------+ + | | | | | ^ N = 16 → 32: + | [Warp 0] | [Warp 2] | [Warp 0] | [Warp 2] | | atom pattern repeats + | 16 x 8 | 16 x 8 | 16 x 8 | 16 x 8 | | along N. Each thread + | | | | | | now holds 2x the + +----------------+----------------+----------------+----------------+ | values along N + | | | | | | (same threads, more + | [Warp 1] | [Warp 3] | [Warp 1] | [Warp 3] | | values per thread). + | 16 x 8 | 16 x 8 | 16 x 8 | 16 x 8 | | + | | | | | v + +----------------+----------------+----------------+----------------+ + <---------------------------- N = 32 ----------------------------> + | atom coverage | value repeat | + + +Reordering coordinates with a per-mode ``Layout`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +So far each entry of ``permutation_mnk`` has been an integer, which is +shorthand for the identity layout ``Layout, Stride<_1>>`` — the +atom pattern simply tiles to fill an ``S``-wide footprint. The general +form lets each entry be a ``Layout`` that **reorders coordinates inside +that mode** while keeping the same total size. That reordering is what +gives the parameter its name; the integer-only cases used earlier are +just the identity permutation. + +The canonical illustration is the SM70 example from +`0t_mma_atom.md <../../cpp/cute/0t_mma_atom.md>`_. Take a 2x2 tiled MMA +of ``SM70_8x8x4_F32F16F16F32_NT`` atoms with a ``32x32x4`` footprint. +Without any M-mode permutation, thread ``T0``'s 8 A-values land at the +following ``(m, k)`` coordinates:: + + T0V0 => (0, 0) T0V4 => (16, 0) + T0V1 => (1, 0) T0V5 => (17, 0) + T0V2 => (2, 0) T0V6 => (18, 0) + T0V3 => (3, 0) T0V7 => (19, 0) + +— two separate runs of 4 along M, with a gap from m=4 to m=15. We may +prefer those 8 values to sit in **one contiguous run** in the logical +M-coordinates (e.g. so register or SMEM layouts pack cleanly). Passing +the M-mode layout ``(4, 4, 2):(1, 8, 4)`` does exactly that: it is a +scatter permutation telling each old m-coord where to go in the new +image. + +.. code-block:: text + + old m-coord: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 + new m-coord: 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14 15 20 21 22 23 28 29 30 31 + +After the permutation, ``T0``'s 8 A-values occupy ``m = 0..7`` — one +contiguous run — and every other thread's M-values become equally +contiguous. Thread-data ownership and value counts are unchanged; only +the **mapping from values to m-coordinates** is permuted. + +In CuTeDSL the permuted entry is built with ``cute.make_layout``; +identity entries stay as integers: + +.. code-block:: python + + m_perm = cute.make_layout((4, 4, 2), stride=(1, 8, 4)) + tiled_mma = cute.make_tiled_mma( + op, # SM70_8x8x4 NT atom + atom_layout_mnk=(2, 2, 1), + permutation_mnk=(m_perm, 32, 4), # M: scatter, N/K: identity sizes + ) + +The same mechanism applies to the N and K modes — any subset of the +three entries can be an integer (identity) or a ``Layout`` (real +permutation). For warp MMAs the most common case in practice is still +the integer-only form shown earlier in this section; the ``Layout`` form +is the tool you reach for when a register or SMEM layout wants each +thread's fragment to be contiguous in logical coordinates. + + +Partitioning Tensors +--------------------- + +Before computing, partition the CTA-tiled tensors according to the +tiled MMA layout. Warp MMA partitioning is **per-thread**: each of +the 32 threads in a warp (or 128 threads across 4 warps) receives +its own slice of the data, sized to match the register fragments +the MMA instruction expects. + +Example: ``GEMM (M, N, K) = (512, 512, 256)``, +``cta_tiler = (128, 128, 32)``, ``atom_layout_mnk = (2, 2, 1)``, +F16 atom = m16n8k16, ``permutation_mnk = (32, 32, 16)``, +``num_stages = 4``, 4 warps = 128 threads. + +Global matrices: + +.. code-block:: text + + mA: (M, K) = (512, 256) mB: (N, K) = (512, 256) mC: (M, N) = (512, 512) + + K=256 K=256 N=512 + |<--------->| |<--------->| |<---------------->| + +-----------+ +-----------+ +----+----+----+---+ + | | ^ | | ^ | | | | | ^ + | mA | | M=512 | mB | | N=512 | | | | | | M=512 + | | v | | v | | | | | v + +-----------+ +-----------+ +----+----+----+---+ + +Tiling with ``cta_tiler = (BM, BN, BK) = (128, 128, 32)`` gives +M/BM = 4 tiles, N/BN = 4 tiles, K/BK = 8 tiles: + +.. code-block:: text + + mA tiled into (M/BM x K/BK) mB tiled into (N/BN x K/BK) mC tiled into (M/BM x N/BN) + = (4 x 8) blocks = (4 x 8) blocks = (4 x 4) blocks + + BK=32 x8 BK=32 x8 BN=128 x4 + |<-->| |<-->| |<------>| + +----+----+-- --+ +----+----+-- --+ +--------+--------+-- --+ + | | |..| | ^ BM=128 | | |..| | ^ BN=128 | (0,0) | (0,1) |.. | ^ BM=128 + +----+----+-- --+ v +----+----+-- --+ v +--------+--------+ + v + | | |..| | ^ BM=128 | | |..| | ^ BN=128 | (1,0) | (1,1) |.. | ^ BM=128 + +----+----+-- --+ v +----+----+-- --+ v +--------+--------+ + v + | | |..| | ^ | | |..| | ^ | ... | ... |.. | ^ + +----+----+-- --+ v +----+----+-- --+ v +--------+--------+-- --+ v + | | |..| | ^ | | |..| | ^ | (3,0) | (3,1) |.. | ^ + +----+----+-- --+ v +----+----+-- --+ v +--------+--------+-- --+ v + +Each CTA picks one (M-tile, N-tile) coordinate. +For example, CTA at ``tiler_coord = (0, 1, :)``. + +After ``local_tile`` — one CTA's tile (``k = K/BK = 256/32 = 8``): + +.. code-block:: text + + gA: (BM, BK, k) = (128, 32, 8) gB: (BN, BK, k) = (128, 32, 8) gC: (BM, BN) = (128, 128) + + BK=32 BK=32 BN=128 + |<----->| |<----->| |<-------->| + +-------+-- +-------+-- +----------+ + | |.. | |.. | | ^ + BM= | gA | k=8 BN= | gB | k=8 BM= | gC | | 128 + 128 | | 128 | | 128 | | v + +-------+ +-------+ +----------+ + +SMEM tensors ``sA`` and ``sB`` have a pipeline staging dimension: + +.. code-block:: text + + sA: (BM, BK, PIPE) = (128, 32, 4) sB: (BN, BK, PIPE) = (128, 32, 4) + +``get_slice(tidx)`` — each thread receives its own per-thread partition. +The tiled MMA footprint is ``permutation_mnk = (32, 32, 16)``, so BM, +BN, and BK are each subdivided into MMA-sized blocks: + +.. code-block:: text + + sA: partition into (MMA, MMA_M, MMA_K, PIPE) + + Each SMEM stage (BM=128, BK=32): + + perm_K perm_K perm_M=32 + =16 =16 |<---->| + |<--->|<--->| +------+------+------+------+ + +-----+-----+ ^ | | | | | ^ + | 0 | 1 | | perm_M=32 | 0 | 1 | 2 | 3 | | perm_N + +-----+-----+ v | | | | | v =32 + | 0 | 1 | ^ +------+------+------+------+ + | | | | perm_M=32 MMA_N = BN/perm_N = 4 + +-----+-----+ v + | 0 | 1 | ^ sB: partition into (MMA, MMA_N, MMA_K, PIPE) + | | | | + +-----+-----+ v gC: partition into (MMA, MMA_M, MMA_N) + | 0 | 1 | ^ + | | | | + +-----+-----+ v + MMA_K = BK/perm_K = 2 + MMA_M = BM/perm_M = 4 + +After partition (per thread, e.g. thread ``tidx``): + +- ``tCsA: (MMA, MMA_M, MMA_K, PIPE) = (MMA, 4, 2, 4)`` — MMA_M = BM/perm_M = 128/32 = 4, MMA_K = BK/perm_K = 32/16 = 2 +- ``tCsB: (MMA, MMA_N, MMA_K, PIPE) = (MMA, 4, 2, 4)`` — MMA_N = BN/perm_N = 128/32 = 4, MMA_K = BK/perm_K = 32/16 = 2 +- ``tCgC: (MMA, MMA_M, MMA_N) = (MMA, 4, 4)`` — MMA_M = 128/32 = 4, MMA_N = 128/32 = 4 + +The first mode ``MMA`` contains the atom's **thread × value** layout — it +encodes which registers within a single thread hold which matrix +elements. The remaining modes are repeat counts that tile the atom +across the full CTA tile. + +.. code-block:: python + + @cute.kernel + def kernel(tiled_mma: cute.TiledMma, ...): + tidx, _, _ = cute.arch.thread_idx() + + # CTA-tiled global tensors + gA = cute.local_tile(mA, cta_tiler, tiler_coord, proj=(1, None, 1)) + gB = cute.local_tile(mB, cta_tiler, tiler_coord, proj=(None, 1, 1)) + gC = cute.local_tile(mC, cta_tiler, tiler_coord, proj=(1, 1, None)) + + # Per-thread partition via the thread index + thr_mma = tiled_mma.get_slice(tidx) + + # SMEM partitions (used by make_fragment_A/B and ldmatrix retiling) + tCsA = thr_mma.partition_A(sA) # (MMA, MMA_M, MMA_K, PIPE) + tCsB = thr_mma.partition_B(sB) # (MMA, MMA_N, MMA_K, PIPE) + + # C partitions for epilogue staging (SMEM) and destination (GMEM) + tCsC = thr_mma.partition_C(sC) # (MMA, MMA_M, MMA_N) + tCgC = thr_mma.partition_C(gC) # (MMA, MMA_M, MMA_N) + +.. note:: The ``tCsA`` / ``tCsB`` SMEM partitions are not read directly + by the GEMM — they establish the **shape** that + ``make_fragment_A`` / ``make_fragment_B`` use to allocate register + fragments. Actual SMEM→RMEM data movement goes through the S2R + ``ldmatrix`` retiling path (see `Making Fragments`_). + +Pre and Post-Conditions for Partitioning +----------------------------------------- + +* The inputs of the partition should be at least rank-2 tensors. +* The output of the partition will have the layout that is compatible with the MMA atom's operand: + + - For A, the output will have the layout ``(MMA, MMA_M, MMA_K, ...)``. + - For B, the output will have the layout ``(MMA, MMA_N, MMA_K, ...)``. + - For C, the output will have the layout ``(MMA, MMA_M, MMA_N, ...)``. + +* Note that the partition doesn't enforce any rules on the tensor's memory space or the tensor's data type. It only cares about the layout. + + +Making Fragments +----------------- + +Fragments are the tensors that the warp MMA instruction operates on. For +warp MMA: + +- **Fragment A**: per-thread register fragment holding one operand-A K-block. +- **Fragment B**: per-thread register fragment holding one operand-B K-block. +- **Fragment C (accumulator)**: per-thread register fragment that lives in + RMEM and serves as both the input C and output D of ``cute.gemm()``. + +Creating register fragments and ``ldmatrix`` copy views +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Warp MMA fragments are actual per-thread register tensors, not descriptors. +Fragment creation has three parts: + +**1. A and B fragments** + +``make_fragment_A`` and ``make_fragment_B`` take one stage of the +MMA-partitioned SMEM views (``tCsA`` / ``tCsB``) and allocate register +fragments with a matching thread-local layout. This establishes the shape +only; no data is loaded yet. + +.. code-block:: python + + # Per-thread MMA partitions + # (sA/sB are the staged SMEM tensors — see "Creating SMEM layouts for A and B") + tCsA = thr_mma.partition_A(sA) # (MMA, MMA_M, MMA_K, PIPE) + tCsB = thr_mma.partition_B(sB) # (MMA, MMA_N, MMA_K, PIPE) + + # Register fragments for one pipeline stage + tCrA = tiled_mma.make_fragment_A( + tCsA[None, None, None, 0] + ) # (MMA, MMA_M, MMA_K) + tCrB = tiled_mma.make_fragment_B( + tCsB[None, None, None, 0] + ) # (MMA, MMA_N, MMA_K) + +Continuing the running example from `Partitioning Tensors`_ (F16 +``m16n8k16``, ``cta_tiler = (128, 128, 32)``, ``permutation_mnk = (32, 32, +16)``, ``num_stages = 4``): + +.. code-block:: text + + tCsA: (MMA, MMA_M=4, MMA_K=2, PIPE=4) + tCsB: (MMA, MMA_N=4, MMA_K=2, PIPE=4) + + make_fragment_A(tCsA[..., stage]) -> tCrA: (MMA, 4, 2) + make_fragment_B(tCsB[..., stage]) -> tCrB: (MMA, 4, 2) + +Each element of ``tCrA`` / ``tCrB`` is a register value owned by the current +thread. Together, the 32 threads in the warp hold the full operand fragment +that one ``mma.sync.aligned`` instruction consumes. + +**2. C fragment (accumulator)** + +``make_fragment_C`` allocates the accumulator registers for the CTA tile +slice owned by the current thread. The accumulator usually starts at zero +before the K loop and is updated in-place by each ``cute.gemm()`` call. + +.. code-block:: python + + tCgC = thr_mma.partition_C(gC) # (MMA, MMA_M, MMA_N) + tCrC = tiled_mma.make_fragment_C(tCgC) + tCrC.fill(0.0) + +For the same running example: + +.. code-block:: text + + tCgC: (MMA, MMA_M=4, MMA_N=4) + make_fragment_C(tCgC) -> tCrC: (MMA, 4, 4) + +``tCrC`` stays in registers for the entire main loop and serves as both the +input C and output D argument of ``cute.gemm()``. + +**3. SMEM → RMEM load (``ldmatrix`` retiling)** + +The register fragments above are storage only — before ``cute.gemm()`` can +consume ``tCrA`` and ``tCrB``, each K-block must be loaded from shared +memory into those registers. This is done via a separate tiled copy built +from an ``ldmatrix`` copy atom and linked to the tiled MMA with +``make_tiled_copy_A`` / ``make_tiled_copy_B``. The copy's ``retile()`` +call remaps the MMA fragment's register layout to match what the +``ldmatrix`` instruction writes. + +.. code-block:: python + + # 1. Create ldmatrix copy atom → tiled copy tied to the MMA layout + s2r_atom_A = cute.make_copy_atom(LdMatrix8x8x16bOp(...), dtype) + s2r_tiled_A = cute.make_tiled_copy_A(s2r_atom_A, tiled_mma) + + # 2. Build SMEM-side and RMEM-side views for the copy + thr_s2r_A = s2r_tiled_A.get_slice(tidx) + tCsA_copy_view = thr_s2r_A.partition_S(sA) # SMEM source + tCrA_copy_view = thr_s2r_A.retile(tCrA) # RMEM dest (retiled) + + # 3. Load one k-block from SMEM into the MMA fragment (in the main loop) + cute.copy(s2r_tiled_A, tCsA_copy_view[None, None, k_block], + tCrA_copy_view[None, None, k_block]) + +See ``tensorop_gemm.py`` for the complete implementation including the +``ldmatrix`` transpose flag, FP8 variants, and operand B. + + +Creating SMEM layouts for A and B +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The SMEM layouts define how A and B tiles are staged in shared memory before +the ``ldmatrix`` loads. For warp MMA, these layouts must satisfy two goals at +the same time: + +- **Efficient GMEM -> SMEM copy:** ``cp.async`` should write contiguous + 16-byte regions for each thread. +- **Bank-conflict-free SMEM -> RMEM load:** the later ``ldmatrix`` loads + should see a swizzled layout that matches the warp MMA operand access + pattern. + +The Ampere dense GEMM example +(``examples/cute/ampere/kernel/dense_gemm/tensorop_gemm.py``) builds these +layouts inline with a helper named ``_make_smem_layout_AB``. + +**Host side** (``@cute.jit``): + +.. code-block:: python + + # 16 bytes per thread for GMEM -> SMEM copies + ab_copy_bits = 128 + + sA_layout, sA_swizzle = self._make_smem_layout_AB( + mA.element_type, # dtype (e.g. Float16) + self.a_major_mode, # row-major or col-major + ab_copy_bits, # copy width in bits (128 = 16 bytes) + (self.cta_tiler[0], # BM + self.cta_tiler[2], # BK + self.num_stages), # PIPE + ) + sB_layout, sB_swizzle = self._make_smem_layout_AB( + mB.element_type, + self.b_major_mode, + ab_copy_bits, + (self.cta_tiler[1], # BN + self.cta_tiler[2], # BK + self.num_stages), # PIPE + ) + +Here ``smem_tiler`` is ``(M_or_N, K, PIPE)``: ``(BM, BK, PIPE)`` for A and +``(BN, BK, PIPE)`` for B. The helper returns: + +- ``sX_layout``: the logical SMEM layout with shape ``(BM_or_BN, BK, PIPE)``. +- ``sX_swizzle``: the swizzle applied when the tensor is materialized in SMEM. + +The helper from ``tensorop_gemm.py`` implements the following four steps: + +1. **Pick the major-mode size.** For a row-major operand, the contiguous + dimension is K, so the helper uses ``smem_tiler[1]``. For a col-major + operand, the contiguous dimension is M or N, so it uses ``smem_tiler[0]``. + +2. **Cap the contiguous span at 128 bytes.** This keeps the layout atom within + the swizzle span used by the example. The cap is 64 elements for F16/BF16 + and 128 elements for FP8. + +3. **Build the swizzle.** With ``copy_bits = 128`` (16 bytes), the helper + derives three arguments for ``make_swizzle``: + + - ``swizzle_bits = log2(major_mode_size * dtype.width / copy_bits)``, + capped at 3. This is the number of address bits that get XOR'd. + - ``base_bits = log2(copy_bits / 8)`` — log2 of the copy width in + bytes (= 4 for 16-byte copies). + - ``shift_bits = log2(copy_bits / dtype.width)`` — log2 of the copy + width in elements (= 3 for F16 with 128-bit copies, i.e. 8 elements). + +4. **Build an 8-row layout atom and tile it.** The constant 8 comes from + ``ldmatrix``: each warp-level load touches 8 rows of shared memory + (32 threads, 4 matrices per load). Row-major uses an atom + ``(8, major_mode_size):(major_mode_size, 1)`` — 8 rows of contiguous + K-elements. Col-major uses + ``(major_mode_size, 8):(1, major_mode_size)`` — contiguous MN-elements + across 8 K-rows. ``tile_to_shape`` then broadcasts that atom across the + full ``(M_or_N, K, PIPE)`` SMEM tensor. + +For the running F16 example (``cta_tiler = (128, 128, 32)``, +``num_stages = 4``, ``copy_bits = 128``): + +.. code-block:: text + + A operand (row-major, smem_tiler = (128, 32, 4)): + major_mode_size = 32 + atom = (8, 32):(32, 1) + swizzle = make_swizzle(2, 4, 3) + tiled layout -> sA: (128, 32, 4) + + B operand (col-major, smem_tiler = (128, 32, 4)): + major_mode_size = min(128, 64) = 64 + atom = (64, 8):(1, 64) + swizzle = make_swizzle(3, 4, 3) + tiled layout -> sB: (128, 32, 4) + + +**Kernel side** (``@cute.kernel``): + +The layout and swizzle are passed to shared-memory allocation: + +.. code-block:: python + + @cute.struct + class SharedStorageAB: + a: cute.struct.Align[ + cute.struct.MemRange[mA.element_type, cute.cosize(sA_layout)], + 16, + ] + b: cute.struct.Align[ + cute.struct.MemRange[mB.element_type, cute.cosize(sB_layout)], + 16, + ] + + sA = SharedStorageAB(storage).a.get_tensor(sA_layout, swizzle=sA_swizzle) + sB = SharedStorageAB(storage).b.get_tensor(sB_layout, swizzle=sB_swizzle) + +After allocation: + +- ``sA`` has shape ``(BM, BK, PIPE)``. +- ``sB`` has shape ``(BN, BK, PIPE)``. + +These are the staged SMEM tensors written by ``cp.async`` and later consumed by +``partition_A`` / ``partition_B``, ``make_fragment_A`` / ``make_fragment_B``, +and the ``ldmatrix`` copy views described in `Making Fragments`_. + + +Executing the GEMM (Main Loop) +------------------------------- + +The main loop iterates over K-tiles and, within each tile, over k-blocks +(``num_k_block = BK / perm_K``). Each k-block loads A and B from SMEM into +registers via ``ldmatrix``, then issues ``cute.gemm``. + +.. code-block:: python + + tCrC.fill(0.0) + + for k_tile in range(k_tile_count): + for k_block in cutlass.range(num_k_block, unroll_full=True): + # Wait for next SMEM stage at the tile boundary + if k_block == num_k_block - 1: + cute.arch.cp_async_wait_group(num_smem_stages - 2) + cute.arch.sync_threads() + + # ldmatrix: prefetch next k-block from SMEM → RMEM + k_block_next = (k_block + 1) % num_k_block + cute.copy(tiled_copy_s2r_A, tCsA_p[None, None, k_block_next], + tCrA_copy_view[None, None, k_block_next]) + cute.copy(tiled_copy_s2r_B, tCsB_p[None, None, k_block_next], + tCrB_copy_view[None, None, k_block_next]) + + # cp.async: issue GMEM → SMEM for next K-tile + # ... (see tensorop_gemm.py for pipeline pointer management) + + # MMA: tCrC += tCrA * tCrB + cute.gemm(tiled_mma, tCrC, tCrA[None, None, k_block], tCrB[None, None, k_block], tCrC) + + cute.arch.cp_async_wait_group(0) + cute.arch.sync_threads() + +Key points: + +- ``cute.gemm`` is **synchronous** — it emits ``mma.sync.aligned`` + instructions. There is no accumulate-mode flag; the accumulator + (``tCrC``) is always read and written. +- All operands must be in **registers** before ``cute.gemm`` is called. + The ``ldmatrix`` copies above prefetch the next k-block into + ``tCrA`` / ``tCrB`` from SMEM each iteration. +- The ``cp.async`` / ``cp_async_wait_group`` calls manage the GMEM→SMEM + pipeline; see ``tensorop_gemm.py`` for predication, K-residue handling, + and pipeline pointer management. + + +Complete Workflow +------------------ + +Putting it all together, a typical Ampere warp MMA GEMM has this structure: + +**Host function** (``@cute.jit``): + +.. code-block:: python + + import cutlass + import cutlass.cute as cute + + @cute.jit + def host_function(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor, stream): + # 1. Create the MMA op and tiled MMA + op = cute.nvgpu.warp.MmaF16BF16Op(cutlass.Float16, cutlass.Float32, (16, 8, 16)) + atom_layout_mnk = (2, 2, 1) + permutation_mnk = ( + atom_layout_mnk[0] * 16, + atom_layout_mnk[1] * 8 * 2, + atom_layout_mnk[2] * 16, + ) + tC = cute.make_layout(atom_layout_mnk) + tiled_mma = cute.make_tiled_mma(op, tC, permutation_mnk=permutation_mnk) + + # 2. Create SMEM layouts + ab_copy_bits = 128 + sA_layout, sA_swizzle = _make_smem_layout_AB( + mA.element_type, a_major_mode, ab_copy_bits, + (cta_tiler[0], cta_tiler[2], num_stages), + ) + sB_layout, sB_swizzle = _make_smem_layout_AB( + mB.element_type, b_major_mode, ab_copy_bits, + (cta_tiler[1], cta_tiler[2], num_stages), + ) + + # 3. Launch the kernel + kernel(mA, mB, mC, ..., tiled_mma, sA_layout, sA_swizzle, + sB_layout, sB_swizzle).launch( + grid=grid, block=[128, 1, 1], stream=stream, + ) + +**Kernel function** (``@cute.kernel``): + +.. code-block:: python + + @cute.kernel + def kernel(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor, + ..., tiled_mma: cute.TiledMma): + tidx, _, _ = cute.arch.thread_idx() + bidx, bidy, bidz = cute.arch.block_idx() + + # -- CTA-tiled global tensors -- + gA = cute.local_tile(mA[None, None, bidz], cta_tiler, (bidx, bidy, None), proj=(1, None, 1)) + gB = cute.local_tile(mB[None, None, bidz], cta_tiler, (bidx, bidy, None), proj=(None, 1, 1)) + gC = cute.local_tile(mC[None, None, bidz], cta_tiler, (bidx, bidy, None), proj=(1, 1, None)) + + # -- Allocate SMEM -- + @cute.struct + class SharedStorageAB: + a: cute.struct.Align[cute.struct.MemRange[mA.element_type, cute.cosize(sA_layout)], 16] + b: cute.struct.Align[cute.struct.MemRange[mB.element_type, cute.cosize(sB_layout)], 16] + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorageAB) + sA = SharedStorageAB(storage).a.get_tensor(sA_layout, swizzle=sA_swizzle) # (BM, BK, PIPE) + sB = SharedStorageAB(storage).b.get_tensor(sB_layout, swizzle=sB_swizzle) # (BN, BK, PIPE) + sC = ... # (BM, BN) SMEM for epilogue (non-MMA, see tensorop_gemm.py) + + # -- GMEM → SMEM copy partitions (cp.async) -- + # ... setup tAgA, tAsA, tBgB, tBsB (see tensorop_gemm.py) + + # -- MMA partitions and fragments -- + thr_mma = tiled_mma.get_slice(tidx) + tCsA = thr_mma.partition_A(sA) # (MMA, MMA_M, MMA_K, PIPE) + tCsB = thr_mma.partition_B(sB) # (MMA, MMA_N, MMA_K, PIPE) + tCsC = thr_mma.partition_C(sC) # (MMA, MMA_M, MMA_N) + tCgC = thr_mma.partition_C(gC) # (MMA, MMA_M, MMA_N) + tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0]) # (MMA, MMA_M, MMA_K) + tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0]) # (MMA, MMA_N, MMA_K) + tCrC = tiled_mma.make_fragment_C(tCgC) # (MMA, MMA_M, MMA_N) + tCrC.fill(0.0) + + # -- ldmatrix retiling (see "Making Fragments" § SMEM → RMEM load) -- + # ... build tiled_copy_s2r_A/B from LdMatrix8x8x16bOp + make_tiled_copy_A/B + # ... then: tCsA_copy_view = partition_S(sA), tCrA_copy_view = retile(tCrA), etc. + + # -- Prologue: cp.async fills num_stages-1 SMEM buffers -- + # -- Prefetch first k-block into registers via ldmatrix -- + # ... (see tensorop_gemm.py for predication, residual_k, and pipeline setup) + + # -- Main loop -- + for k_tile in range(k_tile_count): + for k_block in cutlass.range(num_k_block, unroll_full=True): + if k_block == num_k_block - 1: + cute.arch.cp_async_wait_group(num_smem_stages - 2) + cute.arch.sync_threads() + + # ldmatrix: prefetch next k-block from SMEM → RMEM + # tCsA_p / tCsB_p are per-pipeline-stage slices, e.g.: + # tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read] + k_block_next = (k_block + 1) % num_k_block + cute.copy(tiled_copy_s2r_A, tCsA_p[None, None, k_block_next], + tCrA_copy_view[None, None, k_block_next]) + cute.copy(tiled_copy_s2r_B, tCsB_p[None, None, k_block_next], + tCrB_copy_view[None, None, k_block_next]) + + # cp.async: issue GMEM → SMEM for next K-tile + # ... (see tensorop_gemm.py for pipeline pointer management) + + # MMA + cute.gemm(tiled_mma, tCrC, tCrA[None, None, k_block], + tCrB[None, None, k_block], tCrC) + + # -- Epilogue: RMEM → SMEM → RMEM → GMEM -- + cute.arch.cp_async_wait_group(0) + cute.arch.sync_threads() + tCrD = cute.make_fragment_like(tCrC, c_dtype) + tCrD[None] = epilogue_op(tCrC.load()).to(c_dtype) + cute.autovec_copy(tCrD, tCsC) # RMEM → SMEM + cute.arch.sync_threads() + # ... reload with epilogue thread layout, then vectorized store to GMEM + + +Beyond Simple Dense MMAs +------------------------ + +The warp MMA DSL supports more complex MMA operations beyond simple dense MMA: + +- Block-scaled MMA + +.. {$nv-internal-release begin} + +Internal builds additionally provide: + +- Sparse MMA + +.. {$nv-internal-release end} + +.. {$nv-internal-release begin} + +Sparse MMA +~~~~~~~~~~ + +Sparse MMA exploits **2:4 structured sparsity** in operand A: out of every +4 consecutive K-elements, exactly 2 are non-zero. The hardware consumes a +compressed A operand together with a compact **metadata** tensor ``E`` that +encodes which 2 of 4 positions are non-zero. + +Compared to dense MMA, the MMA API differences are: + +**1. MMA op creation** — use ``MmaF16BF16SparseOp`` with an extra +``sparse_metadata_format`` parameter. The sparse instruction K is doubled +relative to dense (dense ``m16n8k8`` → sparse ``m16n8k16``, dense +``m16n8k16`` → sparse ``m16n8k32``) because operand A is 2:4 compressed: + +.. code-block:: python + + from cutlass.cute.nvgpu.warp.mma import SparseMetadataFormat + + # Dense F16 (for comparison): inst_K = 16 + dense_op = cute.nvgpu.warp.MmaF16BF16Op( + cutlass.Float16, cutlass.Float32, (16, 8, 16), + ) + + # Sparse F16: inst_K = 32 (2× dense, since A is 2:4 compressed) + sparse_op = cute.nvgpu.warp.MmaF16BF16SparseOp( + cutlass.Float16, # A/B element type + cutlass.Float32, # accumulator type + (16, 8, 32), # instruction shape (M, N, K) + SparseMetadataFormat.TID, # metadata format + ) + tiled_mma = cute.make_tiled_mma(sparse_op, cute.make_layout((1, 1, 1))) + +.. code-block:: text + + Supported instruction shapes for MmaF16BF16SparseOp: + + | A/B Type | Acc Type | Inst Shape | + |----------|-----------|----------------| + | F16 | F16, F32 | (16,8,16), (16,8,32) | + | BF16 | F32 | (16,8,16), (16,8,32) | + +**2. Compressed A tensor and metadata E** — operand A stores only the +two non-zero values per group of 4 K-elements (half the storage). The +metadata tensor ``E`` records which 2 of 4 positions are non-zero. The +exact bit encoding depends on ``SparseMetadataFormat`` and on how the +implementation packs metadata. In this repository, helper code that +generates 2:4 test inputs packs two 4-bit metadata entries into each +``uint8`` value: + +.. code-block:: python + + # Example metadata values used by examples/CuTeDSL/helpers/sparse_utils.py + # Each nibble selects which 2 of 4 positions are non-zero. + metadata_values = [0x4, 0x8, 0x9, 0xC, 0xD, 0xE] + +.. code-block:: text + + Dense A: (M, K) Sparse operands: + +--+--+--+--+--+--+--+--+ +--+--+--+--+ + | a| 0| b| 0| c| 0| d| 0| → | a| b| c| d| (compressed A values) + +--+--+--+--+--+--+--+--+ +--+--+--+--+ + + E stores the non-zero positions + for each 2:4 group. + +**3. Fragments** — the dense-style fragment APIs for A, B, and C still +apply to the sparse atom: + +.. code-block:: python + + # A/B/C fragments — same public API shape as dense + tCsA = thr_mma.partition_A(sA) + tCsB = thr_mma.partition_B(sB) + tCgC = thr_mma.partition_C(gC) + + tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0]) + tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0]) + tCrC = tiled_mma.make_fragment_C(tCgC) + tCrC.fill(0.0) + +Sparse metadata ``E`` is an auxiliary operand associated with A. The +public warp API and tests in this repository verify op construction and +the ``cute.gemm(..., [A, E], B, ...)`` calling convention, but they do +not provide an end-to-end warp sparse kernel showing the exact +``partition`` / ``copy`` / ``make_fragment`` sequence for ``E``. For +that reason, this document intentionally does not spell out an ``E`` +fragment construction sequence that has no example backing it. + +**4. Modified gemm call** — the metadata E is passed alongside operand A +as a list. This part of the API is verified by ``cutlass.cute.algorithm.gemm``: + +.. code-block:: python + + # Schematic only: E_k is the metadata operand for the same k-slice as A_k. + A_k = tCrA[None, None, k_block] + E_k = metadata_k + B_k = tCrB[None, None, k_block] + + cute.gemm( + tiled_mma, + tCrC, + [A_k, E_k], # [A, E] + B_k, + tCrC, + ) + +.. code-block:: text + + Dense gemm call: + cute.gemm(tiled_mma, tCrC, A_k, B_k, tCrC) + + Sparse gemm call: + cute.gemm(tiled_mma, tCrC, [A_k, E_k], B_k, tCrC) + ^^^^ ^^^ + A metadata + +The epilogue (RMEM → SMEM → GMEM) is identical to a dense kernel. + +.. note:: An end-to-end warp sparse GEMM example is not yet available in the + examples directory. The closest verified references in this repository are + ``cutlass_ir/compiler/test/python/not_pytest/sm_80/test_mma_atom.py`` for + op construction, ``cutlass_ir/compiler/test/python/api/sm_120a/test_nvgpu_warp_mma.py`` + for tiled sparse MMA construction, and + ``examples/CuTeDSL/helpers/sparse_utils.py`` for + 2:4 metadata packing. + +.. {$nv-internal-release end} + + +Block-scaled MMA +~~~~~~~~~~~~~~~~ + +Block-scaled MMA multiplies narrow-type matrices (FP4) while applying +**per-block scale factors** along the GEMM-K dimension. Each vector of +``sf_vec_size`` consecutive K-elements shares a single scale factor, so the +hardware computes ``D = (SFA · A) * (SFB · B) + C``. The scale factors live +in **registers** alongside the operands and must be loaded from SMEM before +each ``gemm`` call. + +Supported ops: ``MmaMXF4Op`` (SM120a+), ``MmaMXF4NVF4Op`` (SM120a+). + +Compared to a dense MMA kernel, a block-scaled kernel has four additional concerns: + +**1. MMA op creation** — block-scaled ops fix the data type to FP4 +(E2M1) and the accumulator to FP32. The scale-factor type and vector +size distinguish the two ops: + +.. code-block:: python + + # MXF4: UE8M0 scales, sf_vec_size = 32 + op = cute.nvgpu.warp.MmaMXF4Op( + cutlass.Float4E2M1FN, # A/B element type (fixed: E2M1) + cutlass.Float32, # accumulator type (fixed: F32) + cutlass.Float8E8M0FNU, # scale-factor type + ) # instruction shape = (16, 8, 64), sf_vec_size = 32 + + # MXF4NVF4: UE4M3 scales, sf_vec_size = 16 + op = cute.nvgpu.warp.MmaMXF4NVF4Op( + cutlass.Float4E2M1FN, # A/B element type (fixed: E2M1) + cutlass.Float32, # accumulator type (fixed: F32) + cutlass.Float8E4M3FN, # scale-factor type + ) # instruction shape = (16, 8, 64), sf_vec_size = 16 + +.. code-block:: text + + | Op | A/B Type | SF Type | Acc | Inst Shape | SF Vec Size | + |---------------|----------|---------|------|-------------|-------------| + | MmaMXF4Op | E2M1 | UE8M0 | F32 | (16,8,64) | 32 | + | MmaMXF4NVF4Op | E2M1 | UE4M3 | F32 | (16,8,64) | 16 | + +**2. Extra global tensors and SMEM layouts for scale factors** — the host +function creates SFA/SFB tensors and allocates SMEM layouts for them +alongside A and B: + +.. code-block:: python + + import cutlass.utils.blockscaled_layout as blockscaled_utils + import cutlass.utils.blackwell_helpers as sm120_utils + + # Scale-factor global tensors (host side) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(a.shape, sf_vec_size) + sfa_tensor = cute.make_tensor(sfa.iterator, sfa_layout) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b.shape, sf_vec_size) + sfb_tensor = cute.make_tensor(sfb.iterator, sfb_layout) + + # SMEM layouts for scale factors (SM120-specific helper) + sfa_smem_layout = blockscaled_utils.sm120_make_smem_layout_sfa( + tiled_mma, tile_shape_mnk, sf_vec_size, num_stages, + ) + sfb_smem_layout = blockscaled_utils.sm120_make_smem_layout_sfb( + tiled_mma, tile_shape_mnk, sf_vec_size, num_stages, + ) + +**3. SF fragment creation and SMEM→RMEM retiling** — scale-factor +fragments use a ``CopyUniversalOp`` with thread-value layouts derived +from the tiled MMA, rather than the ``ldmatrix``-based path used for +A and B: + +.. code-block:: python + + # A/B fragments (same as dense) + tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0]) + tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0]) + + # SF fragments (SM120-specific partition helpers) + tCrSFA = sm120_utils.partition_fragment_SFA(sSFA[None, None, 0], thr_mma, tidx) + tCrSFB = sm120_utils.partition_fragment_SFB(sSFB[None, None, 0], thr_mma, tidx) + + # A/B: ldmatrix retiling (same as dense) + atom_copy_A = cute.make_copy_atom(cute.nvgpu.warp.LdMatrix8x8x16bOp(...), a_dtype) + smem_tiled_copy_A = cute.make_tiled_copy_A(atom_copy_A, tiled_mma) + + # SF: CopyUniversal with SF-specific thread-value layout + atom_copy_SF = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), sf_dtype) + smem_tiled_copy_SFA = cute.make_tiled_copy( + atom_copy_SF, + sm120_utils.get_layoutSFA_TV(tiled_mma), + (cute.size(tiled_mma.permutation_mnk[0]), cute.size(tiled_mma.permutation_mnk[2])), + ) + smem_tiled_copy_SFB = cute.make_tiled_copy( + atom_copy_SF, + sm120_utils.get_layoutSFB_TV(tiled_mma), + (cute.size(tiled_mma.permutation_mnk[1]), cute.size(tiled_mma.permutation_mnk[2])), + ) + +**4. Modified main loop** — each k-block loads A, B, SFA, and SFB from +SMEM into registers. The ``cute.gemm`` call passes ``[A, SFA]`` and +``[B, SFB]`` as operand lists: + +.. code-block:: python + + for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True): + # ldmatrix: load A and B from SMEM → RMEM (same as dense) + cute.copy(smem_tiled_copy_A, tCsA_p[None, None, k_block_next], + tCrA_copy_view[None, None, k_block_next]) + cute.copy(smem_tiled_copy_B, tCsB_p[None, None, k_block_next], + tCrB_copy_view[None, None, k_block_next]) + + # CopyUniversal: load SFA and SFB from SMEM → RMEM # NEW + cute.copy(smem_tiled_copy_SFA, + cute.filter_zeros(tCsSFA_p)[None, None, k_block_next], + cute.filter_zeros(tCrSFA_copy_view)[None, None, k_block_next]) + cute.copy(smem_tiled_copy_SFB, + cute.filter_zeros(tCsSFB_p)[None, None, k_block_next], + cute.filter_zeros(tCrSFB_copy_view)[None, None, k_block_next]) + + # MMA with scale factors passed as [value, scale] pairs + cute.gemm( + tiled_mma, + accumulators, + [tCrA[None, None, k_block_idx], tCrSFA[None, None, k_block_idx]], # [A, SFA] + [tCrB[None, None, k_block_idx], tCrSFB[None, None, k_block_idx]], # [B, SFB] + accumulators, + ) + +.. code-block:: text + + Dense gemm call: + cute.gemm(tiled_mma, acc, tCrA[k], tCrB[k], acc) + + Block-scaled gemm call: + cute.gemm(tiled_mma, acc, [tCrA[k], tCrSFA[k]], [tCrB[k], tCrSFB[k]], acc) + ^^^^^^^^ ^^^^^^^^^ ^^^^^^^^ ^^^^^^^^^ + value scale value scale + (RMEM) (RMEM) (RMEM) (RMEM) + +Note that ``cute.filter_zeros`` is applied to the SF copy views because +the scale-factor SMEM layouts may contain padding zeros from the TMA +tiling. This strips the padded entries so the copy operates only on +valid elements. + +The epilogue (RMEM → SMEM → GMEM) is identical to a dense kernel. + + +See also: + +- Dense GEMM example (Ampere): ``examples/cute/ampere/kernel/dense_gemm/tensorop_gemm.py`` +- Block-scaled GEMM example (SM120a): ``examples/cute/blackwell_geforce/kernel/blockscaled_gemm/dense_blockscaled_gemm_persistent_pingpong.py`` +- Block-scaled layout utilities: ``cutlass.utils.blockscaled_layout`` +- SM120 helper utilities: ``cutlass.utils.blackwell_helpers`` diff --git a/python/CuTeDSL/cutlass/base_dsl/ast_preprocessor.py b/python/CuTeDSL/cutlass/base_dsl/ast_preprocessor.py index 052b5d8d4..1f1d8a480 100644 --- a/python/CuTeDSL/cutlass/base_dsl/ast_preprocessor.py +++ b/python/CuTeDSL/cutlass/base_dsl/ast_preprocessor.py @@ -720,7 +720,9 @@ class DSLPreprocessor(ast.NodeTransformer): offset = len(all_args) - len(func_ast.args.defaults) for i, default_node in enumerate(func_ast.args.defaults): ast_defaults[all_args[offset + i].arg] = default_node - for kwarg, kw_default in zip(func_ast.args.kwonlyargs, func_ast.args.kw_defaults): + for kwarg, kw_default in zip( + func_ast.args.kwonlyargs, func_ast.args.kw_defaults + ): if kw_default is not None: ast_defaults[kwarg.arg] = kw_default for param_name, default_val in params_with_defaults.items(): diff --git a/python/CuTeDSL/cutlass/base_dsl/dsl.py b/python/CuTeDSL/cutlass/base_dsl/dsl.py index 6d6ec9e57..c136b595a 100644 --- a/python/CuTeDSL/cutlass/base_dsl/dsl.py +++ b/python/CuTeDSL/cutlass/base_dsl/dsl.py @@ -1865,7 +1865,7 @@ class BaseDSL(metaclass=DSLSingletonMeta): sources = set(x.value for x in link_libraries_attributes) link_libraries = ( link_libraries - + ("," if len(link_libraries) > 0 else "") + + ("," if link_libraries and len(sources) > 0 else "") + ",".join(sources) ) self.compile_options.options[LinkLibraries] = LinkLibraries( diff --git a/python/CuTeDSL/cutlass/base_dsl/runtime/cuda.py b/python/CuTeDSL/cutlass/base_dsl/runtime/cuda.py index 37285456a..7d41754ff 100644 --- a/python/CuTeDSL/cutlass/base_dsl/runtime/cuda.py +++ b/python/CuTeDSL/cutlass/base_dsl/runtime/cuda.py @@ -88,6 +88,11 @@ def _get_gpu_arch_info(major: int, minor: int) -> tuple[str, str, list[str]]: "sm_120a", ["sm_120a"], ), # RTX PRO 6000 / RTX 50 Series + (12, 1): ( + "Blackwell", + "sm_121a", + ["sm_121a"], + ), # DGX Spark } return gpu_arch_map.get( (major, minor), ("Unknown", f"sm_{major}{minor}", [f"sm_{major}{minor}"]) diff --git a/python/CuTeDSL/cutlass/cute/core.py b/python/CuTeDSL/cutlass/cute/core.py index 84387e62d..004a2755f 100644 --- a/python/CuTeDSL/cutlass/cute/core.py +++ b/python/CuTeDSL/cutlass/cute/core.py @@ -3330,7 +3330,12 @@ def filter_zeros( if not isinstance(input, (Layout, Tensor)): raise TypeError(f"Expected layout or tensor as input, but got {type(input)=}") if isinstance(input, Tensor): - input = input.value + return _op_wrapper( + partial(_cute_ir.filter_zeros, target_profile=target_profile), + input, + loc=loc, + ip=ip, + ) return _cute_ir.filter_zeros(input, target_profile=target_profile, loc=loc, ip=ip) @@ -3388,7 +3393,7 @@ def filter( input.inner, input.offset, filter(input.outer, loc=loc, ip=ip) ) elif isinstance(input, _Tensor): - return _cute_ir.filter(input.value, loc=loc, ip=ip) + return _op_wrapper(_cute_ir.filter, input, loc=loc, ip=ip) else: return _cute_ir.filter(input, loc=loc, ip=ip) @@ -5020,10 +5025,9 @@ def local_partition( raise NotImplementedError( f"Index value should be 32-bit or smaller integer type, but got {index_val.type}" ) - return _cute_ir.local_partition( - input=target.value, - tiler=dice(tiler, proj), - index=index_val, + return _op_wrapper( + partial(_cute_ir.local_partition, tiler=dice(tiler, proj), index=index_val), + target, loc=loc, ip=ip, ) @@ -5114,11 +5118,9 @@ def local_tile( proj_val = _pack_coord(proj, loc=loc, ip=ip) proj = proj_val.type.attribute - return _cute_ir.local_tile( - input=input.value, - tile=tiler_val, - coord=coord_val, - proj=proj, + return _op_wrapper( + partial(_cute_ir.local_tile, tile=tiler_val, coord=coord_val, proj=proj), + input, loc=loc, ip=ip, ) diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py b/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py index bacbf63ee..2c0c32be6 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py @@ -21,6 +21,9 @@ __all__ = [ "MmaFP8Op", "MmaMXF4Op", "MmaMXF4NVF4Op", + "MmaMXF8Op", + "MmaMXF8F6F4Op", + "MXF8F6F4_SUPPORTED_PAIRS", # copy.py "LdMatrix8x8x16bOp", "LdMatrix16x8x8bOp", diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py b/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py index 4c51d7f59..dbdc296a4 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py @@ -224,7 +224,9 @@ class MmaSM120BlockScaledOp(MmaOp): admissible_archs = [ Arch.sm_120a, + Arch.sm_120f, Arch.sm_121a, + Arch.sm_121f, ] def __post_init__(self) -> None: @@ -239,29 +241,44 @@ class MmaSM120BlockScaledOp(MmaOp): "CUTE_DSL_ARCH set to sm_120a or sm_121a", suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", ) - if self.ab_dtype != Float4E2M1FN: + # (ab_dtype, shape_mnk) consistency: FP4 uses (16,8,64); FP8 uses (16,8,32). + if self.ab_dtype == Float4E2M1FN: + if self.shape_mnk != (16, 8, 64): + raise OpError( + self, + "expects the 'shape_mnk' Op parameter to be (16,8,64) for Float4E2M1FN", + ) + elif self.ab_dtype in (Float8E4M3FN, Float8E5M2): + if self.shape_mnk != (16, 8, 32): + raise OpError( + self, + "expects the 'shape_mnk' Op parameter to be (16,8,32) for Float8E4M3FN/Float8E5M2", + ) + else: raise OpError( self, - "expects the 'ab_dtype' Op parameter to be Float4E2M1FN", + "expects the 'ab_dtype' Op parameter to be Float4E2M1FN, Float8E4M3FN, or Float8E5M2", ) if self.acc_dtype != Float32: raise OpError( self, "expects the 'acc_dtype' Op parameter to be Float32", ) - if self.shape_mnk != (16, 8, 64): - raise OpError( - self, - "expects the 'shape_mnk' Op parameter to be (16,8,64)", - ) if self.sf_vec_size == 16: + # vec_size=16 is only valid for FP4 (NVFP4) with E4M3 scale. + if self.ab_dtype != Float4E2M1FN: + raise OpError( + self, + "expects the 'sf_vec_size' Op parameter to be 32 for Float8E4M3FN/Float8E5M2", + ) if self.sf_type != Float8E4M3FN: raise OpError( self, "expects the 'sf_type' Op parameter to be Float8E4M3FN", ) elif self.sf_vec_size == 32: + # vec_size=32 path uses UE8M0 scale for both FP4 (MXF4) and FP8 (MXF8). if self.sf_type != Float8E8M0FNU: raise OpError( self, @@ -275,7 +292,7 @@ class MmaSM120BlockScaledOp(MmaOp): def __str__(self) -> str: return ( - "warp-level MXF4/MXF4NVF4 MMA Operation" + "warp-level MXF4/MXF4NVF4/MXF8 MMA Operation" + f"\n A/B data type = {self.ab_dtype}" + f"\n Accumulator data type = {self.acc_dtype}" + f"\n Instruction shape MNK = {self.shape_mnk}" @@ -474,3 +491,214 @@ class MmaMXF4NVF4Op(MmaSM120BlockScaledOp): class MmaMXF4NVF4Trait(MmaBlockScaledTrait): pass + + +# +# MXF8 MMA +# + + +@dataclass(frozen=True) +class MmaMXF8Op(MmaSM120BlockScaledOp): + """ + MXF8 warp-level MMA Operation. + + See the `PTX documentation `__. + This Operation covers the instructions using the ``.e4m3`` / ``.e5m2`` qualifiers for the input operands. + .kind = {.kind::mxf8}; + .scale_vec_size = {.scale_vec::1X}; + .stype = {.ue8m0}; + """ + + descriptive_name = "warp-level MXF8 MMA Operation" + + def __init__( + self, + ab_dtype: Type[Numeric], + acc_dtype: Type[Numeric], + sf_type: Type[Numeric], + ) -> None: + super().__init__( + ab_dtype, + acc_dtype, + (16, 8, 32), + sf_type, + 32, + ) + + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "MmaMXF8Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM120BlockScaledType.get( + shape_mnk.type.attribute, + 32, + False, + self.ab_dtype.mlir_type, + self.ab_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.sf_type.mlir_type, + ) + return MmaMXF8Trait(make_atom(ty, loc=loc, ip=ip)) + + +class MmaMXF8Trait(MmaBlockScaledTrait): + pass + + +# +# MXF8F6F4 mixed-precision MMA (independent A/B dtypes) +# + + +MXF8F6F4_SUPPORTED_PAIRS = frozenset( + { + (Float4E2M1FN, Float8E4M3FN), + (Float4E2M1FN, Float8E5M2), + (Float8E4M3FN, Float4E2M1FN), + (Float8E5M2, Float4E2M1FN), + } +) + + +@dataclass(frozen=True) +class MmaMXF8F6F4Op(MmaOp): + """ + SM120 MXF8F6F4 mixed-precision warp-level block-scaled MMA Operation. + + Covers the PTX instructions using independent ``..`` + qualifiers (one of e2m1.e4m3, e2m1.e5m2, e4m3.e2m1, e5m2.e2m1): + + .kind = {.kind::mxf8f6f4}; + .scale_vec_size = {.scale_vec::1X}; + .stype = {.ue8m0}; + + A and B operand dtypes are independent. Same-dtype FP4/FP4 and FP8/FP8 + paths remain on ``MmaMXF4Op`` / ``MmaMXF4NVF4Op`` / ``MmaMXF8Op`` + respectively. Same-width mixed-FP8 (E4M3 + E5M2) and FP6 mixed pairs + are not supported. + """ + + a_dtype: Type[Numeric] + b_dtype: Type[Numeric] + acc_dtype: Type[Numeric] + sf_type: Type[Numeric] + + descriptive_name = "warp-level MXF8F6F4 mixed-precision MMA Operation" + + shape_mnk = (16, 8, 32) + sf_vec_size = 32 + use_sf_layout_TV = False + + admissible_archs = [ + Arch.sm_120a, + Arch.sm_121a, + ] + + def __post_init__(self) -> None: + # Verify arch + arch = BaseDSL._get_dsl().get_arch_enum() + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + if self.acc_dtype != Float32: + raise OpError( + self, + "expects the 'acc_dtype' Op parameter to be Float32", + ) + if self.sf_type != Float8E8M0FNU: + raise OpError( + self, + "expects the 'sf_type' Op parameter to be Float8E8M0FNU", + ) + # Reject same-dtype pairs explicitly (route to dedicated ops). + if self.a_dtype == self.b_dtype: + if self.a_dtype == Float4E2M1FN: + raise OpError( + self, + "same-dtype Float4E2M1FN/Float4E2M1FN is not supported by MmaMXF8F6F4Op; " + "use MmaMXF4Op (sf_vec_size=32) or MmaMXF4NVF4Op (sf_vec_size=16) instead", + ) + if self.a_dtype in (Float8E4M3FN, Float8E5M2): + raise OpError( + self, + "same-dtype FP8/FP8 is not supported by MmaMXF8F6F4Op; " + "use MmaMXF8Op instead", + ) + # Reject same-width mixed-FP8 (E4M3 + E5M2) explicitly. + fp8_dtypes = (Float8E4M3FN, Float8E5M2) + if self.a_dtype in fp8_dtypes and self.b_dtype in fp8_dtypes: + raise OpError( + self, + "same-width mixed-FP8 (Float8E4M3FN + Float8E5M2) is not supported; " + "supported MXF8F6F4 pairs are (Float4E2M1FN x Float8E4M3FN/Float8E5M2) " + "and the reverse", + ) + # Final allow-list check (catches FP6 and any other unsupported dtype). + if (self.a_dtype, self.b_dtype) not in MXF8F6F4_SUPPORTED_PAIRS: + raise OpError( + self, + f"unsupported (a_dtype, b_dtype) = ({self.a_dtype}, {self.b_dtype}) " + f"for MmaMXF8F6F4Op; supported pairs are " + f"{sorted(repr(p) for p in MXF8F6F4_SUPPORTED_PAIRS)}. " + f"FP6 mixed pairs are not supported.", + ) + + def __str__(self) -> str: + return ( + "warp-level MXF8F6F4 mixed-precision MMA Operation" + + f"\n A data type = {self.a_dtype}" + + f"\n B data type = {self.b_dtype}" + + f"\n Accumulator data type = {self.acc_dtype}" + + f"\n Instruction shape MNK = {self.shape_mnk}" + + f"\n Vector size = {self.sf_vec_size}" + + f"\n SF data type = {self.sf_type}" + ) + + def _verify_fragment_A( + self, + input: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: + pass + + def _verify_fragment_B( + self, + input: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: + pass + + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "MmaMXF8F6F4Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM120BlockScaledType.get( + shape_mnk.type.attribute, + self.sf_vec_size, + self.use_sf_layout_TV, + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.sf_type.mlir_type, + ) + return MmaMXF8F6F4Trait(make_atom(ty, loc=loc, ip=ip)) + + +class MmaMXF8F6F4Trait(MmaBlockScaledTrait): + pass diff --git a/python/CuTeDSL/cutlass/jax/primitive.py b/python/CuTeDSL/cutlass/jax/primitive.py index 0dab44306..2114d9ce2 100644 --- a/python/CuTeDSL/cutlass/jax/primitive.py +++ b/python/CuTeDSL/cutlass/jax/primitive.py @@ -21,7 +21,11 @@ from jax._src.interpreters import batching from .compile import get_or_compile_kernel, build_function_spec -from .types import cutlass_to_jax_layout_order, default_tensor_spec, TensorSpec +from .types import ( + cutlass_to_jax_layout_order, + default_tensor_spec, + TensorSpec, +) from .ffi import get_cutlass_call_ffi_name, is_ffi_registered, register_ffi @@ -77,8 +81,10 @@ def cutlass_call( objects with ``.shape`` and ``.dtype`` attributes) describing each output buffer. input_spec: A :class:`TensorSpec` or list thereof providing - layout/mode/divisibility hints for input tensors. ``None`` infers - defaults from each array. + layout/mode/divisibility hints for input tensors. ``None`` infers + defaults from each array. A ``TensorSpec`` with ``layout=None`` uses + and constrains row-major physical layout; use ``mode`` to remap + physical dimensions to the kernel's logical modes. output_spec: Same as *input_spec* but applied to output tensors. input_output_aliases: ``{input_index: output_index}`` mapping that allows an input buffer to alias an output, avoiding an extra copy. @@ -308,7 +314,9 @@ def cutlass_call_inner_p_impl( call_name = get_cutlass_call_ffi_name(allow_cuda_graph) - # Convert layout from CuTeDSL to JAX order as ffi_call expects this. + # Convert explicit layout constraints from CuTeDSL to JAX order. ``None`` is + # passed through intentionally: jax.ffi.ffi_call treats it as default + # row-major layout. input_layouts = [cutlass_to_jax_layout_order(s.layout) for s in input_spec_flat] output_layouts = [cutlass_to_jax_layout_order(s.layout) for s in output_spec_flat] diff --git a/python/CuTeDSL/cutlass/jax/testing.py b/python/CuTeDSL/cutlass/jax/testing.py index 8ba31a3e5..dbf68ecc8 100644 --- a/python/CuTeDSL/cutlass/jax/testing.py +++ b/python/CuTeDSL/cutlass/jax/testing.py @@ -15,12 +15,18 @@ import jax.numpy as jnp import cutlass.cute as cute from cutlass.cutlass_dsl import dsl_user_op -from typing import Optional +from typing import Optional, Sequence from cutlass._mlir import ir -def reorder_modes(src: str, target: str) -> tuple[int, ...]: - """Computes the mode given a source and target order.""" +def reorder_modes(src: Sequence[str], target: Sequence[str]) -> tuple[int, ...]: + """Compute a ``TensorSpec.mode`` from physical input order to kernel order. + + ``src`` names the JAX array's physical dimension order. ``target`` names the + logical mode order that the CuTe kernel expects. The returned tuple can be + passed as ``TensorSpec(mode=...)`` while leaving ``layout`` at its default + row-major value when the JAX buffer is physically row-major. + """ src = tuple(src) target = tuple(target) src_map = {} @@ -29,52 +35,64 @@ def reorder_modes(src: str, target: str) -> tuple[int, ...]: return tuple([src_map[d] for d in target]) -def gemm_a_major(d: str): - """Returns order for A tensor major mode.""" +def gemm_a_major(d: str) -> str: + """Return the physical JAX dimension order for an A tensor major mode. + + The returned string is not the kernel's canonical logical order. Use + :func:`gemm_a_mode` to map this physical order to kernel logical ``mkl``. + """ return {"k": "lmk", "m": "lkm"}[d] def gemm_a_mode(d: str) -> tuple[int, ...]: - """Returns mode for A tensor major mode.""" + """Return ``TensorSpec.mode`` for A, mapping physical order to logical ``mkl``.""" return reorder_modes(gemm_a_major(d), "mkl") -def gemm_b_major(d: str): - """Returns order for B tensor major mode.""" +def gemm_b_major(d: str) -> str: + """Return the physical JAX dimension order for a B tensor major mode. + + The returned string is not the kernel's canonical logical order. Use + :func:`gemm_b_mode` to map this physical order to kernel logical ``nkl``. + """ return {"k": "lnk", "n": "lkn"}[d] def gemm_b_mode(d: str) -> tuple[int, ...]: - """Returns mode for B tensor major mode.""" + """Return ``TensorSpec.mode`` for B, mapping physical order to logical ``nkl``.""" return reorder_modes(gemm_b_major(d), "nkl") -def gemm_c_major(d: str): - """Returns order for C tensor major mode.""" +def gemm_c_major(d: str) -> str: + """Return the physical JAX dimension order for a C/D tensor major mode. + + The returned string is not the kernel's canonical logical order. Use + :func:`gemm_c_mode` to map this physical order to kernel logical ``mnl``. + """ return {"n": "lmn", "m": "lnm"}[d] def gemm_c_mode(d: str) -> tuple[int, ...]: - """Returns mode for C tensor major mode.""" + """Return ``TensorSpec.mode`` for C/D, mapping physical order to logical ``mnl``.""" return reorder_modes(gemm_c_major(d), "mnl") -def gemm_a_shape(l, m, k, major) -> tuple[int, ...]: - """Returns shape for A tensor given major mode.""" +def gemm_a_shape(l: int, m: int, k: int, major: str) -> tuple[int, ...]: + """Return the physical row-major JAX shape for A with the requested major mode.""" assert major in ("k", "m") shape = (l, m, k) if major == "k" else (l, k, m) return shape -def gemm_b_shape(l, n, k, major) -> tuple[int, ...]: - """Returns shape for B tensor given major mode.""" +def gemm_b_shape(l: int, n: int, k: int, major: str) -> tuple[int, ...]: + """Return the physical row-major JAX shape for B with the requested major mode.""" assert major in ("k", "n") shape = (l, n, k) if major == "k" else (l, k, n) return shape -def gemm_c_shape(l, m, n, major) -> tuple[int, ...]: - """Returns shape for C tensor given major mode.""" +def gemm_c_shape(l: int, m: int, n: int, major: str) -> tuple[int, ...]: + """Return the physical row-major JAX shape for C/D with the requested major mode.""" assert major in ("m", "n") shape = (l, m, n) if major == "n" else (l, n, m) return shape diff --git a/python/CuTeDSL/cutlass/jax/types.py b/python/CuTeDSL/cutlass/jax/types.py index 3f22dfb1d..caf53ffaf 100644 --- a/python/CuTeDSL/cutlass/jax/types.py +++ b/python/CuTeDSL/cutlass/jax/types.py @@ -9,7 +9,7 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import Optional, Sequence +from typing import Any, Optional, Sequence from dataclasses import dataclass, field @@ -18,6 +18,7 @@ import jax.numpy as jnp import cutlass import cutlass.cute as cute +from cutlass.cute.core import IntValue from cutlass.cute.runtime import from_dlpack as _from_dlpack from cutlass.cute import AddressSpace from cutlass._mlir import ir @@ -58,35 +59,69 @@ DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNMENT = 256 class TensorSpec: """Specifies the layout and metadata for a JAX array passed to a CuTe kernel. - TensorSpec controls how a JAX array's dimensions are mapped to a cute.Tensor - during jit lowering, including stride ordering, mode permutation, and whether - shapes/strides are compiled as static constants. + TensorSpec controls how a JAX array's input dimensions are mapped to a + ``cute.Tensor`` during jit lowering, including compact stride ordering, + mode permutation, and whether shapes/strides are compiled as static + constants. The JAX bridge models tensors as compact layouts: runtime + strides are derived from runtime shapes using ``layout`` order rather than + loaded from a strided view descriptor. + + A useful way to choose a spec is to separate physical storage from logical + kernel modes: + + 1. First choose the public JAX array shape and its compact physical memory + order. If the buffer is a standard row-major JAX array, leave + ``layout=None``. ``cutlass_call`` will constrain the FFI operand/result + to row-major physical layout, matching the CuTe tensor strides that are + built from the default. + 2. Then use ``mode`` only when the kernel should see those input dimensions + in a different logical order. ``mode`` is applied after the compact + layout is built; it is not a request for JAX/XLA to transpose data. + + For example, a row-major JAX buffer shaped ``(expert_count, N, K)`` can be + presented to a kernel expecting logical ``(N, K, expert_count)`` with + ``TensorSpec(mode=(1, 2, 0))``. No explicit ``layout`` is needed because the + physical buffer is still ordinary row-major, and the FFI call will be + constrained accordingly. Use ``layout`` only when the compact physical + stride order itself differs from the default row-major order, such as a + column-major compact buffer. Attributes: layout: A minor-to-major stride ordering in CuTeDSL convention. ``layout[i]`` - gives the stride rank of dimension ``i``, where rank 0 means the smallest - (innermost) stride. For example, row-major order for a 3-D tensor is - ``(2, 1, 0)``. If ``None``, row-major is assumed. Use - :func:`jax_to_cutlass_layout_order` to convert from JAX's major-to-minor - convention. - mode: A permutation that maps the stride-ordered dimensions to the mode - positions of the resulting ``cute.Layout``. For example, ``mode=(2, 0, 1)`` - reorders an ``(M, K, L)`` layout into ``(K, L, M)`` mode order inside the - kernel. If ``None``, modes match the natural dimension order ``(0, 1, ..., N-1)``. + gives the compact physical stride rank of input dimension ``i``, + where rank 0 means the smallest (innermost) stride. For example, + row-major order for a 3-D tensor is ``(2, 1, 0)``. If ``None``, + row-major is assumed. Use :func:`jax_to_cutlass_layout_order` to + convert from JAX's major-to-minor convention. ``layout`` does not + change which logical mode a dimension represents; combine it with + ``mode`` when physical order and kernel-logical order differ. + mode: A permutation applied after the compact layout is constructed. It + selects input dimensions into the mode positions seen by the kernel. + For example, ``mode=(2, 0, 1)`` presents an input shaped + ``(M, K, L)`` to the kernel as logical ``(L, M, K)``. If ``None``, + modes match the natural input-dimension order ``(0, 1, ..., N-1)``. + ``mode`` changes the tensor layout object seen by CuTe code but + does not materialize a transpose or change the underlying buffer. static: If ``True``, shapes and strides are compiled as static ``constexpr`` values, which may enable additional compiler optimisations. Kernels that do not support static shapes will raise a compile error. Must be ``False`` when any dimension is symbolic (e.g. under ``jax.export``). ptr_assumed_align: Assumed byte alignment of the tensor's data pointer. Overrides the default of 256 bytes. Rarely needs to change. - divisibility: Optional per-mode divisibility hints. If a single int is passed - divisibility will be applied to the leading (stride=1) dimension only. + divisibility: Optional divisibility hints for input dimensions, in the + same order as the JAX array shape and before any ``mode`` reordering. + Positive hints constrain dynamic shape values and are propagated + through compact stride construction: a stride inherits the product + of the divisibilities for dimensions with lower stride rank. + Positive explicit hints take precedence over inferred concrete + extents. If a single int is passed, it is applied to the leading + compact dimension only, where ``layout[i] == 0``. """ # Minor-to-major stride ordering in CuTeDSL convention (layout[i] = stride rank # of dimension i, 0 = innermost). Defaults to row-major if None. layout: tuple[int, ...] | None = field(metadata=dict(static=True), default=None) - # Permutation from stride-ordered dimensions to cute.Layout mode positions. + # Permutation from input dimensions to cute.Layout mode positions. # Defaults to identity (0, 1, ..., N-1) if None. mode: tuple[int, ...] | None = field(metadata=dict(static=True), default=None) # If True, shapes and strides are embedded as compile-time constants. @@ -96,7 +131,7 @@ class TensorSpec: ptr_assumed_align: int = field( metadata=dict(static=True), default=DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNMENT ) - # Per-mode divisibility hints. + # Per-input-dimension divisibility hints, before mode reordering. divisibility: tuple[int | None, ...] | int | None = field( metadata=dict(static=True), default=None ) @@ -128,9 +163,10 @@ def row_major_layout(shaped): def default_tensor_mode(shaped): """Returns the identity mode permutation for an N-dimensional tensor. - The mode permutation maps stride-ordered dimensions to ``cute.Layout`` mode - positions. The default identity ``(0, 1, ..., N-1)`` leaves the mode order - unchanged relative to the dimension order. + The mode permutation maps JAX input dimensions to ``cute.Layout`` mode + positions after the compact layout has been constructed. The default + identity ``(0, 1, ..., N-1)`` leaves the mode order unchanged relative to + the JAX shape order. Args: shaped: An object with a ``.shape`` attribute, or a shape tuple/sequence. @@ -151,12 +187,22 @@ def default_tensor_spec(shaped) -> TensorSpec: TensorSpec(layout=(N-1, ..., 1, 0), mode=(0, 1, ..., N-1), divisibility=(D0, D1, ... DN-1)) This is appropriate for standard row-major (C-contiguous) JAX arrays that - do not require dimension reordering inside the kernel. + do not require dimension reordering inside the kernel. The resulting JAX + CuTe tensor is treated as compact: strides are derived from shapes using the + row-major layout order. - Divisibility hints are inferred only for concrete integer dimensions. - Symbolic dimensions always produce ``None`` for their slot; pass an - explicit ``TensorSpec`` with ``divisibility`` set if you need alignment - hints for symbolic shapes. + If the JAX buffer is row-major but the kernel expects a different logical + mode order, use an explicit :class:`TensorSpec` with ``mode`` set and leave + ``layout`` unset. ``cutlass_call`` still constrains the FFI buffer to + row-major layout in this case. For example, ``TensorSpec(mode=(1, 2, 0))`` + maps a physical ``(L, M, K)`` row-major input to a logical ``(M, K, L)`` + tensor. + + Divisibility hints are inferred only for concrete integer input dimensions. + Symbolic dimensions always produce ``None`` for their slot; pass an explicit + ``TensorSpec`` with ``divisibility`` set if you need alignment hints for + symbolic shapes or want a weaker explicit constraint than the concrete + extent. Args: shaped: An object with a ``.shape`` attribute, or a shape tuple/sequence. @@ -179,11 +225,12 @@ def default_tensor_spec(shaped) -> TensorSpec: def _expand_divisibility( divisibility, order: tuple[int, ...], ndim: int ) -> tuple[int | None, ...] | None: - """Expand a divisibility spec to a full per-dimension tuple. + """Expand a divisibility spec to a full per-input-dimension tuple. A bare ``int`` is placed at the leading-dimension slot (where ``order[i] == 0``, i.e. stride == 1) and ``None`` everywhere else. - A tuple is returned unchanged. ``None`` returns ``None``. + A tuple is already in JAX input-dimension order and is returned unchanged. + ``None`` returns ``None``. """ if divisibility is None or isinstance(divisibility, tuple): return divisibility @@ -268,7 +315,20 @@ def from_dlpack(array, assumed_align: int = DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNM return _from_dlpack(array, assumed_align=assumed_align) -def _validate_permutation(name: str, perm, shape): +def _assume_divisible_int( + value: Any, + divby: int, + *, + loc: ir.Location | None = None, + ip: ir.InsertionPoint | None = None, +) -> Any: + """Attach a divisibility assumption to an integer value without narrowing it.""" + if divby <= 1: + return value + return cute.assume(IntValue(value, loc=loc, ip=ip), divby=divby, loc=loc, ip=ip) + + +def _validate_permutation(name: str, perm: Sequence[int], shape: Sequence[Any]) -> None: if len(perm) != len(shape): raise ValueError(f"{name} must be same length as shape", perm, shape) for s in perm: @@ -292,9 +352,11 @@ class JaxArray: can be concrete or symbolic in the case of jax.export. 3. mem_space: The memory space of the tensor. Defaults to gmem. 4. assumed_align: The alignment of the tensor. Defaults to XLA alignment. - 5. order: Specifies the order of the shape to determine strides. - 6. mode: Specifies how to map ordered elements to the modes od a cute.Layout. + 5. order: Specifies the compact physical stride order of the shape. + 6. mode: Specifies how to map input dimensions to the logical modes seen by + the kernel after the compact layout is constructed. 7. static: If True, tensor shapes and strides are compiled statically. + 8. divisibility: Optional divisibility hints in input-dimension order. """ def __init__( @@ -381,6 +443,21 @@ class JaxArrayValue(JaxArray): ip: Optional[ir.InsertionPoint] = None, ): i32 = ir.IntegerType.get_signless(32) + + # Track the divisibility available for each input dimension. Explicit + # hints win; otherwise concrete dimensions contribute their known extent. + dim_divisibility = None + if self.divisibility is not None: + dim_divisibility = [] + for div_spec, static_s in zip(self.divisibility, self.shape): + if div_spec is not None and div_spec > 0: + dim_divisibility.append(div_spec) + elif isinstance(static_s, int): + dim_divisibility.append(static_s) + else: + dim_divisibility.append(1) + dim_divisibility = tuple(dim_divisibility) + pairs = sorted(zip(shape, order), key=lambda x: x[1]) # Compute strides for each element in order. @@ -395,28 +472,29 @@ class JaxArrayValue(JaxArray): for i in range(len(shape)): strides_ordered.append(strides[order[i]]) + if dim_divisibility is not None: + # A compact stride is the product of all dimensions with a lower + # stride order, so it inherits the product of their divisibility. + stride_divisibility = [] + for dim_order in order: + divby = 1 + for other_dim, other_order in enumerate(order): + if other_order < dim_order: + divby *= dim_divisibility[other_dim] + stride_divisibility.append(divby) + + strides_ordered = [ + _assume_divisible_int(s, divby, loc=loc, ip=ip) + for s, divby in zip(strides_ordered, stride_divisibility) + ] + # Shapes are expected to be int32 so truncate to that before creating layout shape_i32 = tuple(arith.trunci(i32, s) for s in shape) - - # Apply per-mode divisibility assumptions so the compiler can exploit alignment. - if self.divisibility is not None: - assumed = [] - for s32, div_spec, static_s in zip( - shape_i32, self.divisibility, self.shape - ): - if isinstance(static_s, int): - # Pure static shape is known even though a dynamic shape is - # used. We can assume the exact shape here. We keep the shape - # as a dynamic value to avoid breaking code that may expect - # a dynamic value. - assumed.append(cute.assume(s32, divby=static_s)) - elif div_spec is not None: - # Using a dynamic value so apply the div_spec if its provided. - assumed.append(cute.assume(s32, divby=div_spec)) - else: - # No divisibility specification for this shape - assumed.append(s32) - shape_i32 = tuple(assumed) + if dim_divisibility is not None: + shape_i32 = tuple( + _assume_divisible_int(s, divby, loc=loc, ip=ip) + for s, divby in zip(shape_i32, dim_divisibility) + ) return cute.make_layout(shape_i32, stride=tuple(strides_ordered)) diff --git a/python/CuTeDSL/cutlass/utils/__init__.py b/python/CuTeDSL/cutlass/utils/__init__.py index 5929b4520..6c4fdf005 100644 --- a/python/CuTeDSL/cutlass/utils/__init__.py +++ b/python/CuTeDSL/cutlass/utils/__init__.py @@ -84,6 +84,8 @@ from .tmem_allocator import ( from .layout import LayoutEnum +from .block import block_copy + from .mixed_input_helpers import ( TransformMode, scale_tma_partition, @@ -176,6 +178,7 @@ __all__ = [ "sm90", "sm100", "gemm", + "block_copy", "ClcDynamicPersistentTileSchedulerParams", "ClcDynamicPersistentTileScheduler", "print_latex", diff --git a/python/CuTeDSL/cutlass/utils/blackwell_helpers.py b/python/CuTeDSL/cutlass/utils/blackwell_helpers.py index bc75656fa..3ff569c3b 100644 --- a/python/CuTeDSL/cutlass/utils/blackwell_helpers.py +++ b/python/CuTeDSL/cutlass/utils/blackwell_helpers.py @@ -612,7 +612,7 @@ def get_tmem_load_op( def get_smem_layout_atom_ab( major_mode: OperandMajorMode, element_type: Type[Numeric], - smem_shape_mn_k: Tuple[int, int], + smem_shape_mn_k: cute.Tile, *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None, @@ -625,13 +625,16 @@ def get_smem_layout_atom_ab( :param element_type: The element type for the SMEM tensor. :type element_type: Type[Numeric] :param smem_shape_mn_k: The shape of the SMEM tensor. - :type smem_shape_mn_k: Tuple[int, int] + :type smem_shape_mn_k: cute.Tile :return: The SMEM layout atom kind :rtype: cutlass.cute.nvgpu.tcgen05.SmemLayoutAtomKind """ is_k_major = major_mode == OperandMajorMode.K - major_mode_size = smem_shape_mn_k[1] if is_k_major else smem_shape_mn_k[0] - + major_mode_size = ( + cute.size(smem_shape_mn_k, mode=[1]) + if is_k_major + else cute.size(smem_shape_mn_k, mode=[0]) + ) assert major_mode_size % 8 == 0 sw128_num_contiguous_bits = 1024 sw64_num_contiguous_bits = 512 @@ -711,6 +714,7 @@ def make_smem_layout( cute.append(smem_tile_shape, num_stages), order=(0, 1, 2) if is_k_major else (1, 0, 2), ) + return cute.coalesce(smem_layout, target_profile=(1, 1, 1), loc=loc, ip=ip) @@ -1956,12 +1960,35 @@ def thrfrg_SFA( """Thread-fragment scale factor A tensor for SM120 block-scaled MMA. Implements the ThrFrg partitioning for scale factor A according to the - corresponding C++ code. + corresponding C++ code in cutlass/include/cute/atom/mma_traits_sm120.hpp: + SFALayout for SM120 MXF4 16x8x64 uses K=64, SM120 MXF8F6F4 16x8x32 uses + K=32; the stride pattern ``((_8,_0,_1), _16)`` is shared. """ assert cute.rank(sfa_tensor) >= 2 atom_shape_mnk = tiled_mma.shape_mnk - atom_sfa_layout = cute.make_layout(shape=((2, 2, 8), 64), stride=((8, 0, 1), 16)) + # K-dim of the warp-MMA atom: FP4 -> 64, FP8 -> 32 (per mma_traits_sm120.hpp). + # For FP8 (atom_K=32) where mma_nsf=1, wrap K in a 2-tuple ``(atom_K, 1)`` + # so the layout's K mode keeps its 2D structure and the resulting fragment + # has the same rank as the FP4 path. For FP4 (atom_K=64) the original 1D + # layout already produces a 2D K decomposition through SMEM-layout + # composition, so we keep the original shape. + atom_K = atom_shape_mnk[2] + if atom_K == 32: + atom_sfa_layout = cute.make_layout( + shape=((2, 2, 8), (atom_K, 1)), + stride=((8, 0, 1), (16, 0)), + ) + elif atom_K == 64: + atom_sfa_layout = cute.make_layout( + shape=((2, 2, 8), atom_K), + stride=((8, 0, 1), 16), + ) + else: + raise ValueError( + f"thrfrg_SFA: unsupported atom_K={atom_K}; SM120 block-scaled atoms " + f"use atom_K=32 (mxf8/mxf8f6f4) or atom_K=64 (mxf4/mxf4nvf4)" + ) permutation_mnk = tiled_mma.permutation_mnk thr_layout_vmnk = tiled_mma.thr_layout_vmnk @@ -2000,12 +2027,32 @@ def thrfrg_SFB( """Thread-fragment scale factor B tensor for SM120 block-scaled MMA. Implements the ThrFrg partitioning for scale factor B according to the - corresponding C++ code. + corresponding C++ code in cutlass/include/cute/atom/mma_traits_sm120.hpp: + SFBLayout for SM120 MXF4 16x8x64 uses K=64, SM120 MXF8F6F4 16x8x32 uses + K=32; the stride pattern ``((_0,_1), _8)`` is shared. """ assert cute.rank(sfb_tensor) >= 2 atom_shape_mnk = tiled_mma.shape_mnk - atom_sfb_layout = cute.make_layout(shape=((4, 8), 64), stride=((0, 1), 8)) + # K-dim of the warp-MMA atom: FP4 -> 64, FP8 -> 32 (per mma_traits_sm120.hpp). + # See :func:`thrfrg_SFA` for the rationale behind the FP8-only + # ``(atom_K, 1)`` wrapping. + atom_K = atom_shape_mnk[2] + if atom_K == 32: + atom_sfb_layout = cute.make_layout( + shape=((4, 8), (atom_K, 1)), + stride=((0, 1), (8, 0)), + ) + elif atom_K == 64: + atom_sfb_layout = cute.make_layout( + shape=((4, 8), atom_K), + stride=((0, 1), 8), + ) + else: + raise ValueError( + f"thrfrg_SFB: unsupported atom_K={atom_K}; SM120 block-scaled atoms " + f"use atom_K=32 (mxf8/mxf8f6f4) or atom_K=64 (mxf4/mxf4nvf4)" + ) permutation_mnk = tiled_mma.permutation_mnk thr_layout_vmnk = tiled_mma.thr_layout_vmnk diff --git a/python/CuTeDSL/cutlass/utils/block.py b/python/CuTeDSL/cutlass/utils/block.py new file mode 100644 index 000000000..e2c53696d --- /dev/null +++ b/python/CuTeDSL/cutlass/utils/block.py @@ -0,0 +1,248 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + + +from cutlass.cutlass_dsl import dsl_user_op, CuTeDSL + +from cutlass.cute.typing import Tensor +from cutlass.cute.core import make_layout, filter_zeros +from cutlass.cute.atom import TiledCopy +from cutlass.cute.algorithm import copy +from cutlass.cute.nvgpu import tcgen05 +from cutlass.cute.nvgpu.cpasync.copy import ( + TmaCopyOp, + CopyBulkTensorTileG2SOp, + CopyBulkTensorTileG2SMulticastOp, +) +from cutlass.cute.nvgpu.cpasync.helpers import tma_partition +from cutlass.cute.nvgpu.tcgen05.copy import _S2TCopyBase +from typing import Any, Optional +from cutlass._mlir import ir + + +def _check_required_args( + required_args: list[str], kwargs: dict, condition: bool = True +) -> None: + if not condition: + return + for arg in required_args: + if arg not in kwargs: + raise ValueError(f"Argument {arg} is required.") + + +def _tma_copy_impl( + tiled_copy: TiledCopy, + src: Tensor, + dst: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, +) -> None: + """Internal implementation for TMA-based block-level copy.""" + # + # Handle tma_multicast argument + # + if "tma_multicast" in kwargs: + if not isinstance( + tiled_copy.op, + ( + CopyBulkTensorTileG2SOp, + ), + ): + raise ValueError( + "block_copy with tma_multicast expects a non-multicast G2S TMA copy atom " + "(CopyBulkTensorTileG2SOp) for compiler-driven multicast" + ) + # Mark as coming from block API + kwargs["tma_multicast"]["from_block_api"] = True + + # + # Check if required arguments are provided + # + is_bar_ptr_required = isinstance( + tiled_copy.op, + ( + CopyBulkTensorTileG2SOp, + CopyBulkTensorTileG2SMulticastOp, + ), + ) + _check_required_args(["tma_bar_ptr"], kwargs, is_bar_ptr_required) + + # + # TMA bulk tensor copies: partition via tma_partition + # + is_g2s = isinstance( + tiled_copy.op, + ( + CopyBulkTensorTileG2SOp, + ), + ) + stensor = dst if is_g2s else src + gtensor = src if is_g2s else dst + cta_coord = 0 + cta_layout = make_layout(1, loc=loc, ip=ip) + s_ptn, g_ptn = tma_partition( + tiled_copy, cta_coord, cta_layout, stensor, gtensor, loc=loc, ip=ip + ) + + s_ptn = filter_zeros(s_ptn) + g_ptn = filter_zeros(g_ptn) + + src_arg = g_ptn if is_g2s else s_ptn + dst_arg = s_ptn if is_g2s else g_ptn + return copy(tiled_copy, src_arg, dst_arg, loc=loc, ip=ip, **kwargs) + + +def _utccp_copy_impl( + tiled_copy: TiledCopy, + src: Tensor, + dst: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, +) -> None: + """Internal implementation for S2T (SMEM to TMEM) copy operations. + + This function abstracts the S2T copy pattern which involves: + 1. Filtering zeros from src (smem) and dst (tmem) tensors + 2. Creating a tiled copy using make_s2t_copy + 3. Partitioning source and destination + 4. Getting the SMEM descriptor tensor + 5. Executing the copy + + :param tiled_copy: The tiled copy for S2T operations. + :type tiled_copy: TiledCopy + :param src: The source tensor in shared memory. + :type src: Tensor + :param dst: The destination tensor in TMEM. + :type dst: Tensor + """ + # Filter zeros from src (smem) and dst (tmem) tensors + src_compact = filter_zeros(src) + dst_compact = filter_zeros(dst) + + # S2T has a single thread slice; election handled automatically in lowering + thr_copy = tiled_copy.get_slice(0) + + # Partition source and destination + src_partitioned = thr_copy.partition_S(src_compact, loc=loc, ip=ip) + dst_partitioned = thr_copy.partition_D(dst_compact, loc=loc, ip=ip) + + # Get SMEM descriptor tensor for the source + smem_desc_tensor = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy, src_partitioned, loc=loc, ip=ip + ) + + # Execute the copy + return copy(tiled_copy, smem_desc_tensor, dst_partitioned, loc=loc, ip=ip, **kwargs) + + +@dsl_user_op +@CuTeDSL.jit +def block_copy( + tiled_copy: TiledCopy, + src: Tensor, + dst: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, +) -> None: + """Performs a block-level copy operation. + + This function adds an abstraction layer over the `cute.copy` usage model by + allowing operands with layouts shaped like tiles to be passed directly. This + removes the need to manually partition. The API is designed to support multiple + copy kinds; currently TMA-based copies and S2T (SMEM to TMEM) copies are supported. + + **TMA copy requirements**: + + When using TMA-based tiled copies, the ``src`` and ``dst`` tensors must have + their first mode representing the TMATile, i.e. tensors shaped as ``(TMATile, Rest...)``. + For a rank-2 tensor with logical layout (e.g., ``(TILE_M, TILE_N)``), call + ``group_modes(tensor, 0, 2)`` before passing it to this function. + + **TMA multicast support**: + + For TMA-based copies that enable compiler-driven multicast in a 2D cluster, pass the + ``tma_multicast`` argument as a dict with the following keys: + + - ``cluster_shape``: a tuple of 2 integers ``(cluster_m, cluster_n)`` + representing the **2D cluster shape**. + - ``multicast_dim``: either ``"M"`` or ``"N"`` indicating which + cluster dimension the multicast happens along. + - ``use_2cta_mma_inst`` (optional): a ``bool`` indicating whether to + use 2CTA MMA instructions when the loaded data is consumed by MMA. + Defaults to ``False`` when omitted. + + **S2T (SMEM to TMEM) copy**: + + When using S2T copy operations (e.g., ``tcgen05.Cp4x32x128bOp``), the function + automatically handles the filtering, partitioning, and SMEM descriptor creation. + Pass a copy atom created with ``cute.make_copy_atom(tcgen05.Cp*Op(...), dtype)`` + along with source (SMEM) and destination (TMEM) tensors. + + Examples: + + .. code-block:: python + + # 1) TMA load without compiler-driven multicast + # Note: group_modes is called to make the first mode TMATile + block_copy(tma_atom_a, group_modes(tCgA_, 0, 2), group_modes(tCsA_, 0, 2), + tma_bar_ptr=tma_bar_ptr) + + # 2) TMA load with compiler-driven multicast along M in a (4,2) cluster + block_copy( + tma_atom_a, + group_modes(tCgA_, 0, 2), + group_modes(tCsA_, 0, 2), + tma_multicast={ + "cluster_shape": (4, 2), + "multicast_dim": "M", + "use_2cta_mma_inst": True, + }, + tma_bar_ptr=tma_bar_ptr, + ) + + # 3) TMA store + # Note that `tma_bar_ptr` and CTA params (`cta_coord` and `cta_layout`) + # are not needed for TMA store + block_copy(tma_atom_c, group_modes(tCsC_, 0, 2), group_modes(tCgC_, 0, 2)) + + # 4) S2T copy (SMEM to TMEM) + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), sf_dtype + ) + block_copy(copy_atom_s2t, tCsSF, tCtSF) + + :param tiled_copy: The tiled_copy or copy_atom of the current copy operation. + :type tiled_copy: TiledCopy + :param src: The source tensor. + :type src: Tensor + :param dst: The destination tensor. + :type dst: Tensor + :param tma_multicast: Optional dict for TMA multicast configuration with keys + ``cluster_shape``, ``multicast_dim``, and optionally + ``use_2cta_mma_inst``. + :type tma_multicast: dict, optional + """ + import cutlass # local import to avoid circular import at module load time + + if cutlass.const_expr(isinstance(tiled_copy.op, TmaCopyOp)): + return _tma_copy_impl(tiled_copy, src, dst, loc=loc, ip=ip, **kwargs) + elif cutlass.const_expr(isinstance(tiled_copy.op, _S2TCopyBase)): + return _utccp_copy_impl(tiled_copy, src, dst, loc=loc, ip=ip, **kwargs) + else: + raise NotImplementedError( + f"Copy op {type(tiled_copy.op).__name__} is not supported yet." + ) diff --git a/python/CuTeDSL/requirements-cu13.txt b/python/CuTeDSL/requirements-cu13.txt index 4fcd9996b..6434aeb2d 100644 --- a/python/CuTeDSL/requirements-cu13.txt +++ b/python/CuTeDSL/requirements-cu13.txt @@ -1,3 +1,3 @@ # Use `pip install -r requirements-cu13.txt` with the present file to install a # wheel consistent with the present state of the github repository -nvidia-cutlass-dsl[cu13]==4.4.2 +nvidia-cutlass-dsl[cu13]==4.5.1 diff --git a/python/CuTeDSL/requirements.txt b/python/CuTeDSL/requirements.txt index 2238c3db3..7e73af326 100644 --- a/python/CuTeDSL/requirements.txt +++ b/python/CuTeDSL/requirements.txt @@ -1,3 +1,3 @@ # Use `pip install -r requirements.txt` with the present file to install a # wheel consistent with the present state of the github repository -nvidia-cutlass-dsl==4.4.2 +nvidia-cutlass-dsl==4.5.1 diff --git a/python/cutlass_cppgen/__init__.py b/python/cutlass_cppgen/__init__.py index 0cbf25180..93c2d0fd2 100644 --- a/python/cutlass_cppgen/__init__.py +++ b/python/cutlass_cppgen/__init__.py @@ -133,7 +133,7 @@ def get_option_registry(): this._option_registry = OptionRegistry(device_cc()) return this._option_registry -this.__version__ = '4.5.0' +this.__version__ = '4.5.1' from cutlass_cppgen.backend import create_memory_pool from cutlass_cppgen.emit.pytorch import pytorch diff --git a/python/setup_cutlass.py b/python/setup_cutlass.py index 98d2e077c..e78aaa1ca 100644 --- a/python/setup_cutlass.py +++ b/python/setup_cutlass.py @@ -51,7 +51,7 @@ setup_pycute.perform_setup() setup( name='cutlass_cppgen', - version='4.5.0', + version='4.5.1', description='CUTLASS Pythonic Interface', package_dir={'': '.'}, packages=[ diff --git a/python/setup_library.py b/python/setup_library.py index c88e3320c..229e7f145 100644 --- a/python/setup_library.py +++ b/python/setup_library.py @@ -36,7 +36,7 @@ from setuptools import setup def perform_setup(): setup( name='cutlass_library', - version='4.5.0', + version='4.5.1', description='CUTLASS library generation scripts', packages=['cutlass_library'] ) diff --git a/python/setup_pycute.py b/python/setup_pycute.py index 7892f866c..21ee8e059 100644 --- a/python/setup_pycute.py +++ b/python/setup_pycute.py @@ -36,7 +36,7 @@ from setuptools import setup def perform_setup(): setup( name='pycute', - version='4.5.0', + version='4.5.1', description='Python implementation of CuTe', packages=['pycute'], ) diff --git a/tools/util/include/cutlass/util/reference/device/tensor_fill.h b/tools/util/include/cutlass/util/reference/device/tensor_fill.h index 3f7b5b629..a027be935 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_fill.h @@ -554,6 +554,112 @@ struct RandomUniformFunc { } }; +/// Computes an exponent-uniform random distribution for UE8M0 scale factors. +template <> +struct RandomUniformFunc { + + using Element = float_ue8m0_t; + using FloatType = float; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + int exp_min; + int exp_range; + int int_scale; ///< Retained for Params compatibility; exponent is integral. + double pnan; + int exclude_zero; ///< Retained for Params compatibility; unused for UE8M0. + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + CUTLASS_HOST_DEVICE + static int closest_log2_exp(FloatType value) { + using CUTLASS_CMATH_NAMESPACE :: log2; + using CUTLASS_CMATH_NAMESPACE :: nearbyint; + + // UE8M0 scale factors are strictly positive. Keep invalid lower bounds + // finite so callers using the generic [0, max] default do not produce NaN. + FloatType min_scale = FloatType(Element::bitcast(0x01)); + FloatType positive_value = value > FloatType(0) ? value : min_scale; + return int(nearbyint(log2(positive_value))); + } + + /// Construction of uniform RNG functor. + Params( + uint64_t seed_ = 0, + FloatType max_ = FloatType(1), + FloatType min_ = FloatType(0), + int int_scale_ = -1, + double pnan_ = 0, + int exclude_zero_ = -1 + ): + seed(seed_), + exp_min(closest_log2_exp(min_)), + exp_range(closest_log2_exp(max_) - closest_log2_exp(min_)), + int_scale(int_scale_), + pnan(pnan_), + exclude_zero(exclude_zero_) { + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomUniformFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + // Draw random float in [0.0, 1.0] to determine if element should be NaN. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) { + return Element(NAN); + } + } + + using CUTLASS_CMATH_NAMESPACE :: pow; + + FloatType rnd = random_uniform_float(&rng_state); + int exponent_count = params.exp_range + 1; + int exponent_offset = int(rnd * FloatType(exponent_count)); + exponent_offset = exponent_offset < exponent_count ? exponent_offset : exponent_count - 1; + FloatType exp = FloatType(params.exp_min + exponent_offset); + FloatType sf = FloatType(pow(FloatType(2), exp)); + + return Element(sf); + } +}; + /// Computes a random Gaussian distribution template struct RandomUniformFunc> { @@ -763,6 +869,16 @@ struct TensorFillRandomUniformFunc { } }; +template +struct UniformDistributionValueType { + using Type = typename RealType::Type; +}; + +template <> +struct UniformDistributionValueType { + using Type = float; +}; + } // namespace detail /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -774,8 +890,10 @@ template < void TensorFillRandomUniform( TensorView view, ///< destination tensor uint64_t seed, ///< seed for RNG - typename RealType::Type max = Element(1), ///< upper bound of distribution - typename RealType::Type min = Element(0), ///< lower bound for distribution + typename detail::UniformDistributionValueType::Type max = + typename detail::UniformDistributionValueType::Type(1), ///< upper bound of distribution + typename detail::UniformDistributionValueType::Type min = + typename detail::UniformDistributionValueType::Type(0), ///< lower bound for distribution int bits = -1, ///< If non-negative, specifies number of fractional bits that /// are not truncated to zero. Permits reducing precision of /// data. @@ -805,8 +923,8 @@ void BlockFillRandomUniform( Element *ptr, size_t capacity, uint64_t seed, ///< seed for RNG - typename RealType::Type max, ///< upper bound of distribution - typename RealType::Type min, ///< lower bound for distribution + typename detail::UniformDistributionValueType::Type max, ///< upper bound of distribution + typename detail::UniformDistributionValueType::Type min, ///< lower bound for distribution int bits = -1, ///< If non-negative, specifies number of fractional bits that /// are not truncated to zero. Permits reducing precision of /// data. @@ -1768,6 +1886,7 @@ void TensorFillRandom( ) { using Real = typename RealType::Type; + using UniformReal = typename detail::UniformDistributionValueType::Type; if (dist.kind == Distribution::Gaussian) { TensorFillRandomGaussian( @@ -1782,8 +1901,8 @@ void TensorFillRandom( TensorFillRandomUniform( view, seed, - static_cast(dist.uniform.max), - static_cast(dist.uniform.min), + static_cast(dist.uniform.max), + static_cast(dist.uniform.min), dist.int_scale, dist.uniform.pnan, exclude_zero, @@ -1830,6 +1949,7 @@ void BlockFillRandom( cudaStream_t stream = nullptr) { using Real = typename RealType::Type; + using UniformReal = typename detail::UniformDistributionValueType::Type; if (dist.kind == Distribution::Gaussian) { BlockFillRandomGaussian( @@ -1846,8 +1966,8 @@ void BlockFillRandom( ptr, capacity, seed, - static_cast(dist.uniform.max), - static_cast(dist.uniform.min), + static_cast(dist.uniform.max), + static_cast(dist.uniform.min), dist.int_scale, dist.uniform.pnan, stream); diff --git a/tools/util/include/cutlass/util/reference/host/tensor_fill.h b/tools/util/include/cutlass/util/reference/host/tensor_fill.h index 8c1bf7e7b..5967c7541 100644 --- a/tools/util/include/cutlass/util/reference/host/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/host/tensor_fill.h @@ -658,6 +658,74 @@ public: } }; +/// Computes an exponent-uniform random distribution for UE8M0 scale factors. +template <> +struct RandomUniformFunc { + + using Element = float_ue8m0_t; + + uint64_t seed; + int exp_min; + int exp_range; + int int_scale; ///< Retained for Params compatibility; exponent is integral. + + double pnan; +private: + using engine_type = std::mt19937; +public: + engine_type bernoulli_rnd; + std::bernoulli_distribution bernoulli_dist; + + bool exclude_zero; ///< Retained for Params compatibility; unused for UE8M0. + + static int closest_log2_exp(double value) { + // UE8M0 scale factors are strictly positive. Keep invalid lower bounds + // finite so callers using the generic [0, max] default do not produce NaN. + double min_scale = double(Element::bitcast(0x01)); + double positive_value = value > 0.0 ? value : min_scale; + return int(std::nearbyint(std::log2(positive_value))); + } + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1, + double pnan_ = 0, + bool exclude_zero_ = false + ): + seed(seed_), + exp_min(closest_log2_exp(min_)), + exp_range(closest_log2_exp(max) - closest_log2_exp(min_)), + int_scale(int_scale_), + pnan(pnan_), + bernoulli_rnd{static_cast(seed_)}, + bernoulli_dist(pnan_), + exclude_zero(exclude_zero_) + { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + Element operator()() { + + // Sample from NaN distribution. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) { + return Element(NAN); + } + } + + double rnd = double(std::rand()) / double(RAND_MAX); + int exponent_count = exp_range + 1; + int exponent_offset = int(rnd * double(exponent_count)); + exponent_offset = exponent_offset < exponent_count ? exponent_offset : exponent_count - 1; + double sf = std::pow(2.0, double(exp_min + exponent_offset)); + + return Element(sf); + } +}; + /// Partial specialization for initializing a complex value. template struct RandomUniformFunc > {