v4.5.1 update. (#3237)

This commit is contained in:
Junkai-Wu
2026-05-19 10:35:08 +08:00
committed by GitHub
parent e406c186f5
commit 982cb9e718
42 changed files with 6487 additions and 336 deletions

View File

@@ -2,6 +2,28 @@
# CUTLASS 4.x
## [4.5.1](https://github.com/NVIDIA/cutlass/releases/tag/v4.5.1) (2026-05-15)
### CuTe DSL
* Bug fixing and improvements
- Fixed following issues:
https://github.com/NVIDIA/cutlass/issues/3219
https://github.com/NVIDIA/cutlass/issues/3218
https://github.com/NVIDIA/cutlass/issues/3212
https://github.com/NVIDIA/cutlass/issues/3210
https://github.com/NVIDIA/cutlass/issues/3208
https://github.com/NVIDIA/cutlass/issues/3201
https://github.com/NVIDIA/cutlass/issues/3227
- Fixed Jax int64 stride divisibility issue
- Fixed issues for SM120 blockscaled MMAs
- added missing MXFP8MMAOP and MXF8F6F4MMAOP for sm120.
### CUTLASS C++
* Fix SM100 F8F6F4 SS MMA (1SM and 2SM) traits to use typed op templates.
* Add UE8M0 (uniform exponent distribution) initialization support in tensor fill utilities.
* Add `cvt.rn.bf16x2.e4m3x2` conversion instruction support to `numeric_conversion.h`.
* Update [example 93](https://github.com/NVIDIA/cutlass/tree/main/examples/93_blackwell_low_latency_gqa) with paged KV cache support for Blackwell low-latency GQA.
## [4.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.5.0) (2026-05-01)
### CuTe DSL
@@ -20,7 +42,7 @@
- Improved source code correlation for profiling/debugging
- Fixed an aarch64 segfault issue with tvm-ffi
- Re-organization for CuTe DSL examples/tutorials for better discoverability
* More examples of authorizing peak-performance kernels
- MOE examles
- A new style of grouped-gemm that aligns to torch's grouped_mm and scaled_groued_mm interface.

View File

@@ -1,9 +1,9 @@
![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# Overview
# CUTLASS 4.5.0
# CUTLASS 4.5.1
_CUTLASS 4.5.0 - May 2026_
_CUTLASS 4.5.1 - May 2026_
CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM)
and related computations at all levels and scales within CUDA. It incorporates strategies for
@@ -61,6 +61,17 @@ To get started quickly - please refer :
- Improved source code correlation for profiling/debugging
- Fixed an aarch64 segfault issue with tvm-ffi
- Re-organization for CuTe DSL examples/tutorials for better discoverability
- Fixed following issues:
https://github.com/NVIDIA/cutlass/issues/3219
https://github.com/NVIDIA/cutlass/issues/3218
https://github.com/NVIDIA/cutlass/issues/3212
https://github.com/NVIDIA/cutlass/issues/3210
https://github.com/NVIDIA/cutlass/issues/3208
https://github.com/NVIDIA/cutlass/issues/3201
https://github.com/NVIDIA/cutlass/issues/3227
- Fixed Jax int64 stride divisibility issue
- Fixed issues for SM120 blockscaled MMAs
- added missing MXFP8MMAOP and MXF8F6F4MMAOP for sm120.
* More examples of authorizing peak-performance kernels
- MOE examles
@@ -90,11 +101,13 @@ To get started quickly - please refer :
* Add [example 95](https://github.com/NVIDIA/cutlass/tree/main/examples/95_blackwell_gemm_green_context) to support green context SM partition
- Enables launching GEMM on stream with partial SM allocation.
* Add [Snake](https://github.com/NVIDIA/cutlass/blob/main/test/unit/epilogue/thread/activation.cu#L409) activation functor for EVT.
* Fix SM100 F8F6F4 SS MMA (1SM and 2SM) traits to use typed op templates.
* Add UE8M0 (uniform exponent distribution) initialization support in tensor fill utilities.
* Add `cvt.rn.bf16x2.e4m3x2` conversion instruction support to `numeric_conversion.h`.
* Update [example 93](https://github.com/NVIDIA/cutlass/tree/main/examples/93_blackwell_low_latency_gqa) with paged KV cache support for Blackwell low-latency GQA.
* Fix some kernel issues:
- Fix l2_capacity=0 handling in Blackwell SM100/SM120 kernel templates
- Fix CUTLASS clang build issues
- Fix atomicCAS read-modify-write loop in `ConstSubbyteReference`
- Replace `__nv_atomic_load_n` with `volatile` for CUDA 11.4 compatibility in subbyte reference
- Remove `PipelineStorage` shadowing in SM100 complex epilogue
- Fix build issue in SM90 epilogue fusion visitor TMA warpspecialized
* Fix some profiler issues:

View File

@@ -75,19 +75,17 @@ CUTE_HOST_DEVICE constexpr
auto
to_tiled_mma_sm100_ts(
TiledMMA<MMA_Atom<
MMA_Traits<SM100_MMA_F8F6F4_SS, a_type, b_type, c_type,
cute::C<M>, cute::C<N>,
cute::integral_constant<UMMA::Major, a_major>,
cute::integral_constant<UMMA::Major, b_major>,
cute::integral_constant<UMMA::ScaleIn, a_neg>,
cute::integral_constant<UMMA::ScaleIn, b_neg>>,
SM100_MMA_F8F6F4_SS<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg>,
TAs...>, TMs...>) {
return TiledMMA<MMA_Atom<
MMA_Traits<SM100_MMA_F8F6F4_TS<a_type, b_type, c_type,
SM100_MMA_F8F6F4_TS<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg, UMMA::Saturate::False>>,
a_neg, b_neg, UMMA::Saturate::False>,
TAs...>, TMs...>{};
}

View File

@@ -31,6 +31,9 @@ if (NOT MSVC AND CUTLASS_NVCC_ARCHS MATCHES "100a|100f|103a|103f")
cutlass_example_add_executable(
93_blackwell_low_latency_gqa
tgv_gqa.cu
common.cuh
tgv_gqa.cuh
tgv_gqa_paged.cuh
)
endif()

View File

@@ -0,0 +1,140 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cuda_runtime.h"
#include <cutlass/cutlass.h>
#include <cute/tensor.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <iostream>
#include <type_traits>
#ifndef gpuErrChk
#define gpuErrChk(ans) { gpuAssert2((ans), __FILE__, __LINE__); }
inline void gpuAssert2(cudaError_t code, const char *file, int line, bool abort=true) {
if (code != cudaSuccess) {
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
if (abort) exit(code);
}
}
#endif
namespace TGV {
using namespace cute;
// Store value to remote shared memory in the cluster
CUTE_DEVICE void
store_shared_remote_f32(float value, uint32_t dsmem_addr, uint32_t remote_barrier_addr) {
asm volatile("st.async.shared::cluster.mbarrier::complete_tx::bytes.f32 [%0], %1, [%2];"
: : "r"(dsmem_addr), "f"(value), "r"(remote_barrier_addr));
}
// given a smem tensor, return the dsmem tensor for the given rank, the tensor addr is in smem addr space (not generic addr space)
template <class Tensor>
CUTE_DEVICE auto
get_dsmem_tensor(Tensor tensor, int rank) {
using T = typename decltype(tensor)::value_type;
// tensor.data().get() is the smem addr in the generic addr space, in the generic addr space a region is reserved for smem
// doing ld/st to this region of the generic addr space will be converted into ld.shared/st.shared to the smem addr space by the compiler
// the mapa (and many inline ptx) instruction's input and output addr are in the smem/dsmem addr space, so we need to explicitly convert from generic to shared addr space
uint32_t smem_addr = __cvta_generic_to_shared(tensor.data().get()); // smem addr space
// mapa to get the dsmem addr of this tensor in another CTA
uint32_t dsmem_addr = set_block_rank(smem_addr, rank); // smem addr space
return make_tensor(make_smem_ptr((T*)dsmem_addr), tensor.layout());
}
// copied from SM100::TMEM::LOAD::copy_unpack cutlass/include/cute/atom/copy_traits_sm100.hpp
// what it does is given a tmem address, load the data into rmem tensor with the given tcgen05.ld copy op
template <
class CopyOp,
class TD, class DLayout>
CUTLASS_DEVICE void
tmem_load(
uint32_t tmem_addr,
Tensor<TD,DLayout>& dst
) {
static_assert(is_rmem<TD>::value, "Expected RMEM dst.");
using RegTypeDst = typename remove_extent<typename CopyOp::DRegisters>::type;
Tensor rD = recast<RegTypeDst>(dst);
constexpr int RegNumDst = extent<typename CopyOp::DRegisters>::value;
CUTE_STATIC_ASSERT_V(size(rD) == Int<RegNumDst>{},
"The tcgen05.ld CopyOp's size does not match the destination tensor size.");
detail::explode(CopyOp::copy,
&tmem_addr, seq<0>{},
rD, make_seq<RegNumDst>{});
}
// copied from SM100::TMEM::STORE::copy_unpack cutlass/include/cute/atom/copy_traits_sm100.hpp
// what it does is given a tmem address, store the data in rmem tensor to the tmem address with the given tcgen05.st copy op
template <
class CopyOp,
class TS, class SLayout>
CUTLASS_DEVICE void
tmem_store(
Tensor<TS,SLayout>& src,
uint32_t tmem_addr
) {
static_assert(is_rmem<TS>::value, "Expected RMEM src.");
using RegTypeSrc = typename remove_extent<typename CopyOp::SRegisters>::type;
Tensor rS = recast<RegTypeSrc>(src);
constexpr int RegNumSrc = extent<typename CopyOp::SRegisters>::value;
CUTE_STATIC_ASSERT_V(size(rS) == Int<RegNumSrc>{},
"The tcgen05.st CopyOp's size does not match the source tensor size.");
detail::explode(CopyOp::copy,
rS, make_seq<RegNumSrc>{},
&tmem_addr, seq<0>{});
}
// issue cp.async to load 4 bytes (one int) from gmem to smem
CUTLASS_DEVICE void
cp_async(
int* gmem_addr,
int* smem_addr
) {
uint32_t smem_int_ptr = cute::cast_smem_ptr_to_uint(smem_addr);
asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n"
:: "r"(smem_int_ptr),
"l"(gmem_addr),
"n"(sizeof(int)));
}
} // namespace TGV

View File

@@ -1,6 +1,18 @@
# Blackwell Low Latency GQA
This example introduces TGV GQA, a CuTe C++-based Blackwell kernel optimized for low latency (low batch) generation phase GQA.
The example ships two variants:
- `tgv_gqa.cuh` — contiguous KV cache (default, `--mode 0`).
- `tgv_gqa_paged.cuh` — paged KV cache (`--mode 1`). Layout matches a typical paged-attention serving runtime: a combined KV
buffer of shape `(num_pages_total, 2, Page_Size, kvH, dH)` (BS folded into `num_pages_total`, mode-1 selects K vs V) plus a
`(kvL/Page_Size, BS)` page table that maps `(bs, per_batch_page_idx) -> physical page id`. The example harness builds the
page table host-side using a Fisher-Yates shuffle to stress non-contiguous mappings; replacing that with another placement
policy needs no kernel changes.
`common.cuh` holds the shared inline PTX wrappers used by both variants (`cp_async`, `tmem_load`/`tmem_store`,
`store_shared_remote_f32`, `get_dsmem_tensor`).
To compile and run this example:
```bash
# in cutlass top level directory
@@ -8,7 +20,10 @@ mkdir build && cd build
cmake .. -DCUTLASS_NVCC_ARCHS=100a -DCUTLASS_ENABLE_TESTS=OFF -DCUTLASS_ENABLE_EXAMPLES=ON -DCUTLASS_ENABLE_LIBRARY=OFF
cd examples/93_blackwell_low_latency_gqa
make
# contiguous KV cache (default)
./93_blackwell_low_latency_gqa --kvL 8192 --kvH 8 --qH 64 --BS 1
# paged KV cache
./93_blackwell_low_latency_gqa --kvL 8192 --kvH 8 --qH 64 --BS 1 --mode 1
```
Supported configs are:
@@ -20,11 +35,11 @@ Supported configs are:
- Flash decoding, configurable number of splits
- Cluster reduction with configurable number of reduction cta
- Attention sink and sliding window
- Paged KV cache (`--mode 1`)
Unsupported features are:
- Persistent schedule
- MTP
- Paged KV cache
## Kernel Design

View File

@@ -40,8 +40,14 @@
kvL is max_seq_len, seq_lens[BS] is the actual seq len for each batch
sinks has shape (qHLocal * kvH), i.e. one sink per q head
--mode 1 enables the paged variant:
Combined KV cache shape (num_pages_total, 2, Page_Size, kvH, dH), with BS folded into num_pages_total
and KV(_,0,_,_,_) = K, KV(_,1,_,_,_) = V.
page_table shape (kvL/Page_Size, BS), entry [p, bs] = physical page id of batch bs's per-batch page p.
Example usage:
$ ./examples/93_blackwell_low_latency_gqa --kvL 8192 --kvH 8 --qH 64 --BS 1
$ ./examples/93_blackwell_low_latency_gqa --kvL 8192 --kvH 8 --qH 64 --BS 1 --mode 1
*/
// Standard library includes
@@ -51,6 +57,8 @@
#include <cmath>
#include <iostream>
#include <ctime>
#include <numeric>
#include <utility>
#include <getopt.h>
#include <cuda_runtime.h>
@@ -68,7 +76,9 @@
// CuTe includes
#include <cute/tensor.hpp> // CuTe tensor implementation
#include "common.cuh"
#include "tgv_gqa.cuh"
#include "tgv_gqa_paged.cuh"
using namespace cute;
@@ -411,9 +421,23 @@ struct ProblemStride {
int stride_O_qL;
int stride_O_dH;
int stride_O_BS;
// Combined KV (paged mode) layout: (num_pages_total, 2, Page_Size, kvH, dH).
// BS is folded into num_pages_total = BS * kvL / Page_Size (no explicit BS mode).
// Mode 1 is the K/V selector; harness picks dH-innermost packed strides below.
int stride_KV_pages;
int stride_KV_KV;
int stride_KV_ps;
int stride_KV_kvH;
int stride_KV_dH;
// Page table (paged mode) layout: (kvL/Page_Size, BS), entry [p, bs] = physical page id of batch bs's per-batch
// page p. Mode 0 is innermost in memory (per-batch page id, stride 1); mode 1 advances by pages_per_batch across batches.
int stride_PT_p;
int stride_PT_BS;
};
ProblemStride make_gqa_stride(int kvH, int qHLocal, int qL, int kvL, int dH, int BS) {
ProblemStride make_gqa_stride(int kvH, int qHLocal, int qL, int kvL, int dH, int BS, int Page_Size) {
ProblemStride stride;
// Q shape ((qHLocal, qL), dH, kvH, BS), where dH is contiguous
@@ -446,6 +470,20 @@ ProblemStride make_gqa_stride(int kvH, int qHLocal, int qL, int kvL, int dH, int
stride.stride_O_dH = 1;
stride.stride_O_BS = kvH * qHLocal * dH * qL;
// Combined KV (paged) shape (num_pages_total, 2, Page_Size, kvH, dH), dH innermost contiguous.
// BS is folded into num_pages_total; the (bs_idx, per_batch_page_idx) -> physical page mapping lives in the
// page_table tensor (see stride_PT_*). Strides slowest -> fastest:
// num_pages_total, KV (K/V selector), Page_Size, kvH, dH.
stride.stride_KV_dH = 1;
stride.stride_KV_kvH = dH;
stride.stride_KV_ps = kvH * dH;
stride.stride_KV_KV = Page_Size * kvH * dH;
stride.stride_KV_pages = 2 * Page_Size * kvH * dH;
// Page table shape (kvL/Page_Size, BS); per-batch page idx is mode 0 (contiguous, stride 1), batch is mode 1 (advances by pages_per_batch).
stride.stride_PT_p = 1;
stride.stride_PT_BS = kvL / Page_Size;
return stride;
}
@@ -459,17 +497,25 @@ public:
static constexpr int CTA_qL = 1;
static constexpr int CTA_kvL = 128;
static constexpr int CTA_dH = 64;
// Page_Size only used by gqa_paged (mode 1). Page_Size must divide CTA_kvL; CTA_kvL/Page_Size = pages per CTA tile.
static constexpr int Page_Size = 32;
static constexpr int BMM1_DMA_Stage = 3;
static constexpr int BMM2_DMA_Stage = 3;
// Page-idx staging (mode 1 only). Num_Page_Idx_Per_Stage must be a multiple of CTA_kvL/Page_Size
// so a DMA stage's pages live in one pi stage. Page_Idx_Stage = pipeline depth on the page-idx side.
static constexpr int Page_Idx_Stage = 2;
static constexpr int Num_Page_Idx_Per_Stage = 8 * (CTA_kvL / Page_Size);
static constexpr int MaxSplits = 8;
static constexpr int NumReductionCTA = 8;
static constexpr bool NoSink = true;
static constexpr bool VarSeqLens = false;
private:
int kvH_, qHLocal_, qL_, kvL_, dH_, BS_;
float softmax_scale_;
ProblemStride stride_;
int sliding_window_size_;
int mode_; // 0: gqa, 1: gqa_paged
// Host vectors
thrust::host_vector<TypeQKV> host_Q_;
@@ -480,17 +526,82 @@ private:
thrust::host_vector<int> host_seq_lens_;
thrust::host_vector<TypeAcc> host_sinks_;
// Device vectors
// Device vectors
thrust::device_vector<TypeQKV> device_Q_;
thrust::device_vector<TypeQKV> device_K_;
thrust::device_vector<TypeQKV> device_V_;
// Combined KV cache for paged mode (mode_ == 1).
// Layout: (num_pages_total, 2, Page_Size, kvH, dH) with dH innermost; BS is folded into num_pages_total.
// The (bs, per_batch_page_idx) -> physical page mapping is held in device_page_table_ (built by
// build_random_page_table: a Fisher-Yates shuffle of [0, num_pages_total) sliced per batch).
// KV(_,0,_,_,_) is K, KV(_,1,_,_,_) is V.
thrust::device_vector<TypeQKV> device_KV_;
// Page table for paged mode. Logical shape (kvL/Page_Size, BS); entry [p, bs] is the physical page id for
// batch bs and per-batch page index p. Padded by Num_Page_Idx_Per_Stage ints at the tail so the device's
// last-pi-stage cp.async never reads past the allocation.
thrust::device_vector<int> device_page_table_;
thrust::device_vector<TypeO> device_O_;
thrust::device_vector<int> device_seq_lens_;
thrust::device_vector<TypeAcc> device_sinks_;
// Build the host-side page table: tail-padded buffer with shape (kvL/Page_Size, BS) and contents = a random
// per-batch slice of a Fisher-Yates shuffle of [0, num_pages_total). Each batch's slice length is
// seq_len[bs]/Page_Size, so total pages assigned = sum(seq_len_pages) <= num_pages_total. Tail padding
// (Num_Page_Idx_Per_Stage ints) covers the device's last-pi-stage cp.async OOB read.
thrust::host_vector<int> build_random_page_table(int pages_per_batch, int num_pages_total) {
thrust::host_vector<int> host_page_table(num_pages_total + Num_Page_Idx_Per_Stage, 0);
auto host_tensor_page_table = make_tensor(host_page_table.data(),
make_layout(make_shape(pages_per_batch, BS_),
make_stride(stride_.stride_PT_p, stride_.stride_PT_BS)));
std::vector<int> perm(num_pages_total);
std::iota(perm.begin(), perm.end(), 0);
for (int i = num_pages_total - 1; i > 0; --i) {
std::swap(perm[i], perm[rand() % (i + 1)]);
}
int perm_offset = 0;
for (int bs = 0; bs < BS_; ++bs) {
// ceil_div: a partial tail page (seq_len % Page_Size != 0) still needs a physical page assigned
// so the kernel's address-by-page lookup for positions [floor(seq_len/Page_Size)*Page_Size, seq_len)
// hits real packed K/V data. Positions past seq_len are masked by the kernel and don't contribute.
int seq_len_pages = cutlass::ceil_div(host_seq_lens_[bs], Page_Size);
for (int p = 0; p < seq_len_pages; ++p) {
host_tensor_page_table(p, bs) = perm[perm_offset + p];
}
perm_offset += seq_len_pages;
}
return host_page_table;
}
// Pack the flat per-batch K/V buffers into the combined paged KV layout, using the page table for the
// (bs, per_batch_page) -> physical page mapping. Bounded by seq_len_pages per batch so out-of-seq-len entries
// (which the page table doesn't populate) aren't dereferenced.
template <class HostTensorPageTable, class HostTensorK, class HostTensorV, class HostTensorKV>
void pack_combined_kv(HostTensorPageTable const& host_tensor_page_table,
HostTensorK const& host_tensor_K, HostTensorV const& host_tensor_V,
HostTensorKV& host_tensor_KV) {
for (int bs = 0; bs < BS_; ++bs) {
// ceil_div pages so the partial tail page (when seq_len isn't page-aligned) is packed.
// We pack the full Page_Size for the tail page; positions past seq_len carry whatever
// initialize_tensor wrote into the K/V buffers, but the kernel masks those out.
int seq_len_pages = cutlass::ceil_div(host_seq_lens_[bs], Page_Size);
for (int p = 0; p < seq_len_pages; ++p) {
int global_page = host_tensor_page_table(p, bs);
for (int ps = 0; ps < Page_Size; ++ps) {
int kvl = p * Page_Size + ps;
for (int kvh = 0; kvh < kvH_; ++kvh) {
for (int dh = 0; dh < dH_; ++dh) {
host_tensor_KV(global_page, 0, ps, kvh, dh) = host_tensor_K(kvl, dh, kvh, bs);
host_tensor_KV(global_page, 1, ps, kvh, dh) = host_tensor_V(dh, kvl, kvh, bs);
}
}
}
}
}
}
public:
GQATester(int kvH, int qH, int qL, int kvL, int dH, int BS, float softmax_scale, int sliding_window_size) :
kvH_(kvH), qHLocal_(qH / kvH), qL_(qL), kvL_(kvL), dH_(dH), BS_(BS), softmax_scale_(softmax_scale), sliding_window_size_(sliding_window_size) {
GQATester(int kvH, int qH, int qL, int kvL, int dH, int BS, float softmax_scale, int sliding_window_size, int mode = 0) :
kvH_(kvH), qHLocal_(qH / kvH), qL_(qL), kvL_(kvL), dH_(dH), BS_(BS), softmax_scale_(softmax_scale), sliding_window_size_(sliding_window_size), mode_(mode) {
assert(sliding_window_size_ >= 0);
// Allocate host memory
host_Q_.resize(kvH_ * qHLocal_ * qL_ * dH_ * BS_);
@@ -501,7 +612,7 @@ public:
host_seq_lens_.resize(BS_);
host_sinks_.resize(qHLocal_ * kvH_); // one sink per q head
stride_ = make_gqa_stride(kvH_, qHLocal_, qL_, kvL_, dH_, BS_);
stride_ = make_gqa_stride(kvH_, qHLocal_, qL_, kvL_, dH_, BS_, Page_Size);
// Create host CuTe tensors for initialization
auto host_tensor_Q = make_tensor(host_Q_.data(), TGV::gqa::make_layout_Q(kvH_, qHLocal_, qL_, dH_, BS_, stride_.stride_Q_kvH, stride_.stride_Q_qHLocal, stride_.stride_Q_qL, stride_.stride_Q_dH, stride_.stride_Q_BS));
@@ -513,10 +624,8 @@ public:
initialize_tensor(host_tensor_Q);
initialize_tensor(host_tensor_K);
initialize_tensor(host_tensor_V);
// have batch size matching kvL (i.e. max seq len) for now
bool test_var_seq_lens = false;
for (int i = 0; i < BS_; ++i) {
if (test_var_seq_lens) {
if (VarSeqLens) {
host_seq_lens_[i] = rand() % kvL_ + 1;
}
else { // all the batch have the same seq len
@@ -535,27 +644,81 @@ public:
device_seq_lens_ = host_seq_lens_;
device_sinks_ = host_sinks_;
// For paged mode, build the page_table and combined KV tensor on device. The harness owns the layouts:
// stride_KV_* and stride_PT_* were computed in make_gqa_stride above and are used both here (host pack)
// and downstream (passed into gqa_paged_host as stride args). Combined KV shape: (num_pages_total, 2,
// Page_Size, kvH, dH), BS folded into num_pages_total. We populate the page_table via
// build_random_page_table (Fisher-Yates shuffle of [0, num_pages_total) sliced per batch) to stress
// non-contiguous mappings.
if (mode_ == 1) {
assert(kvL_ % Page_Size == 0);
assert(kvL_ % CTA_kvL == 0);
int pages_per_batch = kvL_ / Page_Size;
int num_pages_total = BS_ * pages_per_batch;
auto host_page_table = build_random_page_table(pages_per_batch, num_pages_total);
auto host_tensor_page_table = make_tensor(host_page_table.data(),
make_layout(make_shape(pages_per_batch, BS_),
make_stride(stride_.stride_PT_p, stride_.stride_PT_BS)));
thrust::host_vector<TypeQKV> host_KV(num_pages_total * stride_.stride_KV_pages);
auto host_tensor_KV = make_tensor(host_KV.data(),
make_layout(make_shape(num_pages_total, 2, Page_Size, kvH_, dH_),
make_stride(stride_.stride_KV_pages, stride_.stride_KV_KV,
stride_.stride_KV_ps, stride_.stride_KV_kvH, stride_.stride_KV_dH)));
pack_combined_kv(host_tensor_page_table, host_tensor_K, host_tensor_V, host_tensor_KV);
device_KV_ = host_KV;
device_page_table_ = host_page_table;
}
gpuErrChk(cudaDeviceSynchronize());
}
void run_kernel(bool pdl, int pdl_count = -1, cudaStream_t stream = 0) {
TGV::gqa::gqa_host<
TypeQKV, TypeO, TypeAcc,
CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH,
BMM1_DMA_Stage, BMM2_DMA_Stage,
MaxSplits,
NumReductionCTA>(
device_K_.data().get(), device_Q_.data().get(), device_V_.data().get(), device_O_.data().get(),
device_seq_lens_.data().get(),
NoSink ? nullptr : device_sinks_.data().get(),
kvH_, qHLocal_, qL_, kvL_, dH_, BS_,
stride_.stride_K_kvH, stride_.stride_K_kvL, stride_.stride_K_dH, stride_.stride_K_BS,
stride_.stride_Q_kvH, stride_.stride_Q_qHLocal, stride_.stride_Q_qL, stride_.stride_Q_dH, stride_.stride_Q_BS,
stride_.stride_V_kvH, stride_.stride_V_kvL, stride_.stride_V_dH, stride_.stride_V_BS,
stride_.stride_O_kvH, stride_.stride_O_qHLocal, stride_.stride_O_qL, stride_.stride_O_dH, stride_.stride_O_BS,
softmax_scale_,
sliding_window_size_,
pdl, pdl_count, stream);
if (mode_ == 0) {
TGV::gqa::gqa_host<
TypeQKV, TypeO, TypeAcc,
CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH,
BMM1_DMA_Stage, BMM2_DMA_Stage,
MaxSplits,
NumReductionCTA>(
device_K_.data().get(), device_Q_.data().get(), device_V_.data().get(), device_O_.data().get(),
device_seq_lens_.data().get(),
NoSink ? nullptr : device_sinks_.data().get(),
kvH_, qHLocal_, qL_, kvL_, dH_, BS_,
stride_.stride_K_kvH, stride_.stride_K_kvL, stride_.stride_K_dH, stride_.stride_K_BS,
stride_.stride_Q_kvH, stride_.stride_Q_qHLocal, stride_.stride_Q_qL, stride_.stride_Q_dH, stride_.stride_Q_BS,
stride_.stride_V_kvH, stride_.stride_V_kvL, stride_.stride_V_dH, stride_.stride_V_BS,
stride_.stride_O_kvH, stride_.stride_O_qHLocal, stride_.stride_O_qL, stride_.stride_O_dH, stride_.stride_O_BS,
softmax_scale_,
sliding_window_size_,
pdl, pdl_count, stream);
}
else {
TGV::gqa_paged::gqa_paged_host<
TypeQKV, TypeO, TypeAcc,
CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH,
Page_Size,
BMM1_DMA_Stage, BMM2_DMA_Stage,
Page_Idx_Stage, Num_Page_Idx_Per_Stage,
MaxSplits,
NumReductionCTA>(
device_KV_.data().get(),
device_Q_.data().get(),
device_O_.data().get(),
NoSink ? nullptr : device_sinks_.data().get(),
device_seq_lens_.data().get(),
device_page_table_.data().get(),
kvH_, qHLocal_, qL_, kvL_, dH_, BS_,
stride_.stride_KV_pages, stride_.stride_KV_KV, stride_.stride_KV_ps, stride_.stride_KV_kvH, stride_.stride_KV_dH,
stride_.stride_Q_kvH, stride_.stride_Q_qHLocal, stride_.stride_Q_qL, stride_.stride_Q_dH, stride_.stride_Q_BS,
stride_.stride_O_kvH, stride_.stride_O_qHLocal, stride_.stride_O_qL, stride_.stride_O_dH, stride_.stride_O_BS,
stride_.stride_PT_p, stride_.stride_PT_BS,
softmax_scale_,
sliding_window_size_,
pdl, pdl_count, stream);
}
}
bool verify() {
@@ -604,16 +767,17 @@ public:
};
void benchmark_gqa(int kvH, int qH, int qL, int kvL, int dH, int BS, float softmax_scale, int sliding_window_size, bool pdl, int pdl_count, int num_testers = 4, int bench_iters = 100) {
void benchmark_gqa(int kvH, int qH, int qL, int kvL, int dH, int BS, float softmax_scale, int sliding_window_size, int mode, bool pdl, int pdl_count, int num_testers = 4, int bench_iters = 100) {
std::cout << "=== GQA Benchmark ===" << std::endl;
std::cout << "Problem size: kvH=" << kvH << ", qH=" << qH << ", qL=" << qL << ", kvL=" << kvL << ", dH=" << dH << ", BS=" << BS << ", sliding_window_size=" << sliding_window_size << std::endl;
std::cout << "Mode: " << mode << " (" << (mode == 0 ? "gqa" : "gqa_paged") << ")" << std::endl;
std::cout << "Number of testers (L2 thrashing): " << num_testers << std::endl;
std::cout << "Benchmark iterations: " << bench_iters << std::endl;
// Create multiple tester instances to thrash L2 cache
std::vector<std::unique_ptr<GQATester>> testers;
for (int i = 0; i < num_testers; ++i) {
testers.push_back(std::make_unique<GQATester>(kvH, qH, qL, kvL, dH, BS, softmax_scale, sliding_window_size));
testers.push_back(std::make_unique<GQATester>(kvH, qH, qL, kvL, dH, BS, softmax_scale, sliding_window_size, mode));
}
std::cout << "Created " << num_testers << " GQATester instances" << std::endl;
@@ -708,6 +872,8 @@ int main(int argc, char* argv[]) {
bool pdl = false;
// don't support it yet
int pdl_count = -1;
// 0: gqa (contiguous KV cache), 1: gqa_paged (paged KV cache)
int mode = 0;
// arg parsing
while (1) {
@@ -718,6 +884,7 @@ int main(int argc, char* argv[]) {
{"qL", required_argument, 0, 0},
{"BS", required_argument, 0, 0},
{"sliding_window_size", required_argument, 0, 0},
{"mode", required_argument, 0, 0},
{0, 0, 0, 0} // denote end of array
};
@@ -737,14 +904,17 @@ int main(int argc, char* argv[]) {
else if (option_index == 3) qL = atoi(optarg);
else if (option_index == 4) BS = atoi(optarg);
else if (option_index == 5) sliding_window_size = atoi(optarg);
else if (option_index == 6) mode = atoi(optarg);
break;
default: assert(false);
}
}
GQATester tester(kvH, qH, qL, kvL, dH, BS, softmax_scale, sliding_window_size);
GQATester tester(kvH, qH, qL, kvL, dH, BS, softmax_scale, sliding_window_size, mode);
bool success = tester.verify();
std::cout << "Correctness test " << (success ? "PASSED" : "FAILED") << std::endl;
std::cout << "Correctness test"
<< " mode=" << mode
<< " " << (success ? "PASSED" : "FAILED") << std::endl;
benchmark_gqa(kvH, qH, qL, kvL, dH, BS, softmax_scale, sliding_window_size, pdl, pdl_count, 100, 1000);
benchmark_gqa(kvH, qH, qL, kvL, dH, BS, softmax_scale, sliding_window_size, mode, pdl, pdl_count, 100, 1000);
}

View File

@@ -50,13 +50,7 @@
#include <cute/arch/tmem_allocator_sm100.hpp> // TMEM allocator for SM100
#include <cute/arch/copy_sm90_desc.hpp>
#define gpuErrChk(ans) { gpuAssert2((ans), __FILE__, __LINE__); }
inline void gpuAssert2(cudaError_t code, const char *file, int line, bool abort=true) {
if (code != cudaSuccess) {
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
if (abort) exit(code);
}
}
#include "common.cuh"
namespace TGV {
namespace gqa {
@@ -93,27 +87,6 @@ acc = tl.dot(p.to(q.dtype), v) # [BLOCK_qL, BLOCK_dH]
tl.store(acc_ptrs, acc)
*/
// Store value to remote shared memory in the cluster
CUTE_DEVICE void
store_shared_remote_f32(float value, uint32_t dsmem_addr, uint32_t remote_barrier_addr) {
asm volatile("st.async.shared::cluster.mbarrier::complete_tx::bytes.f32 [%0], %1, [%2];"
: : "r"(dsmem_addr), "f"(value), "r"(remote_barrier_addr));
}
// given a smem tensor, return the dsmem tensor for the given rank, the tensor addr is in smem addr space (not generic addr space)
template <class Tensor>
CUTE_DEVICE auto
get_dsmem_tensor(Tensor tensor, int rank) {
using T = typename decltype(tensor)::value_type;
// tensor.data().get() is the smem addr in the generic addr space, in the generic addr space a region is reserved for smem
// doing ld/st to this region of the generic addr space will be converted into ld.shared/st.shared to the smem addr space by the compiler
// the mapa (and many inline ptx) instruction's input and output addr are in the smem/dsmem addr space, so we need to explicitly convert from generic to shared addr space
uint32_t smem_addr = __cvta_generic_to_shared(tensor.data().get()); // smem addr space
// mapa to get the dsmem addr of this tensor in another CTA
uint32_t dsmem_addr = set_block_rank(smem_addr, rank); // smem addr space
return make_tensor(make_smem_ptr((T*)dsmem_addr), tensor.layout());
}
// Helper methods to create layouts
// K always has the shape (kvL, dH, kvH, BS)
// kvH has to be the last dim because we do mma partitioning to the first two dims (M, K) in gemm terminology
@@ -243,6 +216,8 @@ struct SharedStorage {
alignas(16) cute::uint64_t bmm1_softmax_full_barrier; // Barrier between BMM1 and softmax, BMM1 tells softmax the tile is ready/full, softmax can start consuming it
alignas(16) cute::uint64_t bmm2_epilog_full_barrier; // Barrier between BMM2 and epilog, BMM2 tells epilog the tile is ready/full, epilog can start consuming it
alignas(16) cute::uint64_t tmem_allocation_result_barrier; // Barrier between MMA and epilog, sync tmem allocation/deallocation status between MMA and epilogue warps within CTA
// for cluster reduction
alignas(16) cute::uint64_t maxsum_mailbox_full_barrier; // barrier indicating the st.async of fmax and fsum are done
alignas(16) cute::uint64_t acc2_mailbox_full_barrier; // barrier indicating the st.async of acc2 are done
@@ -512,67 +487,6 @@ cta_reduce_transposed(
return acc;
}
// copied from SM100::TMEM::LOAD::copy_unpack cutlass/include/cute/atom/copy_traits_sm100.hpp
// what it does is given a tmem address, load the data into rmem tensor with the given tcgen05.ld copy op
template <
class CopyOp,
class TD, class DLayout>
CUTLASS_DEVICE void
tmem_load(
uint32_t tmem_addr,
Tensor<TD,DLayout>& dst
) {
static_assert(is_rmem<TD>::value, "Expected RMEM dst.");
using RegTypeDst = typename remove_extent<typename CopyOp::DRegisters>::type;
Tensor rD = recast<RegTypeDst>(dst);
constexpr int RegNumDst = extent<typename CopyOp::DRegisters>::value;
CUTE_STATIC_ASSERT_V(size(rD) == Int<RegNumDst>{},
"The tcgen05.ld CopyOp's size does not match the destination tensor size.");
detail::explode(CopyOp::copy,
&tmem_addr, seq<0>{},
rD, make_seq<RegNumDst>{});
}
// copied from SM100::TMEM::STORE::copy_unpack cutlass/include/cute/atom/copy_traits_sm100.hpp
// what it does is given a tmem address, store the data in rmem tensor to the tmem address with the given tcgen05.st copy op
template <
class CopyOp,
class TS, class SLayout>
CUTLASS_DEVICE void
tmem_store(
Tensor<TS,SLayout>& src,
uint32_t tmem_addr
) {
static_assert(is_rmem<TS>::value, "Expected RMEM src.");
using RegTypeSrc = typename remove_extent<typename CopyOp::SRegisters>::type;
Tensor rS = recast<RegTypeSrc>(src);
constexpr int RegNumSrc = extent<typename CopyOp::SRegisters>::value;
CUTE_STATIC_ASSERT_V(size(rS) == Int<RegNumSrc>{},
"The tcgen05.st CopyOp's size does not match the source tensor size.");
detail::explode(CopyOp::copy,
rS, make_seq<RegNumSrc>{},
&tmem_addr, seq<0>{});
}
// issue cp.async
CUTLASS_DEVICE void
cp_async(
int* gmem_addr,
int* smem_addr
) {
uint32_t smem_int_ptr = cute::cast_smem_ptr_to_uint(smem_addr);
asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n"
:: "r"(smem_int_ptr),
"l"(gmem_addr),
"n"(sizeof(int)));
}
// mapping between thread id (T) -> dH (row index of Acc2)
template <int CTA_dH>
CUTLASS_DEVICE auto
@@ -831,14 +745,13 @@ template <
class TiledBMM1,
class TiledBMM2,
int CTA_qHLocal, int CTA_qL, int CTA_kvL, int CTA_dH>
CUTLASS_DEVICE void
CUTLASS_DEVICE void
MMA_warp(
SharedStorage& shared_storage,
WorkTileInfo work_tile_info,
OTensor mO,
TiledBMM1 tiled_bmm1,
TiledBMM2 tiled_bmm2,
cutlass::arch::NamedBarrier& tmem_allocation_barrier
TiledBMM2 tiled_bmm2
) {
if (!work_tile_info.is_valid()) {
// we don't allocate tmem for invalid tiles but we still need to relinquish the allocation lock
@@ -900,7 +813,7 @@ MMA_warp(
tmem_allocator.allocate(Acc1_col_max, &shared_storage.bmm1_tmem_base_ptr);
tmem_allocator.allocate(Acc2_col_max, &shared_storage.bmm2_tmem_base_ptr);
// notify epilog warp that tmem allocation is complete
tmem_allocation_barrier.arrive();
arrive_barrier(shared_storage.tmem_allocation_result_barrier);
// relinquish early so that prefetch cta can be launched
tmem_allocator.release_allocation_lock();
@@ -963,7 +876,10 @@ MMA_warp(
MMA_gemm<decltype(tCrV), decltype(tCrP), decltype(tCtAcc2), TiledBMM2, '2', Print>(tCrV, tCrP, tCtAcc2, tiled_bmm2, bmm2_stage_idx, tma_bmm2_full_barrier_phase_bit, bmm2_accumulate, shared_storage.tmasoftmax_bmm2_full_barrier, shared_storage.tma_bmm2_empty_barrier, shared_storage.bmm2_epilog_full_barrier);
// wait for tmem deallocation signal from epilog warp
tmem_allocation_barrier.arrive_and_wait();
arrive_barrier(shared_storage.tmem_allocation_result_barrier);
// initial phase bit = 1 since it's already flipped once for tmem allocation
// it will flip to 0 when tmem can be deallocated, so we wait for old phase bit of 1
wait_barrier(shared_storage.tmem_allocation_result_barrier, 1);
// deallocate TMEM
tmem_allocator.free(shared_storage.bmm1_tmem_base_ptr, Acc1_col_max);
@@ -982,7 +898,7 @@ template <
int CTA_qHLocal, int CTA_qL, int CTA_kvL, int CTA_dH,
int NumReductionCTA,
bool NoSink>
CUTLASS_DEVICE void
CUTLASS_DEVICE void
EPILOG_warp(
SharedStorage& shared_storage,
WorkTileInfo work_tile_info,
@@ -994,7 +910,6 @@ EPILOG_warp(
TiledBMM2 tiled_bmm2,
float softmax_scale_log2,
int sliding_window_size,
cutlass::arch::NamedBarrier& tmem_allocation_barrier,
cutlass::arch::NamedBarrier& epilog_barrier,
int NumSplits,
int tid, // tid local to epilog warp
@@ -1024,7 +939,9 @@ EPILOG_warp(
// wait for tmem allocation in mma warp to complete, only do the wait for valid tiles
if (work_tile_info.is_valid()) {
tmem_allocation_barrier.arrive_and_wait();
arrive_barrier(shared_storage.tmem_allocation_result_barrier);
// initial phase bit = 0, it will flip to 1 when tmem is allocated, so we wait for old phase bit of 0
wait_barrier(shared_storage.tmem_allocation_result_barrier, 0);
}
// update tmem base ptr of the accumulator tensor
@@ -1344,11 +1261,15 @@ EPILOG_warp(
static_assert(MaxSplits <= 32, "we can use 1 warp to initialize mailbox");
// initialize mailbox tensor for fmax and fsum, when NumSplits < MaxSplits, we need to init those value to -inf and 0
// because we do reduction on the full tensor (of size MaxSplits) not just the valid splits
// there is no need to init sAcc2 because it will be scaled with beta which will be 0 for invalid splits
if (tid < MaxSplits) {
fill(sFmaxMailbox(tid, _), -cutlass::platform::numeric_limits<TypeAcc>::infinity());
clear(sFsumMailbox(tid, _));
}
// we also need to clear out acc2 mailbox for invalid splits, because acc2 value could be nan
// nan * 0 (beta) = nan, we still need to clear acc2
if (tid < CTA_dH) {
clear(sAcc2Mailbox(tid, _, _));
}
// ensure initialized smem is visible to the entire cluster
cutlass::arch::fence_view_async_shared();
@@ -1606,7 +1527,7 @@ EPILOG_warp(
}
// signal the mma warp tcgen05.ld of bmm2 is done, can start deallocate all tmem
tmem_allocation_barrier.arrive();
arrive_barrier(shared_storage.tmem_allocation_result_barrier);
}
// only NumReductionCTA number of reduction ctas will do the reduction
@@ -1734,15 +1655,16 @@ EPILOG_warp(
}*/
}
// K has shape (kvL, dH, kvH, BS)
// Q has shape ((qHLocal, qL), dH, kvH, BS)
// V has shape (dH, kvL, kvH, BS)
// O has shape (dH, (qHLocal, qL), kvH, BS)
// sinks has shape ((qHLocal, qL), kvH)
// seq_len has shape (BS)
// mK has shape (kvL, dH, kvH, BS)
// mQ has shape ((qHLocal, qL), dH, kvH, BS)
// mV has shape (dH, kvL, kvH, BS)
// mO has shape (dH, (qHLocal, qL), kvH, BS)
// mSink has shape ((qHLocal, qL), kvH)
// mSeqLens has shape (BS)
template <
class SharedStorage,
class KTensor, class QTensor, class VTensor, class OTensor, class SinkTensor,
class SeqLensTensor,
class TmaAtomK, class TmaAtomQ, class TmaAtomV,
class TiledBMM1, class TiledBMM2,
class TypeAcc,
@@ -1758,7 +1680,7 @@ gqa_device(
VTensor mV,
OTensor mO,
SinkTensor mSink,
int* seq_lens,
SeqLensTensor mSeqLens,
CUTE_GRID_CONSTANT TmaAtomK const tma_atom_K,
CUTE_GRID_CONSTANT TmaAtomQ const tma_atom_Q,
CUTE_GRID_CONSTANT TmaAtomV const tma_atom_V,
@@ -1782,7 +1704,7 @@ gqa_device(
int BS_idx = blockIdx.x / kvH;
// only thread 0 issues cp.async to load the seq_len
if (threadIdx.x == 0) {
cp_async(&seq_lens[BS_idx], &shared_storage.seq_len);
cp_async(&mSeqLens(BS_idx), &shared_storage.seq_len);
}
//if (threadIdx.x == 0) {
@@ -1805,15 +1727,13 @@ gqa_device(
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterBarrier, 1>(&shared_storage.bmm1_softmax_full_barrier, /* arrival count */ 1);
// 1 thread (BMM2) arrive to signal epilog
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterBarrier, 1>(&shared_storage.bmm2_epilog_full_barrier, /* arrival count */ 1);
// 32 (mma) + 128 (epilog) to signal tmem allocation/deallocation result
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterBarrier, 1>(&shared_storage.tmem_allocation_result_barrier, /* arrival count */ 32 + 128);
// 1 thread (epilog) arrive to signal maxsum
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterTransactionBarrier, 1>(&shared_storage.maxsum_mailbox_full_barrier, /* arrival count */ 1);
// 1 thread (epilog) arrive to signal acc2
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterTransactionBarrier, 1>(&shared_storage.acc2_mailbox_full_barrier, /* arrival count */ 1);
}
// Sync tmem allocation status between MMA and softmax/epilogue warps within CTA
// 32 threads (mma) + 128 threads (epilog) to sync
// also used for tmem deallocation between epilog warps and mma warps within CTA
cutlass::arch::NamedBarrier tmem_allocation_barrier(32 + 128, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier);
// syncing all threads (128) within 4 epilog warps
cutlass::arch::NamedBarrier epilog_barrier(128, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
@@ -1913,13 +1833,13 @@ gqa_device(
DMA_KV_warp<SharedStorage, WorkTileInfo, decltype(mK), decltype(mV), TmaAtomK, TmaAtomV, TiledBMM1, TiledBMM2, CTA_kvL, CTA_dH>(shared_storage, work_tile_info, mK, mV, &tma_atom_K, &tma_atom_V, tiled_bmm1, tiled_bmm2);
}
else if (warp_idx == 2) {
MMA_warp<SharedStorage, WorkTileInfo, decltype(mO), TiledBMM1, TiledBMM2, CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH>(shared_storage, work_tile_info, mO, tiled_bmm1, tiled_bmm2, tmem_allocation_barrier);
}
MMA_warp<SharedStorage, WorkTileInfo, decltype(mO), TiledBMM1, TiledBMM2, CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH>(shared_storage, work_tile_info, mO, tiled_bmm1, tiled_bmm2);
}
else if (warp_idx >= 4) {
// epilog tid is from 128 to 255, need to offset by -128 when getting the per thread slice
int tid = threadIdx.x - 128;
// warp_idx - 4 because epilog warp group starts from warp 4
EPILOG_warp<SharedStorage, WorkTileInfo, decltype(mK), decltype(mO), decltype(mSink), TiledBMM1, TiledBMM2, CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH, NumReductionCTA, NoSink>(shared_storage, work_tile_info, mK, mO, mSink, seq_len, tiled_bmm1, tiled_bmm2, softmax_scale_log2, sliding_window_size, tmem_allocation_barrier, epilog_barrier, NumSplits, tid, warp_idx - 4, rank);
EPILOG_warp<SharedStorage, WorkTileInfo, decltype(mK), decltype(mO), decltype(mSink), TiledBMM1, TiledBMM2, CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH, NumReductionCTA, NoSink>(shared_storage, work_tile_info, mK, mO, mSink, seq_len, tiled_bmm1, tiled_bmm2, softmax_scale_log2, sliding_window_size, epilog_barrier, NumSplits, tid, warp_idx - 4, rank);
}
__syncthreads();
@@ -1980,6 +1900,7 @@ void gqa_host(
Tensor mV = make_tensor(make_gmem_ptr(device_ptr_V), layout_V); // (dH, kvL, kvH, BS)
Tensor mO = make_tensor(make_gmem_ptr(device_ptr_O), layout_O); // (dH, (qHLocal, qL), kvH, BS)
Tensor mSink = make_tensor(make_gmem_ptr(device_ptr_sinks), layout_sinks); // ((qHLocal, qL), kvH)
Tensor mSeqLens = make_tensor(make_gmem_ptr(seq_lens), make_layout(make_shape(BS))); // (BS)
//printf("mK: "); print(mK); printf("\n");
//printf("mQ: "); print(mQ); printf("\n");
@@ -2174,6 +2095,7 @@ void gqa_host(
if (device_ptr_sinks != nullptr) {
auto *kernel_instance =
&gqa_device<SMEMStorage, decltype(mK_tma), decltype(mQ_tma), decltype(mV_tma), decltype(mO), decltype(mSink),
decltype(mSeqLens),
decltype(tma_atom_K), decltype(tma_atom_Q), decltype(tma_atom_V),
decltype(tiled_bmm1), decltype(tiled_bmm2),
TypeAcc,
@@ -2185,14 +2107,15 @@ void gqa_host(
// portable max cluster size is 8, but sm100a supports 16, need explicit opt in
gpuErrChk(cudaFuncSetAttribute(*kernel_instance, cudaFuncAttributeNonPortableClusterSizeAllowed, 1));
gpuErrChk(cudaLaunchKernelEx(&config, kernel_instance, mK_tma, mQ_tma, mV_tma, mO, mSink,
seq_lens,
mSeqLens,
tma_atom_K, tma_atom_Q, tma_atom_V,
tiled_bmm1, tiled_bmm2,
softmax_scale * Log2_E, sliding_window_size, pdl_count));
}
}
else {
auto *kernel_instance =
&gqa_device<SMEMStorage, decltype(mK_tma), decltype(mQ_tma), decltype(mV_tma), decltype(mO), decltype(mSink),
decltype(mSeqLens),
decltype(tma_atom_K), decltype(tma_atom_Q), decltype(tma_atom_V),
decltype(tiled_bmm1), decltype(tiled_bmm2),
TypeAcc,
@@ -2204,7 +2127,7 @@ void gqa_host(
// portable max cluster size is 8, but sm100a supports 16, need explicit opt in
gpuErrChk(cudaFuncSetAttribute(*kernel_instance, cudaFuncAttributeNonPortableClusterSizeAllowed, 1));
gpuErrChk(cudaLaunchKernelEx(&config, kernel_instance, mK_tma, mQ_tma, mV_tma, mO, mSink,
seq_lens,
mSeqLens,
tma_atom_K, tma_atom_Q, tma_atom_V,
tiled_bmm1, tiled_bmm2,
softmax_scale * Log2_E, sliding_window_size, pdl_count));

View File

@@ -0,0 +1,867 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cassert>
#include <cstdint>
#include <iostream>
#include <cstdio>
#include <cmath>
// Cutlass includes
#include <cutlass/cutlass.h>
#include <cutlass/numeric_types.h>
#include <cutlass/arch/barrier.h>
#include <cutlass/gemm/collective/builders/sm100_common.inl> // mma/smem selector, umma::major
#include <cutlass/numeric_conversion.h>
#include <cutlass/arch/grid_dependency_control.h>
// CuTe includes
#include <cute/tensor.hpp> // CuTe tensor implementation
#include <cute/arch/tmem_allocator_sm100.hpp> // TMEM allocator for SM100
#include <cute/arch/copy_sm90_desc.hpp>
#include <cute/arch/copy_sm90_tma.hpp>
#include "common.cuh"
#include "tgv_gqa.cuh" // reuse layout helpers, WorkTileInfo, TMA_copy, MMA_gemm, reductions, DMA_Q/MMA/EPILOG warps
// Grouped Query Attention (paged KV cache) with dual BMM + online softmax. 7 warps: 1 DMA_Q, 1 DMA_KV, 1 MMA,
// 4 EPILOG. Warp 3 is unused.
// Warp 0 (DMA_Q): Loads Q via TMA, single-stage. Reused from gqa namespace (gqa::DMA_Q_warp).
// Warp 1 (DMA_KV): Loads K then V via TMA, one cp.async.bulk per page. Issues its own cp.async of per-CTA-tile
// page indices into a smem staging buffer (lane-distributed, thread-local cp_async fence/wait).
// Defined locally in gqa_paged namespace.
// Warp 2 (MMA): Performs BMM1 (K@Q) and BMM2 (V@P). Reused from gqa namespace (gqa::MMA_warp).
// Warps 4-7 (EPILOG): Softmax partial max/sum warp reduction, cluster wide max/sum reduction, final flash-decode
// output with attention sink support. Reused from gqa namespace (gqa::EPILOG_warp).
// WorkTileInfo: Reused from gqa namespace. Attention-specific fields: BS_idx, kvH_idx, kvL_idx_start/end, dH_idx, qHLocal_idx, qL_idx.
// SharedStorage: Defined locally. Inherits from gqa::SharedStorage and adds the paged-only pieces: PageIdx smem
// buffer (double-buffered because K and V are offset by 1 tile -- around pi-stage boundaries they read different
// slots) and paged views (tensor_sK_paged / tensor_sV_paged / tensor_sPageIdx) over the inherited K/V smem buffers.
namespace TGV {
namespace gqa_paged {
using namespace cute;
// Symbols reused identically from TGV::gqa -- see tgv_gqa.cuh for definitions.
// Log2_E and WorkTileInfo are pulled in here; everything else is referenced explicitly via gqa:: at call sites.
using TGV::gqa::Log2_E;
using TGV::gqa::WorkTileInfo;
// The (bs_idx, per_batch_page_idx) -> physical page id mapping lives in a gmem page_table tensor of shape
// (kvL/Page_Size, BS), seeded by the host harness and fetched at runtime by DMA_KV_warp. To install a
// new placement policy, populate the page_table differently host-side; the kernel does not assume any structure.
// The shared memory buffers for Q, K, V matrices.
template <
class TypeQKV, // Tensor Q/K/V data type
class TypeAcc, // Tensor Acc data type
class KSmemLayout, // ((Mma_M, Mma_K), NumMma_M, NumMma_K, BMM1_DMA_Stage)
class KPagedSmemLayout,// (Page_Size, CTA_dH, NumPagePerCTATile, BMM1_DMA_Stage), same memory as KSmemLayout
class QSmemLayout, // ((Mma_N, Mma_K), NumMma_N, NumMma_K, 1)
class VSmemLayout, // ((Mma_M, Mma_K), NumMma_M, NumMma_K, BMM2_DMA_Stage)
class VPagedSmemLayout,// (CTA_dH, Page_Size, NumPagePerCTATile, BMM2_DMA_Stage), same memory as VSmemLayout
class SSmemLayout, // (CTA_kvL, (CTA_qHLocal, CTA_qL), 1) aka C matrix (M, N, 1) for bmm1
class PSmemLayout, // ((CTA_qHLocal, CTA_qL), CTA_kvL, 1) aka B matrix (N, K, 1) for bmm2
class WRSmemLayout, // (NumEpilogWarps, (CTA_qHLocal, CTA_qL)), WR stands for warp reduce
class MSMailboxSmemLayout,// (MaxSplits, CTA_qHLocal * CTA_qL / NumReductionCTA), MS stands max and sum
class Acc1SmemLayout, // (CTA_kvL, (CTA_qHLocal, CTA_qL), 1)
class Acc2MailboxSmemLayout, // (CTA_dH, CTA_qHLocal * CTA_qL / NumReductionCTA, MaxSplits)
class SinksSmemLayout, // (CTA_qHLocal * CTA_qL / NumReductionCTA)
class PageIdxSmemLayout, // ((NumPagePerCTATile, Tiles_Per_Pi_Stage), Page_Idx_Stage), int32 page indices staged by DMA_KV warp
int BMM1_DMA_Stage,
int BMM2_DMA_Stage,
int Page_Idx_Stage>
// Paged kernel adds two pieces on top of the plain gqa SharedStorage:
// - PageIdx smem buffer (DMA_KV staging, int32 page indices). Double-buffered: K and V are 1 tile apart, so
// around a pi-stage boundary they read different slots. DMA_KV writes the next slot via lane-distributed
// cp.async at K_t_in_stage==1 (when V has crossed into the current pi-stage and the next slot is free).
// - paged views of the inherited K/V smem buffers (same memory, different layout)
struct SharedStorage : TGV::gqa::SharedStorage<
TypeQKV, TypeAcc,
KSmemLayout, QSmemLayout, VSmemLayout, SSmemLayout, PSmemLayout,
WRSmemLayout, MSMailboxSmemLayout, Acc1SmemLayout,
Acc2MailboxSmemLayout, SinksSmemLayout,
BMM1_DMA_Stage, BMM2_DMA_Stage> {
// DMA_KV staging buffer for int32 page indices, layout ((NumPagePerCTATile, Tiles_Per_Pi_Stage), Page_Idx_Stage)
alignas(128) cute::ArrayEngine<int, cute::cosize_v<PageIdxSmemLayout>> PageIdx;
// alternative paged view of the same K smem buffer used by TMA: (Page_Size, CTA_dH, NumPagePerCTATile, BMM1_DMA_Stage)
CUTE_DEVICE constexpr auto tensor_sK_paged() { return make_tensor(make_smem_ptr(this->K.begin()), KPagedSmemLayout{}); }
// alternative paged view of the same V smem buffer used by TMA: (CTA_dH, Page_Size, NumPagePerCTATile, BMM2_DMA_Stage)
CUTE_DEVICE constexpr auto tensor_sV_paged() { return make_tensor(make_smem_ptr(this->V.begin()), VPagedSmemLayout{}); }
CUTE_DEVICE constexpr auto tensor_sPageIdx() { return make_tensor(make_smem_ptr(PageIdx.begin()), PageIdxSmemLayout{}); }
};
// paged TMA copy: like gqa::TMA_copy, but issues NumPagePerCTATile back-to-back copies that share one full barrier.
// empty/full barrier semantics still operate at CTA-tile granularity -- one slot covers the whole (CTA_kvL, CTA_dH) tile,
// transaction bytes counts all pages, and a single set_barrier_transaction_bytes arrives once per stage.
//
// Page indices are read from smem -- DMA_KV stages them from the gmem page_table via lane-distributed
// cp.async and a thread-local cp_async_fence/wait. The caller passes a smem pointer to the
// NumPagePerCTATile page indices owned by this tile; this function performs an ld.shared per page and uses the
// resulting global page id as the gmem coordinate for the TMA copy.
template <
class GTensor,
class STensor,
class TmaAtom,
char Name,
bool Print,
int DMA_Stage,
int NumPagePerCTATile>
CUTLASS_DEVICE void
TMA_copy_paged(
GTensor gTensor, // ((TMA, NumTma_K), Num_Page_Global)
STensor sTensor, // ((TMA, NumTma_K), NumPagePerCTATile, DMA_Stage)
int k_tile,
int const* page_idx_smem, // pointer to NumPagePerCTATile contiguous int page indices in smem (this tile's slice)
int& tma_mma_empty_barrier_phase_bit,
int tma_transaction_bytes, // total bytes for one CTA tile (NumPagePerCTATile pages)
TmaAtom const* tma_atom,
cute::uint64_t* tma_mma_full_barrier,
cute::uint64_t* tma_mma_empty_barrier
) {
// wait for the smem slot to be empty before issuing pages for the next CTA tile
wait_barrier(tma_mma_empty_barrier[k_tile % DMA_Stage], tma_mma_empty_barrier_phase_bit);
if constexpr (Print) {
if (elect_one_sync() && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 0)) {
printf("[DMA_%c] barrier empty, kblock %d (paged, %d pages)\n", Name, k_tile, NumPagePerCTATile);
}
}
if (elect_one_sync()) {
// single set_barrier_transaction_bytes accounts for all pages; later cute::copy calls don't add arrivals.
set_barrier_transaction_bytes(tma_mma_full_barrier[k_tile % DMA_Stage], tma_transaction_bytes);
CUTE_UNROLL
for (int p = 0; p < NumPagePerCTATile; p++) {
// per-page lookup: load global page id from smem and feed it as the gmem coordinate for one TMA copy.
int global_page = page_idx_smem[p];
copy(tma_atom->with(tma_mma_full_barrier[k_tile % DMA_Stage]),
gTensor(_, global_page),
sTensor(_, p, k_tile % DMA_Stage));
}
}
if ((k_tile % DMA_Stage) == (DMA_Stage - 1)) {
tma_mma_empty_barrier_phase_bit ^= 1;
}
}
// Lane-distributed cp.async of one pi-stage's worth (Num_Page_Idx_Per_Stage ints) of page indices from gmem
// page table to a smem slot. Each lane loads ceil(Num_Page_Idx_Per_Stage/32) ints; the LSU coalesces them.
// Caller pre-slices the BS mode of the page table (it is CTA-constant) and is responsible for
// cp_async_fence/wait + __syncwarp ordering after this returns.
template <
int Tiles_Per_Pi_Stage, int Num_Page_Idx_Per_Stage, int Page_Idx_Stage,
class GPageTableTensor, class SPageIdxTensor>
CUTLASS_DEVICE void
issue_pi_stage_cp_async(
GPageTableTensor gPageTable, // ((NumPagePerCTATile, kvL/CTA_kvL),), int gmem -- BS already sliced
SPageIdxTensor sPageIdx, // ((NumPagePerCTATile, Tiles_Per_Pi_Stage), Page_Idx_Stage), int smem
int kvL_idx_start, int pi_stage_idx) {
int tile_start = kvL_idx_start + pi_stage_idx * Tiles_Per_Pi_Stage;
int slot = pi_stage_idx % Page_Idx_Stage;
// Shift gPageTable's data pointer to the first page of this pi-stage; flat-index `i` walks consecutive
// (page-within-CTA-tile, CTA-tile) pairs in column-major order across Num_Page_Idx_Per_Stage ints.
auto gp = domain_offset(make_coord(make_coord(0, tile_start)), gPageTable);
// Use threadIdx.x-derived lane (cheap: threadIdx.x is already live in a register from earlier in the kernel).
int lane = cutlass::canonical_lane_idx();
CUTE_UNROLL
for (int i = lane; i < Num_Page_Idx_Per_Stage; i += 32) {
cp_async(&gp(i), &sPageIdx(i, slot));
}
}
// Paged DMA_KV warp: TMA-loads K then V, one cp.async.bulk per page. Page indices are staged into smem by this
// same warp via lane-distributed cp.async (no separate Read_Page_Idx warp, no transaction-barrier handshake).
// See the in-body comment ahead of the prolog for the page-idx fetch pipeline. Tiles_Per_Pi_Stage >= 2 is
// required so the K_t_in_stage==1 issue hook exists.
template <
class SharedStorage,
class WorkTileInfo,
class KTensor,
class VTensor,
class PageTableTensor,
class TmaAtomK,
class TmaAtomV,
int CTA_kvL, int CTA_dH, int Page_Size,
int Page_Idx_Stage, int Num_Page_Idx_Per_Stage>
CUTLASS_DEVICE void
DMA_KV_warp(
SharedStorage& shared_storage,
WorkTileInfo work_tile_info,
KTensor mK,
VTensor mV,
PageTableTensor mPageTable, // (kvL/Page_Size, BS), int gmem; underlying allocation tail-padded by Num_Page_Idx_Per_Stage ints. Mode-0 is partitioned by NumPagePerCTATile to access per-CTA-tile slices.
// when passing tma descriptor as function argument, it has to be pass by pointer/reference, if pass by value, it will live on local memory (i.e. the stack)
// and the tma unit cannot access the local memory, (even if it can, the local memory is strided by thread id, the content for each thread is strided)
TmaAtomK const* tma_atom_K,
TmaAtomV const* tma_atom_V) {
if (!work_tile_info.is_valid()) {
return;
}
// CTA_kvL % Page_Size == 0, Num_Page_Idx_Per_Stage % NumPagePerCTATile == 0, Page_Idx_Stage == 2,
// and Tiles_Per_Pi_Stage >= 2 are all checked in gqa_paged_host.
constexpr int NumPagePerCTATile = CTA_kvL / Page_Size;
constexpr int Tiles_Per_Pi_Stage = Num_Page_Idx_Per_Stage / NumPagePerCTATile;
// setup code for K tensor
// paged smem view of K, NOT mma partitioned, used purely for TMA: (Page_Size, CTA_dH, NumPagePerCTATile, BMM1_DMA_Stage)
// the same smem buffer is also accessible via tensor_sK() in MMA-partitioned form for the MMA warp.
Tensor sK_paged = shared_storage.tensor_sK_paged(); // (Page_Size, CTA_dH, NumPagePerCTATile, BMM1_DMA_Stage)
// mK has shape (Page_Size, dH, Num_Page_Global, kvH); BS is folded into Num_Page_Global so there is no BS mode.
// local_tile with a static (Page_Size, CTA_dH) tile_shape materializes the leading two modes statically (TMA
// partition requires static mode-0 sizes); dH==CTA_dH and Page_Size already matches so the divisions collapse.
// The page selection is done at TMA-issue time via the smem page_idx slice (staged by DMA_KV's own cp.async),
// not by indexing on a BS mode here.
Tensor gK = local_tile(mK, make_shape(Int<Page_Size>{}, Int<CTA_dH>{}),
make_coord(0, 0, _, work_tile_info.kvH_idx)); // (Page_Size, CTA_dH, Num_Page_Global)
// group modes [0,2) on both sK_paged and gK so the TMA box (Page_Size, CTA_dH) is mode 0; outer modes are
// (NumPagePerCTATile, BMM1_DMA_Stage) for smem and Num_Page_Global for gmem.
auto [tAgK, tAsK] = tma_partition(*tma_atom_K,
Int<0>{}, // cta_coord: 1x1 cluster
Layout<_1>{}, // cta_layout: CTA coord -> logical multicast id, no multicast, just identity layout
group_modes<0, 2>(sK_paged), group_modes<0, 2>(gK));
// tAsK: ((TMA, NumTma_K), NumPagePerCTATile, BMM1_DMA_Stage) -- 3 modes
// tAgK: ((TMA, NumTma_K), Num_Page_Global) -- 2 modes
// the shape of the TMA box is (Page_Size, CTA_dH)
/*if (elect_one_sync() && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 0)) {
printf("[PAGED-K] sK_paged:\t"); print(sK_paged); printf("\n"); // (Page_Size, CTA_dH, NumPagePerCTATile, BMM1_DMA_Stage)
printf("[PAGED-K] gK:\t"); print(gK); printf("\n"); // (Page_Size, CTA_dH, Num_Page_Global)
printf("[PAGED-K] tAgK:\t"); print(tAgK); printf("\n"); // ((TMA, NumTma_K), Num_Page_Global)
printf("[PAGED-K] tAsK:\t"); print(tAsK); printf("\n"); // ((TMA, NumTma_K), NumPagePerCTATile, BMM1_DMA_Stage)
}*/
// setup code for V tensor
// paged smem view of V, NOT mma partitioned, used purely for TMA: (CTA_dH, Page_Size, NumPagePerCTATile, BMM2_DMA_Stage)
// the same smem buffer is also accessible via tensor_sV() in MMA-partitioned form for the MMA warp.
Tensor sV_paged = shared_storage.tensor_sV_paged(); // (CTA_dH, Page_Size, NumPagePerCTATile, BMM2_DMA_Stage)
// mV has shape (dH, Page_Size, Num_Page_Global, kvH); BS is folded into Num_Page_Global so there is no BS mode.
Tensor gV = local_tile(mV, make_shape(Int<CTA_dH>{}, Int<Page_Size>{}),
make_coord(0, 0, _, work_tile_info.kvH_idx)); // (CTA_dH, Page_Size, Num_Page_Global)
// group modes [0,2) on both sV_paged and gV so the TMA box (CTA_dH, Page_Size) is mode 0; outer modes are
// (NumPagePerCTATile, BMM2_DMA_Stage) for smem and Num_Page_Global for gmem.
auto [tAgV, tAsV] = tma_partition(*tma_atom_V,
Int<0>{}, // cta_coord: 1x1 cluster
Layout<_1>{}, // cta_layout: CTA coord -> logical multicast id, no multicast, just identity layout
group_modes<0, 2>(sV_paged), group_modes<0, 2>(gV));
// tAsV: ((TMA, NumTma_K), NumPagePerCTATile, BMM2_DMA_Stage) -- 3 modes
// tAgV: ((TMA, NumTma_K), Num_Page_Global) -- 2 modes
// the shape of the TMA box is (CTA_dH, Page_Size)
/*if (elect_one_sync() && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 0)) {
printf("[PAGED-V] sV_paged:\t"); print(sV_paged); printf("\n"); // (CTA_dH, Page_Size, NumPagePerCTATile, BMM2_DMA_Stage)
printf("[PAGED-V] gV:\t"); print(gV); printf("\n"); // (CTA_dH, Page_Size, Num_Page_Global)
printf("[PAGED-V] tAgV:\t"); print(tAgV); printf("\n"); // ((TMA, NumTma_K), Num_Page_Global)
printf("[PAGED-V] tAsV:\t"); print(tAsV); printf("\n"); // ((TMA, NumTma_K), NumPagePerCTATile, BMM2_DMA_Stage)
}*/
// total K/V bytes per stage slot = all NumPagePerCTATile pages combined (they share one full barrier).
// Slice tAsK/tAsV at one stage to get all pages in that stage; size_in_bytes is implicit via sizeof(tensor_like).
int tma_K_transaction_bytes = sizeof(make_tensor_like(tAsK(_, _, 0)));
int tma_V_transaction_bytes = sizeof(make_tensor_like(tAsV(_, _, 0)));
int k_tile_count = work_tile_info.kvL_idx_end - work_tile_info.kvL_idx_start;
// BMM1_DMA_Stage = mode 3 of rank-4 sK_paged (Page_Size, CTA_dH, NumPage, Stage)
// BMM2_DMA_Stage = mode 3 of rank-4 sV_paged (CTA_dH, Page_Size, NumPage, Stage)
int constexpr BMM1_DMA_Stage = size<3>(sK_paged);
int constexpr BMM2_DMA_Stage = size<3>(sV_paged);
/*if (elect_one_sync() && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 0)) {
printf("[PAGED-K] tma_K_transaction_bytes (per CTA tile)=%d\n", tma_K_transaction_bytes);
printf("[PAGED-V] tma_V_transaction_bytes=%d, k_tile_count=%d, BMM1_DMA_Stage=%d, BMM2_DMA_Stage=%d\n", tma_V_transaction_bytes, k_tile_count, BMM1_DMA_Stage, BMM2_DMA_Stage);
}*/
// details of how the phase bit works is in DMA_Q_warp (in tgv_gqa.cuh)
int tma_bmm1_empty_barrier_phase_bit = 1;
int tma_bmm2_empty_barrier_phase_bit = 1;
int bmm1_k_tile = 0;
int bmm2_k_tile = 0;
bool constexpr Print = false;
// sPageIdx layout: ((NumPagePerCTATile, Tiles_Per_Pi_Stage), Page_Idx_Stage). For tile t (CTA-tile idx):
// t_in_stage = t % Tiles_Per_Pi_Stage; pi_slot = (t / Tiles_Per_Pi_Stage) % Page_Idx_Stage;
// page_idx_smem_ptr = &sPageIdx(make_coord(0, t_in_stage), pi_slot).
Tensor sPageIdx = shared_storage.tensor_sPageIdx();
// gPageTable view: per-CTA slice of the page table, then split mode 0 by NumPagePerCTATile so the
// (page-within-CTA-tile, CTA-tile-idx) split is explicit. BS is sliced upfront -- it's CTA-constant.
// Tail-padded host allocation lets us read up to Tiles_Per_Pi_Stage tiles past kvL_idx_end without OOB.
Tensor gPageTable = logical_divide(mPageTable(_, work_tile_info.BS_idx), Shape<Int<NumPagePerCTATile>>{}); // ((NumPagePerCTATile, kvL/CTA_kvL),)
// SOL order: K0, (K1, V0), (K2, V1), ...
//
// Page-idx fetch pipeline. The page-idx smem buffer is double-buffered (Page_Idx_Stage=2 slots) because
// K leads V by 1 tile -- around a pi-stage boundary K reads slot S%2 (stage S) while V is finishing the
// last tile of slot (S-1)%2 (stage S-1), so the two slots must coexist. Synchronization is thread-local
// cp_async_fence/cp_async_wait<0> + __syncwarp -- no cross-warp barrier.
// * Prolog: issue stage 0 -> slot 0, fence, wait, syncwarp, then K0 TMA.
// * Main loop K_t_in_stage == 0: drain the previously-issued cp.async (stage K_pi_stage's data). At this
// point at most one cp.async group is outstanding so wait<0> is correct.
// * Main loop K_t_in_stage == 1: pre-issue stage K_pi_stage+1 -> slot (K_pi_stage+1)%2 if it has tiles
// in this CTA's range. V crossed into pi-stage K_pi_stage at the previous iteration, so slot
// (K_pi_stage+1)%2 = (K_pi_stage-1)%2 is no longer being ld.shared'd and is safe to overwrite. The
// consumer's wait at the next pi-stage boundary (Tiles_Per_Pi_Stage-1 tiles later) hides the RTT.
// Prolog: stage 0 must be in slot 0 before K0 reads it.
{
issue_pi_stage_cp_async<Tiles_Per_Pi_Stage, Num_Page_Idx_Per_Stage, Page_Idx_Stage>(
gPageTable, sPageIdx, work_tile_info.kvL_idx_start, /*pi_stage_idx=*/0);
cp_async_fence();
cp_async_wait<0>();
__syncwarp();
// K0 prolog (no V yet).
TMA_copy_paged<decltype(tAgK), decltype(tAsK), TmaAtomK, 'K', Print, BMM1_DMA_Stage, NumPagePerCTATile>(tAgK, tAsK, bmm1_k_tile, &sPageIdx(make_coord(0, 0), 0), tma_bmm1_empty_barrier_phase_bit, tma_K_transaction_bytes, tma_atom_K, shared_storage.tma_bmm1_full_barrier, shared_storage.tma_bmm1_empty_barrier);
bmm1_k_tile++;
}
for (; bmm1_k_tile < k_tile_count; bmm1_k_tile++, bmm2_k_tile++) {
int K_t_in_stage = bmm1_k_tile % Tiles_Per_Pi_Stage;
int K_pi_stage = bmm1_k_tile / Tiles_Per_Pi_Stage;
int K_pi_slot = K_pi_stage % Page_Idx_Stage;
// Drain stage K_pi_stage's cp.async before K reads its slot.
if (K_t_in_stage == 0) {
cp_async_wait<0>();
__syncwarp();
}
// Pre-issue stage K_pi_stage+1 if there are still tiles in this CTA's range to consume from it.
if (K_t_in_stage == 1) {
int next_pi_stage = K_pi_stage + 1;
int next_tile_start = work_tile_info.kvL_idx_start + next_pi_stage * Tiles_Per_Pi_Stage;
if (next_tile_start < work_tile_info.kvL_idx_end) {
issue_pi_stage_cp_async<Tiles_Per_Pi_Stage, Num_Page_Idx_Per_Stage, Page_Idx_Stage>(
gPageTable, sPageIdx, work_tile_info.kvL_idx_start, next_pi_stage);
cp_async_fence();
}
}
TMA_copy_paged<decltype(tAgK), decltype(tAsK), TmaAtomK, 'K', Print, BMM1_DMA_Stage, NumPagePerCTATile>(tAgK, tAsK, bmm1_k_tile, &sPageIdx(make_coord(0, K_t_in_stage), K_pi_slot), tma_bmm1_empty_barrier_phase_bit, tma_K_transaction_bytes, tma_atom_K, shared_storage.tma_bmm1_full_barrier, shared_storage.tma_bmm1_empty_barrier);
int V_t_in_stage = bmm2_k_tile % Tiles_Per_Pi_Stage;
int V_pi_slot = (bmm2_k_tile / Tiles_Per_Pi_Stage) % Page_Idx_Stage;
TMA_copy_paged<decltype(tAgV), decltype(tAsV), TmaAtomV, 'V', Print, BMM2_DMA_Stage, NumPagePerCTATile>(tAgV, tAsV, bmm2_k_tile, &sPageIdx(make_coord(0, V_t_in_stage), V_pi_slot), tma_bmm2_empty_barrier_phase_bit, tma_V_transaction_bytes, tma_atom_V, shared_storage.tmasoftmax_bmm2_full_barrier, shared_storage.tma_bmm2_empty_barrier);
}
// V epilog (last V). Its slot is already populated by an earlier cp.async.
{
int V_t_in_stage = bmm2_k_tile % Tiles_Per_Pi_Stage;
int V_pi_slot = (bmm2_k_tile / Tiles_Per_Pi_Stage) % Page_Idx_Stage;
TMA_copy_paged<decltype(tAgV), decltype(tAsV), TmaAtomV, 'V', Print, BMM2_DMA_Stage, NumPagePerCTATile>(tAgV, tAsV, bmm2_k_tile, &sPageIdx(make_coord(0, V_t_in_stage), V_pi_slot), tma_bmm2_empty_barrier_phase_bit, tma_V_transaction_bytes, tma_atom_V, shared_storage.tmasoftmax_bmm2_full_barrier, shared_storage.tma_bmm2_empty_barrier);
}
cutlass::arch::launch_dependent_grids();
}
// mK has shape (Page_Size, dH, num_pages, kvH)
// mQ has shape ((qHLocal, qL), dH, kvH, BS)
// mV has shape (dH, Page_Size, num_pages, kvH)
// mO has shape (dH, (qHLocal, qL), kvH, BS)
// mSink has shape ((qHLocal, qL), kvH)
// mSeqLens has shape (BS)
// mPageTable has shape (kvL/Page_Size, BS), entry [p, bs] = physical page id of batch bs's per-batch page p
// (BS is folded into num_pages on the K/V side; mPageTable supplies the (bs, kvL) -> page mapping.)
template <
class SharedStorage,
class KTensor, class QTensor, class VTensor, class OTensor, class SinkTensor,
class SeqLensTensor, class PageTableTensor,
class TmaAtomK, class TmaAtomQ, class TmaAtomV,
class TiledBMM1, class TiledBMM2,
class TypeAcc,
int CTA_qHLocal, int CTA_qL, int CTA_kvL, int CTA_dH,
int Page_Size,
int BMM1_DMA_Stage, int BMM2_DMA_Stage,
int Page_Idx_Stage, int Num_Page_Idx_Per_Stage,
int MaxSplits, int NumReductionCTA,
bool NoSink>
__maxnreg__(128)
__global__ void
gqa_paged_device(
KTensor mK,
QTensor mQ,
VTensor mV,
OTensor mO,
SinkTensor mSink,
SeqLensTensor mSeqLens,
PageTableTensor mPageTable, // (kvL/Page_Size, BS); underlying allocation tail-padded by Num_Page_Idx_Per_Stage ints
CUTE_GRID_CONSTANT TmaAtomK const tma_atom_K,
CUTE_GRID_CONSTANT TmaAtomQ const tma_atom_Q,
CUTE_GRID_CONSTANT TmaAtomV const tma_atom_V,
TiledBMM1 tiled_bmm1,
TiledBMM2 tiled_bmm2,
float softmax_scale_log2,
int sliding_window_size,
int pdl_count
) {
// Allocate SMEM
extern __shared__ char shared_memory[];
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
// WorkTileInfo, for non persistent static scheduler, cta id is the work tile info
// since loading the seq_lens is on the critical path of the prolog, we want to start it as soon as possible
// mK shape: (Page_Size, CTA_dH, Num_Page_Global, kvH) -- BS is folded into Num_Page_Global; kvH is mode 3.
int kvH = shape<3>(mK);
int BS_idx = blockIdx.x / kvH;
// only thread 0 issues cp.async to load the seq_len -- keep this as the first thing on the critical path.
if (threadIdx.x == 0) {
cp_async(&mSeqLens(BS_idx), &shared_storage.seq_len);
}
int warp_idx = cutlass::canonical_warp_idx_sync();
// barrier initialization, warp 0 does initialization
if (warp_idx == 0) {
// transaction barrier because tma arrive on it, 6 thread arrive: one for DMA_Q warp, one for DMA_KV (K fetch) warp, and 4 for softmax warp (tcgen05.ld Acc1 is done).
// For paged K, DMA_KV still arrives ONCE per stage (via set_barrier_transaction_bytes once with total bytes for all pages).
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterTransactionBarrier, BMM1_DMA_Stage>(shared_storage.tma_bmm1_full_barrier, /* arrival count */ 6);
// 1 thread (BMM1) arrive to signal DMA_Q and DMA_KV (K fetch) warp
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterBarrier, BMM1_DMA_Stage>(shared_storage.tma_bmm1_empty_barrier, /* arrival count */ 1);
// transaction barrier because tma arrive on it, 5 thread arrive: one for DMA_KV (V fetch) warp and 4 for softmax warp (S/P store)
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterTransactionBarrier, BMM2_DMA_Stage>(shared_storage.tmasoftmax_bmm2_full_barrier, /* arrival count */ 5);
// 1 thread (BMM2) arrive to signal DMA_KV (V fetch) and softmax warp (P store)
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterBarrier, BMM2_DMA_Stage>(shared_storage.tma_bmm2_empty_barrier, /* arrival count */ 1);
// 1 thread (BMM1) arrive to signal softmax
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterBarrier, 1>(&shared_storage.bmm1_softmax_full_barrier, /* arrival count */ 1);
// 1 thread (BMM2) arrive to signal epilog
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterBarrier, 1>(&shared_storage.bmm2_epilog_full_barrier, /* arrival count */ 1);
// 32 (mma) + 128 (epilog) to signal tmem allocation/deallocation result
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterBarrier, 1>(&shared_storage.tmem_allocation_result_barrier, /* arrival count */ 32 + 128);
// 1 thread (epilog) arrive to signal maxsum
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterTransactionBarrier, 1>(&shared_storage.maxsum_mailbox_full_barrier, /* arrival count */ 1);
// 1 thread (epilog) arrive to signal acc2
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterTransactionBarrier, 1>(&shared_storage.acc2_mailbox_full_barrier, /* arrival count */ 1);
// No page-idx full/empty barriers: DMA_KV fetches its own page indices via cp.async with thread-local
// cp_async_fence/wait, so the cross-warp transaction-barrier handshake is gone.
}
// syncing all threads (128) within 4 epilog warps
cutlass::arch::NamedBarrier epilog_barrier(128, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
// barrier initialization needs to be visible to all warps
// defer it as late as possible to allow some thread divergence in prolog
cutlass::arch::fence_barrier_init();
#if 0
// this will have a membar.gpu to ensure dsmem write visibility within the entire cluster, because there isn't a membar.cluster
// membar.gpu is 0.2us
cluster_sync();
#else
// the alternative is to use proper fences
// at the ptx level, fence.mbarrier_init.release.cluster act as a release fence (in cluster scope) for mbarrier init op
cutlass::arch::fence_barrier_init();
// thread 0 waits for its previously issued cp.async to complete
// here we overlap the cp.async with the barrier initialization as much as possible
if (threadIdx.x == 0) {
cp_async_fence();
cp_async_wait<0>();
}
// the cluster sync serves two purposes:
// 1. it waits for all threads in the cluster to see the barrier initialization
// 2. it waits for all threads in the cta to see the cp.async result in smem
cluster_arrive_relaxed();
cluster_wait();
#endif
// bid.y includes rasterization of qHLocal and qL
int qHLocal = shape<0>(shape<0>(mQ));
int num_qHLocal = cutlass::ceil_div(qHLocal, CTA_qHLocal);
// bid.z is the split id for kvL split, try to evenly distribute the kvL blocks to every CTA
int rank = blockIdx.z;
int seq_len = shared_storage.seq_len;
// Sliding window optimization: when enabled, we only process tokens in range
// [seq_len - sliding_window_size, seq_len). To simplify tile distribution,
// we align the skip boundary to CTA_kvL tiles.
int workload_seq_len = seq_len;
int seq_len_skip_offset = 0;
if (sliding_window_size > 0 && sliding_window_size < seq_len) {
int unaligned_skip = seq_len - sliding_window_size;
seq_len_skip_offset = (unaligned_skip / CTA_kvL) * CTA_kvL;
workload_seq_len = seq_len - seq_len_skip_offset;
}
int NumKVTiles = cutlass::ceil_div(workload_seq_len, CTA_kvL);
int NumSplits = cute::min(MaxSplits, NumKVTiles);
int kvL_tile_count_per_cta = NumKVTiles / MaxSplits;
int kvL_tile_count_remainder = NumKVTiles % MaxSplits;
int kvL_tile_count = kvL_tile_count_per_cta + (rank < kvL_tile_count_remainder ? 1 : 0);
int kvL_tile_count_skip_offset = seq_len_skip_offset / CTA_kvL;
int kvL_tile_count_start = rank * kvL_tile_count_per_cta + (rank < kvL_tile_count_remainder ? rank : kvL_tile_count_remainder) + kvL_tile_count_skip_offset;
int kvL_tile_count_end = kvL_tile_count_start + kvL_tile_count;
WorkTileInfo work_tile_info {
.BS_idx = (int32_t)BS_idx,
.kvH_idx = (int32_t)(blockIdx.x % kvH),
.kvL_idx_start = (int32_t)kvL_tile_count_start,
.kvL_idx_end = (int32_t)kvL_tile_count_end,
.dH_idx = 0, // no dH tiling for bmm2
.qHLocal_idx = (int32_t)(blockIdx.y % num_qHLocal),
.qL_idx = (int32_t)(blockIdx.y / num_qHLocal),
.is_valid_tile = (kvL_tile_count > 0)
};
if (warp_idx == 0) {
gqa::DMA_Q_warp<SharedStorage, WorkTileInfo, decltype(mQ), TmaAtomQ, TiledBMM1, CTA_qHLocal, CTA_qL, CTA_dH>(shared_storage, work_tile_info, mQ, &tma_atom_Q, tiled_bmm1);
}
else if (warp_idx == 1) {
DMA_KV_warp<SharedStorage, WorkTileInfo, decltype(mK), decltype(mV), decltype(mPageTable), TmaAtomK, TmaAtomV, CTA_kvL, CTA_dH, Page_Size, Page_Idx_Stage, Num_Page_Idx_Per_Stage>(shared_storage, work_tile_info, mK, mV, mPageTable, &tma_atom_K, &tma_atom_V);
}
else if (warp_idx == 2) {
gqa::MMA_warp<SharedStorage, WorkTileInfo, decltype(mO), TiledBMM1, TiledBMM2, CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH>(shared_storage, work_tile_info, mO, tiled_bmm1, tiled_bmm2);
}
// warp_idx == 3 is unused (the previous Read_Page_Idx warp was folded into DMA_KV).
else if (warp_idx >= 4) {
// epilog tid is from 128 to 255, need to offset by -128 when getting the per thread slice
int tid = threadIdx.x - 128;
// EPILOG_warp only uses mK for shape(mK) (to build a kvL coord predicate). mK on the paged kernel has shape
// (Page_Size, CTA_dH, Num_Page_Global, kvH) with BS folded into Num_Page_Global, so we build a coord-only
// identity tensor with the original (kvL, dH, kvH, BS) shape. BS comes straight from mO (rank 4, mode 3).
// kvL is recovered from mPageTable shape (kvL/Page_Size, BS); Page_Size is a compile-time constant so this
// is a constant multiply (no integer divide), unlike deriving from shape<2>(mK)/BS which costs ~3%.
int dH = shape<1>(mK);
int BS = shape<3>(mO);
int kvL = static_cast<int>(shape<0>(mPageTable)) * Page_Size;
auto mK_coord = make_identity_tensor(make_shape(kvL, dH, kvH, BS));
// warp_idx - 4 because epilog warp group starts from warp 4
gqa::EPILOG_warp<SharedStorage, WorkTileInfo, decltype(mK_coord), decltype(mO), decltype(mSink), TiledBMM1, TiledBMM2, CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH, NumReductionCTA, NoSink>(shared_storage, work_tile_info, mK_coord, mO, mSink, seq_len, tiled_bmm1, tiled_bmm2, softmax_scale_log2, sliding_window_size, epilog_barrier, NumSplits, tid, warp_idx - 4, rank);
}
__syncthreads();
}
// KV has shape (num_pages_total, 2, Page_Size, kvH, dH); KV(_,0,...) is K, KV(_,1,...) is V; BS is folded into num_pages_total
// Q has shape ((qHLocal, qL), dH, kvH, BS)
// O has shape (dH, (qHLocal, qL), kvH, BS)
// sinks has shape (qHLocal * kvH), i.e. one sink per q head, when device_ptr_sinks is nullptr, it's disabled
// seq_lens has shape (BS); kvL is max_seq_len, seq_lens[bs] is the actual seq len for batch bs
// page_table has shape (kvL/Page_Size, BS)
// sliding_window_size is the size of the sliding window, when it's 0, it's disabled
template<
class TypeQKV, class TypeO, class TypeAcc,
int CTA_qHLocal, int CTA_qL, int CTA_kvL, int CTA_dH,
int Page_Size,
int BMM1_DMA_Stage, int BMM2_DMA_Stage,
int Page_Idx_Stage, int Num_Page_Idx_Per_Stage,
int MaxSplits, int NumReductionCTA>
void gqa_paged_host(
TypeQKV* device_ptr_KV,
TypeQKV* device_ptr_Q,
TypeO* device_ptr_O,
TypeAcc* device_ptr_sinks,
int* seq_lens,
int* device_ptr_page_table, // (kvL/Page_Size, BS); underlying allocation must be tail-padded by Num_Page_Idx_Per_Stage ints
int kvH, int qHLocal, int qL, int kvL, int dH, int BS,
int stride_KV_pages, int stride_KV_KV, int stride_KV_ps, int stride_KV_kvH, int stride_KV_dH,
int stride_Q_kvH, int stride_Q_qHLocal, int stride_Q_qL, int stride_Q_dH, int stride_Q_BS,
int stride_O_kvH, int stride_O_qHLocal, int stride_O_qL, int stride_O_dH, int stride_O_BS,
int stride_PT_p, int stride_PT_BS,
float softmax_scale,
int sliding_window_size,
bool pdl, int pdl_count = -1,
cudaStream_t stream = 0
) {
assert(kvL % Page_Size == 0);
int num_pages_total = BS * (kvL / Page_Size);
// Reconstruct the combined KV gmem tensor exactly as the harness laid it out: shape
// (num_pages_total, 2, Page_Size, kvH, dH) with the 5 strides supplied by the caller.
auto layout_KV = make_layout(
make_shape(num_pages_total, Int<2>{}, Int<Page_Size>{}, kvH, dH),
make_stride(stride_KV_pages, stride_KV_KV, stride_KV_ps, stride_KV_kvH, stride_KV_dH));
Tensor mKV = make_tensor(make_gmem_ptr(device_ptr_KV), layout_KV); // (num_pages_total, 2, Page_Size, kvH, dH)
// Slice on the K/V mode (=1). Each slice has shape (num_pages_total, Page_Size, kvH, dH); the kernel expects the
// modes in MMA order, so we permute via select<...>: K wants (Page_Size, dH, num_pages_total, kvH) -> indices
// (1,3,0,2); V wants (dH, Page_Size, num_pages_total, kvH) -> indices (3,1,0,2).
Tensor mKV_K = mKV(_, 0, _, _, _);
Tensor mKV_V = mKV(_, 1, _, _, _);
Tensor mK = make_tensor(mKV_K.data(), select<1, 3, 0, 2>(mKV_K.layout())); // (Page_Size, dH, num_pages_total, kvH)
Tensor mV = make_tensor(mKV_V.data(), select<3, 1, 0, 2>(mKV_V.layout())); // (dH, Page_Size, num_pages_total, kvH)
Layout layout_Q = gqa::make_layout_Q(kvH, qHLocal, qL, dH, BS, stride_Q_kvH, stride_Q_qHLocal, stride_Q_qL, stride_Q_dH, stride_Q_BS);
Layout layout_O = gqa::make_layout_O(kvH, qHLocal, qL, dH, BS, stride_O_kvH, stride_O_qHLocal, stride_O_qL, stride_O_dH, stride_O_BS);
Layout layout_sinks = gqa::make_layout_sinks(qHLocal, qL, kvH);
// Page table tensor as the harness owns it: rank-2 shape (kvL/Page_Size, BS), strides supplied by the caller.
// DMA_KV warp on the device side partitions mode-0 via logical_divide(_, Shape<NumPagePerCTATile>) into
// ((NumPagePerCTATile, MaxNumKVTiles), BS) -- the (page-within-CTA-tile, CTA-tile-idx, batch) view it actually
// indexes. Keeping the gmem-side layout flat per batch lets the harness pass any contiguous or page-strided table.
// (DMA_KV uses 4-byte cp.async, so no cp.async.bulk-style 16B alignment requirements on bytes/base/stride.)
auto layout_PageTable = make_layout(
make_shape(kvL / Page_Size, BS),
make_stride(stride_PT_p, stride_PT_BS));
Tensor mPageTable = make_tensor(make_gmem_ptr(device_ptr_page_table), layout_PageTable); // (kvL/Page_Size, BS)
// how we handle oob:
// oob for K, Q, V are handled by TMA
// oob for O is explicitly handled by predicate in the epilog since it uses simple st.global epilog
// we partition kvL with tile size of CTA_kvL, and we evenly distribute the kvL blocks to MaxSplits number of cta in the cluster
assert(NumReductionCTA <= MaxSplits);
static_assert(((CTA_qHLocal * CTA_qL) % NumReductionCTA) == 0, "each reduction cta must have even number of q tokens");
// mK and mV are constructed above by slicing+permuting the combined mKV tensor.
Tensor mQ = make_tensor(make_gmem_ptr(device_ptr_Q), layout_Q); // ((qHLocal, qL), dH, kvH, BS)
Tensor mO = make_tensor(make_gmem_ptr(device_ptr_O), layout_O); // (dH, (qHLocal, qL), kvH, BS)
Tensor mSink = make_tensor(make_gmem_ptr(device_ptr_sinks), layout_sinks); // ((qHLocal, qL), kvH)
Tensor mSeqLens = make_tensor(make_gmem_ptr(seq_lens), make_layout(make_shape(BS))); // (BS)
static_assert(CTA_kvL == 128, "BMM1's MMA_M needs to be 128 for tcgen05.ld->softmax");
static_assert(((CTA_qHLocal * CTA_qL) % 8) == 0, "BMM1's MMA_N needs to be divisible by 8 for tcgen05.mma");
assert(dH == CTA_dH); // bmm1 only has 1 kblock (i.e. 1 Q tile), bmm2 deal with all dH for now, in the foreseable future this is the hardest constraint to lift
static_assert((CTA_dH == 128) || (CTA_dH == 64), "BMM2's MMA_M needs to be at either 128 or 64 for tcgen05.ld->correction");
// we swap AB so bmm1 is K (CTA_kvL, CTA_dH) x Q (CTA_dH, CTA_qHLocal * CTA_qL)
// both Q and K are dH (K in gemm terminology) major
// M = CTA_kvL, N = CTA_qHLocal * CTA_qL, K = CTA_dH
TiledMMA tiled_bmm1 = cutlass::gemm::collective::detail::sm100_make_1sm_trivial_tiled_mma<
TypeQKV, TypeQKV, TypeAcc, // Mma's A, B, and Accumulator types
Shape<Int<CTA_kvL>, Int<CTA_qHLocal * CTA_qL>, Int<CTA_dH>>, // TileShape_MNK
Shape<_1, _1, _1>, // ClusterShape_MNK
cute::UMMA::Major::K, cute::UMMA::Major::K>();
// we swap AB for bmm2 as well, V (dH, CTA_kvL) x P (CTA_kvL, CTA_qHLocal * CTA_qL)
// V is dH (M in gemm terminology) major, P is CTA_kvL (K in gemm terminology) major in smem after each thread writes P from rmem to smem
// M = CTA_dH, N = CTA_qHLocal * CTA_qL, K = CTA_kvL
TiledMMA tiled_bmm2 = cutlass::gemm::collective::detail::sm100_make_1sm_trivial_tiled_mma<
TypeQKV, TypeQKV, TypeAcc, // Mma's A, B, and Accumulator types
Shape<Int<CTA_dH>, Int<CTA_qHLocal * CTA_qL>, Int<CTA_kvL>>, // TileShape_MNK
Shape<_1, _1, _1>, // ClusterShape_MNK
cute::UMMA::Major::MN, cute::UMMA::Major::K>();
// Pre-partitioned smem Tile Shape to post-partitioned smem tile shape ((Mma_M, Mma_K), NumMma_M, NumMma_K, DMA_Stage)
auto shape_K = make_shape(Int<CTA_kvL>{}, Int<CTA_dH>{}, Int<BMM1_DMA_Stage>{});
auto shape_Q = make_shape(make_shape(Int<CTA_qHLocal>{}, Int<CTA_qL>{}), Int<CTA_dH>{}, Int<1>{});
auto shape_S = make_shape(Int<CTA_kvL>{}, make_shape(Int<CTA_qHLocal>{}, Int<CTA_qL>{}), Int<1>{});
auto mma_shape_K = partition_shape_A(tiled_bmm1, shape_K);
auto mma_shape_Q = partition_shape_B(tiled_bmm1, shape_Q);
auto shape_V = make_shape(Int<CTA_dH>{}, Int<CTA_kvL>{}, Int<BMM2_DMA_Stage>{});
auto shape_P = select<1, 0, 2>(shape_S); // just a permutation of shape_S
auto mma_shape_V = partition_shape_A(tiled_bmm2, shape_V);
// choose the swizzle atom for K, Q, S, V and P
auto SmemLayoutAtomK = cutlass::gemm::collective::detail::sm100_smem_selector<
cute::UMMA::Major::K, // gmem layout of K
TypeQKV, // data type of K
decltype(shape<0>(shape_K)), decltype(shape<1>(shape_K))>(); // tile size of K
auto SmemLayoutAtomQ = cutlass::gemm::collective::detail::sm100_smem_selector<
cute::UMMA::Major::K, // gmem layout of Q
TypeQKV, // data type of Q
decltype(shape<0>(shape_Q)), decltype(shape<1>(shape_Q))>(); // tile size of Q
// for bmm1 tcgen05.ld, each register is holding a row of S (CTA_kvL is mapped to the thread dimension), if we do
// st.shared from rmem to smem, to avoid bank conflict, we need to put T0V0, T1V0, T2V0, ... T31V0 contiguously in smem.
// then the smem layout of S is M (CTA_kvL) major, so we choose MN major swizzle atom
auto SmemLayoutAtomS = UMMA::Layout_MN_SW128_Atom<TypeQKV>{};
auto SmemLayoutAtomV = cutlass::gemm::collective::detail::sm100_smem_selector<
cute::UMMA::Major::MN, // gmem layout of V
TypeQKV, // data type of V
decltype(shape<0>(shape_V)), decltype(shape<1>(shape_V))>(); // tile size of V
// swizzle atom for P should be the transpose of the swizzle atom for S, because they literally represent the same tensor just with different dimension order (aka a pytorch transpose)
auto SmemLayoutAtomP = UMMA::Layout_K_SW128_Atom<TypeQKV>{};
// finally construct the smem layout for tile K, Q, V, S, and P
auto sK_layout = UMMA::tile_to_mma_shape(SmemLayoutAtomK, mma_shape_K); // ((Mma_M, Mma_K), NumMma_M, NumMma_K, BMM1_DMA_Stage)
// paged K smem layout: same memory as sK_layout, viewed as (Page_Size, CTA_dH, NumPagePerCTATile, BMM1_DMA_Stage).
// tile_to_shape uses Step<_1, _3, _2, _4> to stack SmemLayoutAtomK the same way as sK_layout,
// i.e. first along CTA_kvL/(Page_Size, NumPagePerCTATile), then along CTA_dH, finally along BMM1_DMA_Stage
static_assert(CTA_kvL % Page_Size == 0, "Page_Size must divide CTA_kvL");
constexpr int NumPagePerCTATile = CTA_kvL / Page_Size;
auto sK_paged_layout = tile_to_shape(SmemLayoutAtomK,
make_shape(Int<Page_Size>{}, Int<CTA_dH>{}, Int<NumPagePerCTATile>{}, Int<BMM1_DMA_Stage>{}),
Step<_1, _3, _2, _4>{});
auto sQ_layout = UMMA::tile_to_mma_shape(SmemLayoutAtomQ, mma_shape_Q); // ((Mma_N, Mma_K), NumMma_N, NumMma_K, 1)
auto sV_layout = UMMA::tile_to_mma_shape(SmemLayoutAtomV, mma_shape_V); // ((Mma_M, Mma_K), NumMma_M, NumMma_K, BMM2_DMA_Stage)
// paged V smem layout: same memory as sV_layout, viewed as (CTA_dH, Page_Size, NumPagePerCTATile, BMM2_DMA_Stage).
// V is MN-major: atom is (atom_M=CTA_dH, atom_K) -> stacking inside the atom covers all of CTA_dH; multiple atoms tile
// along K (kvL). Splitting kvL into Page_Size + NumPage just renames the K-iter modes; default LayoutLeft step works
// because mode order (M -> K_inner -> K_outer -> Stage) matches the natural sV_layout stride sequence.
auto sV_paged_layout = tile_to_shape(SmemLayoutAtomV,
make_shape(Int<CTA_dH>{}, Int<Page_Size>{}, Int<NumPagePerCTATile>{}, Int<BMM2_DMA_Stage>{}));
// The paged and MMA-partitioned views of K/V must alias the same smem buffer -> cosize must match.
static_assert(cute::cosize_v<decltype(sK_paged_layout)> == cute::cosize_v<decltype(sK_layout)>,
"sK_paged_layout and sK_layout must alias the same smem buffer (cosize must match)");
static_assert(cute::cosize_v<decltype(sV_paged_layout)> == cute::cosize_v<decltype(sV_layout)>,
"sV_paged_layout and sV_layout must alias the same smem buffer (cosize must match)");
// S and P use tile_to_shape as we do the mma partition in the kernel later
auto sS_layout = tile_to_shape(SmemLayoutAtomS, shape_S, Step<_1, _2, _3>{}); // (CTA_kvL, (CTA_qHLocal, CTA_qL), 1)
auto sP_layout = tile_to_shape(SmemLayoutAtomP, shape_P, Step<_2, _1, _3>{}); // ((CTA_qHLocal, CTA_qL), CTA_kvL, 1)
auto sAcc1_layout = make_layout(shape_S); // (CTA_kvL, (CTA_qHLocal, CTA_qL), 1)
// for storing fmax and fsum warp reduce partial results
int constexpr NumEpilogWarps = 4;
// NumEpilogWarps contiguous because we often ld.shared all NumEpilogWarps from 1/32 threads, this has best vectorization
auto sWarpReduce_layout = make_layout(make_shape(Int<NumEpilogWarps>{}, make_shape(Int<CTA_qHLocal>{}, Int<CTA_qL>{}))); // (NumEpilogWarps, (CTA_qHLocal, CTA_qL))
// MaxSplits contiguous because we often ld.shared all MaxSplits from 1/32 threads, this has best vectorization
auto sMSMailbox_Layout = make_layout(make_shape(Int<MaxSplits>{}, Int<CTA_qHLocal * CTA_qL / NumReductionCTA>{})); // (MaxSplits, CTA_qHLocal * CTA_qL / NumReductionCTA)
// default layout is CTA_dH contiguous to maximize st.async/ld.shared bw
auto sAcc2Mailbox_layout = make_layout(make_shape(Int<CTA_dH>{}, Int<CTA_qHLocal * CTA_qL / NumReductionCTA>{}, Int<MaxSplits>{})); // (CTA_dH, CTA_qHLocal * CTA_qL / NumReductionCTA, MaxSplits)
auto sSinks_layout = make_layout(Int<CTA_qHLocal * CTA_qL / NumReductionCTA>{}); // (CTA_qHLocal * CTA_qL / NumReductionCTA)
// DMA_KV's page-idx staging buffer. Page_Idx_Stage and Num_Page_Idx_Per_Stage are configured by the host
// harness (template params). Single static_assert block lives here so warp functions don't repeat them.
static_assert(Num_Page_Idx_Per_Stage % NumPagePerCTATile == 0,
"Num_Page_Idx_Per_Stage must be a multiple of CTA_kvL/Page_Size so a DMA stage's pages live in one pi stage");
// Page_Idx_Stage must be exactly 2: K leads V by 1 tile, so around a pi-stage boundary K reads slot S%2
// (stage S) while V finishes the last tile of slot (S-1)%2 (stage S-1). =1 would alias these slots; >2
// wastes smem because at most 2 pi-stage groups are live at once (the one being consumed + the pre-issued).
static_assert(Page_Idx_Stage == 2, "Page_Idx_Stage must be 2 for the K-leads-V-by-1 page-idx pipeline");
constexpr int Tiles_Per_Pi_Stage_Host = Num_Page_Idx_Per_Stage / NumPagePerCTATile;
// DMA_KV's folded page-idx pipeline issues stage S's cp.async at K_t_in_stage==1 of stage S-1, so each
// pi-stage must hold at least 2 CTA tiles. Choose Num_Page_Idx_Per_Stage >= 2 * NumPagePerCTATile.
static_assert(Tiles_Per_Pi_Stage_Host >= 2,
"Tiles_Per_Pi_Stage (= Num_Page_Idx_Per_Stage / NumPagePerCTATile) must be >= 2 for the folded page-idx pipeline");
// Hierarchical layout ((NumPagePerCTATile, Tiles_Per_Pi_Stage), Page_Idx_Stage), default LayoutLeft so
// mode 0 is fully contiguous (NumPagePerCTATile innermost). Letting cute carry the (p, t) split keeps the
// producer's inner write as sPageIdx(make_coord(p, t), stage_idx) without manual i/N + i%N arithmetic, and
// gives the consumer a contiguous NumPagePerCTATile-int slice via &sPageIdx(make_coord(0, t), stage_idx).
auto sPageIdx_layout = make_layout(make_shape(make_shape(Int<NumPagePerCTATile>{}, Int<Tiles_Per_Pi_Stage_Host>{}), Int<Page_Idx_Stage>{}));
// Now we can find the SMEM allocation size
using SMEMStorage = SharedStorage<TypeQKV, TypeAcc,
decltype(sK_layout), decltype(sK_paged_layout),
decltype(sQ_layout),
decltype(sV_layout), decltype(sV_paged_layout),
decltype(sS_layout), decltype(sP_layout),
decltype(sWarpReduce_layout), decltype(sMSMailbox_Layout), decltype(sAcc1_layout),
decltype(sAcc2Mailbox_layout), decltype(sSinks_layout),
decltype(sPageIdx_layout),
BMM1_DMA_Stage, BMM2_DMA_Stage, Page_Idx_Stage>;
static_assert(BMM1_DMA_Stage >= BMM2_DMA_Stage, "otherwise you are wasting BMM2 stage because BMM1 TMA issue will block BMM2 TMA due to insufficient BMM1 stages");
// create TMA descriptors for K, Q, V matrices
// K TMA box is (Page_Size, CTA_dH) -- one page per TMA copy. The per-page SMEM destination layout points into the
// same memory as a single page slot inside sK_layout/sK_paged_layout.
Copy_Atom tma_atom_K = make_tma_atom(
SM90_TMA_LOAD{}, // TMA Load Op, sm100 reuses sm90 tma atom
mK, // Source GMEM tensor
take<0, 2>(sK_paged_layout), // Destination SMEM layout for 1 page = 1 TMA box, (Page_Size, CTA_dH)
make_shape(Int<Page_Size>{}, Int<CTA_dH>{}) // TMA box shape
);
Tensor mK_tma = tma_atom_K.get_tma_tensor(shape(mK)); // (Page_Size, dH, num_pages_total, kvH)
Copy_Atom tma_atom_Q = make_tma_atom(
SM90_TMA_LOAD{}, // TMA Load Op, sm100 reuses sm90 tma atom
mQ, // Source GMEM tensor
// sQ_layout(_,_,_,Int<0>{}) doesn't work under some corner cases (composedlayout indexing), so we use
// the take method which is also correct.
take<0, 3>(sQ_layout), // Destination SMEM layout for 1 DMA_Stage, ((Mma_N, Mma_K), NumMma_N, NumMma_K)
make_shape(get<0>(shape_Q), get<1>(shape_Q)) // TMA box shape
);
Tensor mQ_tma = tma_atom_Q.get_tma_tensor(shape(mQ)); // ((qHLocal, qL), dH, kvH, BS)
// V TMA box is (CTA_dH, Page_Size) -- one page per TMA copy. Per-page SMEM destination layout points into the
// same memory as a single page slot inside sV_layout/sV_paged_layout.
Copy_Atom tma_atom_V = make_tma_atom(
SM90_TMA_LOAD{}, // TMA Load Op, sm100 reuses sm90 tma atom
mV, // Source GMEM tensor
take<0, 2>(sV_paged_layout), // Destination SMEM layout for 1 page = 1 TMA box, (CTA_dH, Page_Size)
make_shape(Int<CTA_dH>{}, Int<Page_Size>{}) // TMA box shape
);
Tensor mV_tma = tma_atom_V.get_tma_tensor(shape(mV)); // (dH, Page_Size, num_pages_total, kvH)
int smemBytes = sizeof(SMEMStorage);
// invoke the kernel
cudaLaunchConfig_t config;
cudaLaunchAttribute attrs[2];
// bid.x: kvH * BS, bid.y: qHLocal * qL, bid.z: kvL
uint32_t Cluster_Size = cute::max(MaxSplits, NumReductionCTA);
config.gridDim = dim3{
(uint32_t)kvH * BS,
(uint32_t)cutlass::ceil_div(qHLocal, CTA_qHLocal) * cutlass::ceil_div(qL, CTA_qL),
Cluster_Size};
config.blockDim = 256; // 8 warps
config.dynamicSmemBytes = smemBytes;
config.stream = stream;
attrs[0].id = cudaLaunchAttributeClusterDimension;
attrs[0].val.clusterDim = {1, 1, Cluster_Size};
attrs[1].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[1].val.programmaticStreamSerializationAllowed = 1;
config.attrs = attrs;
config.numAttrs = pdl ? 2 : 1;
if (device_ptr_sinks != nullptr) {
auto *kernel_instance =
&gqa_paged_device<SMEMStorage,
decltype(mK_tma), decltype(mQ_tma), decltype(mV_tma), decltype(mO), decltype(mSink),
decltype(mSeqLens), decltype(mPageTable),
decltype(tma_atom_K), decltype(tma_atom_Q), decltype(tma_atom_V),
decltype(tiled_bmm1), decltype(tiled_bmm2),
TypeAcc,
CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH,
Page_Size,
BMM1_DMA_Stage, BMM2_DMA_Stage,
Page_Idx_Stage, Num_Page_Idx_Per_Stage,
MaxSplits, NumReductionCTA,
false>;
gpuErrChk(cudaFuncSetAttribute(*kernel_instance, cudaFuncAttributeMaxDynamicSharedMemorySize, smemBytes));
// portable max cluster size is 8, but sm100a supports 16, need explicit opt in
gpuErrChk(cudaFuncSetAttribute(*kernel_instance, cudaFuncAttributeNonPortableClusterSizeAllowed, 1));
gpuErrChk(cudaLaunchKernelEx(&config, kernel_instance, mK_tma, mQ_tma, mV_tma, mO, mSink,
mSeqLens, mPageTable,
tma_atom_K, tma_atom_Q, tma_atom_V,
tiled_bmm1, tiled_bmm2,
softmax_scale * Log2_E, sliding_window_size, pdl_count));
}
else {
auto *kernel_instance =
&gqa_paged_device<SMEMStorage,
decltype(mK_tma), decltype(mQ_tma), decltype(mV_tma), decltype(mO), decltype(mSink),
decltype(mSeqLens), decltype(mPageTable),
decltype(tma_atom_K), decltype(tma_atom_Q), decltype(tma_atom_V),
decltype(tiled_bmm1), decltype(tiled_bmm2),
TypeAcc,
CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH,
Page_Size,
BMM1_DMA_Stage, BMM2_DMA_Stage,
Page_Idx_Stage, Num_Page_Idx_Per_Stage,
MaxSplits, NumReductionCTA,
true>;
gpuErrChk(cudaFuncSetAttribute(*kernel_instance, cudaFuncAttributeMaxDynamicSharedMemorySize, smemBytes));
// portable max cluster size is 8, but sm100a supports 16, need explicit opt in
gpuErrChk(cudaFuncSetAttribute(*kernel_instance, cudaFuncAttributeNonPortableClusterSizeAllowed, 1));
gpuErrChk(cudaLaunchKernelEx(&config, kernel_instance, mK_tma, mQ_tma, mV_tma, mO, mSink,
mSeqLens, mPageTable,
tma_atom_K, tma_atom_Q, tma_atom_V,
tiled_bmm1, tiled_bmm2,
softmax_scale * Log2_E, sliding_window_size, pdl_count));
}
}
} // namespace gqa_paged
} // namespace TGV

View File

@@ -1207,6 +1207,10 @@ struct SM100_MMA_S8_2x1SM_SS_SPARSE
}
};
template <class a_type, class b_type, class c_type, int M, int N,
UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One,
UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
struct SM100_MMA_F8F6F4_SS
{
using DRegisters = void;
@@ -1452,6 +1456,10 @@ struct SM100_MMA_MXF8F6F4_SS_SPARSE
}
};
template <class a_type, class b_type, class c_type, int M, int N,
UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One,
UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
struct SM100_MMA_F8F6F4_2x1SM_SS
{
using DRegisters = void;

View File

@@ -3327,12 +3327,9 @@ struct MMA_Traits<SM100_MMA_S8_2x1SM_SS_SPARSE<a_type, b_type, c_type,
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F8F6F4_SS, a_type, b_type, c_type,
cute::C<M>, cute::C<N>,
cute::integral_constant<UMMA::Major, a_major>,
cute::integral_constant<UMMA::Major, b_major>,
cute::integral_constant<UMMA::ScaleIn, a_neg>,
cute::integral_constant<UMMA::ScaleIn, b_neg>>
struct MMA_Traits<SM100_MMA_F8F6F4_SS<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
@@ -3390,7 +3387,9 @@ struct MMA_Traits<SM100_MMA_F8F6F4_SS, a_type, b_type, c_type,
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F8F6F4_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
SM100_MMA_F8F6F4_SS<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
@@ -3745,12 +3744,9 @@ struct MMA_Traits<SM100_MMA_F8F6F4_SS_SPARSE<a_type, b_type, c_type,
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F8F6F4_2x1SM_SS, a_type, b_type, c_type,
cute::C<M>, cute::C<N>,
cute::integral_constant<UMMA::Major, a_major>,
cute::integral_constant<UMMA::Major, b_major>,
cute::integral_constant<UMMA::ScaleIn, a_neg>,
cute::integral_constant<UMMA::ScaleIn, b_neg>>
struct MMA_Traits<SM100_MMA_F8F6F4_2x1SM_SS<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
@@ -3808,7 +3804,9 @@ struct MMA_Traits<SM100_MMA_F8F6F4_2x1SM_SS, a_type, b_type, c_type,
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F8F6F4_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
SM100_MMA_F8F6F4_2x1SM_SS<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};

View File

@@ -57,6 +57,14 @@
# endif // (__CUDA_ARCH__ >= 900)
#endif // defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__)
# if (__CUDA_ARCH__ >= 1000)
# if (__CUDACC_VER_MAJOR__ > 13) || ((__CUDACC_VER_MAJOR__ >= 13) && (__CUDACC_VER_MINOR__ >= 2))
# define CUDA_PTX_FP8_BF16_CVT_ENABLED 1
# endif // (__CUDACC_VER_MAJOR__ > 13) || ((__CUDACC_VER_MAJOR__ >= 13) && (__CUDACC_VER_MINOR__ >= 2))
# endif // (__CUDA_ARCH__ >= 1000)
#endif // defined(__CUDA_ARCH__)
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\
defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110A_ENABLED) ||\

View File

@@ -339,18 +339,16 @@ sm100_make_1sm_trivial_tiled_mma() {
) {
return make_tiled_mma(
cute::MMA_Traits<
cute::SM100_MMA_F8F6F4_SS,
cute::SM100_MMA_F8F6F4_SS<
ElementAMma,
ElementBMma,
ElementAMmaccumulator,
cute::C<M>,
cute::C<N>,
cute::integral_constant<UMMA::Major, UmmaMajorA>,
cute::integral_constant<UMMA::Major, UmmaMajorB>,
cute::integral_constant<UMMA::ScaleIn, ANeg>,
cute::integral_constant<UMMA::ScaleIn, BNeg>
>{}
M,
N,
UmmaMajorA,
UmmaMajorB,
ANeg,
BNeg>{}
);
}
else {
@@ -407,18 +405,16 @@ sm100_make_2sm_trivial_tiled_mma() {
) {
return make_tiled_mma(
cute::MMA_Traits<
cute::SM100_MMA_F8F6F4_2x1SM_SS,
cute::SM100_MMA_F8F6F4_2x1SM_SS<
ElementAMma,
ElementBMma,
ElementAMmaccumulator,
cute::C<M>,
cute::C<N>,
cute::integral_constant<UMMA::Major, UmmaMajorA>,
cute::integral_constant<UMMA::Major, UmmaMajorB>,
cute::integral_constant<UMMA::ScaleIn, ANeg>,
cute::integral_constant<UMMA::ScaleIn, BNeg>
>{}
M,
N,
UmmaMajorA,
UmmaMajorB,
ANeg,
BNeg>{}
);
}
@@ -739,17 +735,16 @@ sm100_make_trivial_mixed_input_tiled_mma() {
}
if constexpr (cute::is_same_v<ElementBMma, cutlass::float_e4m3_t>) {
return make_tiled_mma(
cute::MMA_Traits<
cute::SM100_MMA_F8F6F4_SS,
cute::SM100_MMA_F8F6F4_SS<
ElementAMma,
ElementBMma,
ElementAccumulator,
cute::C<M>,
cute::C<N>,
cute::integral_constant<UMMA::Major, UmmaMajorA>,
cute::integral_constant<UMMA::Major, UmmaMajorB>,
cute::integral_constant<UMMA::ScaleIn, cute::UMMA::ScaleIn::One>,
cute::integral_constant<UMMA::ScaleIn, cute::UMMA::ScaleIn::One>>{});
M,
N,
UmmaMajorA,
UmmaMajorB,
cute::UMMA::ScaleIn::One,
cute::UMMA::ScaleIn::One>{});
}
}
}

View File

@@ -1820,14 +1820,21 @@ struct NumericArrayConverter<cutlass::bfloat16_t, cutlass::float_e4m3_t, 2, Roun
#if defined(CUDA_PTX_FP8_CVT_ENABLED)
uint32_t res_half;
uint16_t const& src_packed = reinterpret_cast<uint16_t const&>(source);
asm volatile( \
"{\n" \
"cvt.rn.f16x2.e4m3x2 %0, %1;\n" \
"}\n" : "=r"(res_half): "h"(src_packed));
float2 res_float = __half22float2(reinterpret_cast<__half2 &>(res_half));
NumericArrayConverter<cutlass::bfloat16_t, float, 2, Round> converter;
return converter(reinterpret_cast<Array<float, 2> const&>(res_float));
#if defined(CUDA_PTX_FP8_BF16_CVT_ENABLED)
asm volatile( \
"{\n" \
"cvt.rn.bf16x2.e4m3x2 %0, %1;\n" \
"}\n" : "=r"(res_half): "h"(src_packed));
return reinterpret_cast<result_type const &>(res_half);
#else
asm volatile( \
"{\n" \
"cvt.rn.f16x2.e4m3x2 %0, %1;\n" \
"}\n" : "=r"(res_half): "h"(src_packed));
float2 res_float = __half22float2(reinterpret_cast<__half2 &>(res_half));
NumericArrayConverter<cutlass::bfloat16_t, float, 2, Round> converter;
return converter(reinterpret_cast<Array<float, 2> const&>(res_float));
#endif
#else
result_type result;
NumericConverter<result_element, source_element, Round> converter;
@@ -2961,19 +2968,31 @@ struct NumericArrayConverterPacked4Element<cutlass::bfloat16_t, cutlass::float_e
static result_type convert(source_type const & source) {
#if defined(CUDA_PTX_FP8_CVT_ENABLED)
// Convert f8 to float
NumericArrayConverterPacked4Element<float, source_element, Round> src2float;
Array<float, 4> tmp_floats = src2float(source);
// Convert float to bf16
result_type out;
Array<float, 2>* packed_tmp = reinterpret_cast<Array<float, 2>*>(&tmp_floats);
Array<result_element, 2>* packed_out = reinterpret_cast<Array<result_element, 2>*>(&out);
NumericArrayConverter<result_element, float, 2, Round> float2result;
packed_out[0] = float2result(packed_tmp[0]);
packed_out[1] = float2result(packed_tmp[1]);
#if defined(CUDA_PTX_FP8_BF16_CVT_ENABLED)
uint32_t const& src_packed = reinterpret_cast<uint32_t const&>(source);
Array<uint32_t, 2>& out_packed = reinterpret_cast<Array<uint32_t, 2>&>(out);
asm volatile("{\n"
".reg .b16 b0, b1;\n"
"mov.b32 {b0, b1}, %2;\n"
"cvt.rn.bf16x2.e4m3x2 %0, b0;\n"
"cvt.rn.bf16x2.e4m3x2 %1, b1;\n"
"}\n"
: "=r"(out_packed[0]), "=r"(out_packed[1])
: "r"(src_packed));
#else
// Convert f8 to float
NumericArrayConverterPacked4Element<float, source_element, Round> src2float;
Array<float, 4> tmp_floats = src2float(source);
return out;
// Convert float to bf16
Array<float, 2>* packed_tmp = reinterpret_cast<Array<float, 2>*>(&tmp_floats);
Array<result_element, 2>* packed_out = reinterpret_cast<Array<result_element, 2>*>(&out);
NumericArrayConverter<result_element, float, 2, Round> float2result;
packed_out[0] = float2result(packed_tmp[0]);
packed_out[1] = float2result(packed_tmp[1]);
#endif
return out;
#else
result_type result;
NumericConverter<result_element, source_element, Round> converter;

View File

@@ -36,7 +36,7 @@
#define CUTLASS_MAJOR 4
#define CUTLASS_MINOR 5
#define CUTLASS_PATCH 0
#define CUTLASS_PATCH 1
#ifdef CUTLASS_VERSIONS_GENERATED
#include "cutlass/version_extended.h"

View File

@@ -23,3 +23,5 @@ CuTe DSL
Compile with TVM FFI <cute_dsl_general/compile_with_tvm_ffi.rst>
Ahead-of-Time (AOT) Compilation <cute_dsl_general/dsl_ahead_of_time_compilation.rst>
Talks and Presentations <cute_dsl_general/resources.rst>
Naming Conventions <cute_dsl_general/naming_conventions.rst>
MMA Programming Guides <mma_docs/intro.rst>

View File

@@ -83,6 +83,9 @@ an elementwise lambda function can be passed in as the ``epilogue_op`` argument.
Refer to the `Blackwell dense GEMM example <https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py>`__ for a complete example.
.. note::
For the per-thread/partition naming convention used above (``tTR_rAcc``, ``tTR_rC``, and related tokens such as ``tAgA``, ``bSG_sC``, ``tQgQ_qdl``, …), see the :ref:`cute_dsl_naming_conventions`.
Type safety
-----------

View File

@@ -0,0 +1,253 @@
.. _cute_dsl_naming_conventions:
CuTe DSL Naming Conventions
===========================
This page summarizes the Hungarian-style naming conventions used for identifiers across the DSL examples and epilogue helpers: tensor partitions, per-thread copy-partitioners, copy atoms, and the axis-order suffixes that encode tensor layouts. It is meant as a lookup reference while reading example code — not as a style rule enforced on new code.
Memory/space scopes
-------------------
- ``g``: Global memory view (GMEM), e.g., ``gB_nkl``, ``tTR_gC``
- ``s``: Shared memory view (SMEM), e.g., ``sA``, ``tRS_sC``, ``bSG_sC``
- ``r``: Register view (RMEM), e.g., ``tTR_rAcc``, ``tRS_rC``
- ``t``: Tensor-memory view (TMEM), used for any TMEM-resident fragment or layout regardless of role. The classical case is the accumulator (``tCtAcc``, ``tTR_tAcc``). The same scope letter also appears for non-accumulator TMEM tensors such as ``tCtE``, ``tCtState``, ``tCtQState``, ``tCtShared``. Read the operand suffix to distinguish the role from the memory scope.
Per-thread/partitioned views and families
-----------------------------------------
- ``tA…`` / ``tB…``: TMA load path for A/B
- ``tAgA`` / ``tAsA``: per-thread partitioned global/shared A for TMA load
- ``tBgB`` / ``tBsB``: per-thread partitioned global/shared B for TMA load
- NVFP4/FP8 scale factors mirror this: ``tAgSFA`` / ``tAsSFA``, ``tBgSFB`` / ``tBsSFB``
- ``tC…``: Compute/epilogue path for C/Acc
- ``tCgA`` / ``tCgB`` / ``tCgC``: per-thread partitions used by MMA/epilogue (derived from global tensors)
- ``tCrA`` / ``tCrB``: per-thread fragments used by MMA (derived from SMEM A/B)
- ``tCtAcc``: per-thread accumulator fragment/layout in TMEM
- Additional ``tC*`` tensors follow the same schema for kernels that carry more than the classical A/B/C/Acc operands (see Operands and roles below): e.g. ``tCtState`` / ``tCtQState`` / ``tCtShared`` (gated-delta-net recurrent state in TMEM), ``tCrValpha`` / ``tCrVbeta`` / ``tCrVbias`` (EVT/EFC broadcast vectors in registers), ``tCtAccInter`` / ``tCtAccIntra`` (hierarchical accumulators)
.. {$nv-internal-release begin}
- Sparse GEMM additionally defines ``tCtE`` for the sparsity metadata tensor in TMEM (sm_140 / Feynman sparse GEMM, not yet released)
.. {$nv-internal-release end}
- ``tTM…``: Per-thread TMEM tiled-copy partitions used by FMHA/attention kernels (e.g. ``tTMrO`` as the register-side view of a TMEM load partitioned through ``thr_tmem_load``)
- Attention/MLA path families (``tQ…``, ``tK…``, ``tV…``, ``tP…``, ``tO…``): same schema as ``tA…`` / ``tB…`` / ``tC…`` but specialised to the Q/K/V/P/O operands of attention kernels, e.g.:
- ``tQsQ`` / ``tQgQ_qdl``: per-thread SMEM / GMEM partitions of Q for TMA load
- ``tKrK`` / ``tVrV``: per-thread register fragments for K / V
- ``tOtO`` / ``tOrO``: per-thread TMEM / register views of the attention output accumulator O
- ``tPrP``: per-thread register fragment for the softmax probability matrix P
Data-movement copy paths
------------------------
- ``tTR_*``: TMEM → Register (T2R)
- ``tTR_tAcc``: TMEM accumulator source for T2R
- ``tTR_rAcc``: Register destination for T2R
- ``tTR_gC``: When not using TMA store, Register → Global C destination partition
- ``tRS_*``: Register → Shared (R2S)
- ``tRS_rC``: Register source (C dtype)
- ``tRS_sC``: Shared destination
- ``bSG_*``: Thread(b)lock partition for Shared → Global via TMA store
- ``bSG_sC``: Shared source for TMA store
- ``bSG_gC``: Global destination for TMA store
- Also used for accumulator in some flows: ``bSG_sAcc``, ``bSG_gAcc``
- The same schema extends to additional store operands: ``bSG_sD`` / ``bSG_gD``, ``bSG_sP`` / ``bSG_gP``, ``bSG_sY`` / ``bSG_gY``
- ``bGS_*``: Thread(b)lock partition for Global → Shared via TMA **load** (the load-path mirror of ``bSG_*``)
- ``bGS_gC`` / ``bGS_sC``: Global source / Shared destination for TMA load of C-like operands (seen in EFC row/column broadcast prologues)
- ``simt_atom``: SIMT copy path used when TMA store is disabled (Register → Global)
- Generic SIMT / tiled copy atoms ``<src>2<dst>_atom[_suffix]`` name the copy direction between two memory scopes:
- ``s2r_atom_*``: Shared → Register atom used in specialised epilogues and attention loads (e.g. ``s2r_atom_delta``, ``s2r_atom_cumsum``, ``s2r_atom_d`` in Mamba2 SSD)
- ``r2s_atom``: Register → Shared atom
- ``t2r_atom`` / ``r2t_atom``: Tensor memory ↔ Register atoms (paired with ``thr_tmem_load`` / ``thr_tmem_store``)
- ``s2s_atom``: Shared → Shared atom (reshape/remap without register spill)
- ``s2t``: Shared → Tensor memory atom
.. {$nv-internal-release begin}
- ``sp2t_copy_op_*``: Sparse source → Tensor memory copy op (sm_140 / Feynman sparse GEMM, not yet released: e.g. ``Sp2TAsACopyOp``, ``Sp2TAsECopyOp``)
.. {$nv-internal-release end}
- Custom ``autovec_copy`` paths appear where the DSL auto-vectorises a bespoke layout
Operands and roles
------------------
- ``A``, ``B``, ``C``: GEMM operands
- ``Acc``: Accumulator (TMEM/Register paths). Hierarchical MMA kernels split this into ``AccInter`` / ``AccIntra`` for the inter-/intra-CTA accumulator halves
- Classical extra outputs / intermediates: ``D`` (additional output), ``Y`` (fused output), ``SFA`` / ``SFB`` (per-operand scale-factor arrays for NVFP4/FP8), ``SF`` (generic scale factor)
- Attention / MLA operand letters (Q/K/V/P/O schema):
- ``Q`` (query), ``K`` (key), ``V`` (value), ``P`` (softmax probability / score matrix), ``O`` (attention output)
- Variants: ``Kt`` / ``Vt`` for the transposed view of K/V, ``Qi`` / ``Ki`` / ``Vi`` for per-iteration slices, ``QK`` / ``PV`` / ``QKV`` where a single fragment spans multiple operands of the two back-to-back matmuls
- Mamba / recurrent-state letters: ``Delta`` / ``DeltaA`` (time-step and A-decay), ``State`` / ``QState`` / ``Shared`` (gated-delta-net recurrent state tensors), ``Cumsumlog`` / ``Cumprod`` (running reductions), ``Gate``, ``DecayV``
.. {$nv-internal-release begin}
- Sparse-GEMM letters (sm_140 / Feynman, not yet released): ``E`` (sparsity metadata tensor in TMEM; paired with ``sp2t_*`` copy ops)
.. {$nv-internal-release end}
- EVT / EFC broadcast vectors: ``Valpha`` / ``Vbeta`` (alpha/beta scalars broadcast as vectors), ``Vbias`` (bias vector), ``Ainv`` (inverse of A for fused solvers)
.. {$nv-internal-release begin}
- LUT-based block-scaled GEMM letter (Rubin, not yet released): ``LutB`` (look-up-table operand)
.. {$nv-internal-release end}
- Communication operands (multi-CTA / multicast flows): ``CommInMC`` / ``CommOutMC`` (multicast in/out), ``CommOutUC`` (unicast out)
- Head-dimension variants: ``Dv`` (value head dimension when distinct from Q/K dim), ``Nv`` (number of value heads)
Axis-order suffixes
-------------------
- Suffix encodes axis order of the view (lowercase letters each stand for one tensor mode):
- GEMM layouts use ``m``/``n``/``k``/``l``:
- ``_mnl``, ``_nkl``, ``_mkl``, … map to (M, N, K, L) ordering
- Example: ``gB_nkl`` is B with axes (N, K, L); ``gC_mnl`` is C with (M, N, L)
- Attention / FMHA layouts use ``q``/``k``/``d``/``l`` (sequence-Q, sequence-K, head-dim, batch):
- ``mQ_qdl``: Q tensor with axes (SeqQ, HeadDim, Batch)
- ``mK_kdl``: K tensor with axes (SeqK, HeadDim, Batch)
- ``mV_dkl``: V tensor with axes (HeadDim, SeqK, Batch) — the ``d``-first order reflects the V-transpose that makes the second matmul (P·V) a standard row-major ``MxK·KxN``
- Lower-rank 2D slices drop the batch letter: ``_mn``, ``_mk``, ``_nk``
- Internally, CuTe layouts also expose grouped modes like ``MMA_M/N/K``, ``EPI_M/N``, ``RestM/N/K/L``, ``STAGE``, etc. (these are typically implementation details not directly used in example code).
Reading compound tokens
-----------------------
- From left to right: ``[t|b][A|B|C|Q|K|V|P|O|TR|RS|SG|GS|TM]_[g|s|r|t][Operand/Role][AxisSuffix?]``
- ``t`` = per-thread/partitioned view; ``b`` = block/threadblock partition context
- family/path letters:
- Operand-based: ``A`` / ``B`` / ``C`` (GEMM), ``Q`` / ``K`` / ``V`` / ``P`` / ``O`` (attention)
- Direction-based: ``TR`` (TMEM → Register), ``RS`` (Register → Shared), ``SG`` (Shared → Global, store), ``GS`` (Global → Shared, load), ``TM`` (TMEM tiled-copy partition), ``R2G`` / ``S2R`` / ``T2R`` / ``R2T`` convenience aliases
- memory = ``g``/``s``/``r``/``t``
- operand/role = ``A``/``B``/``C``/``Acc``/``SFA``/``SFB``/``Q``/``K``/``V``/``P``/``O``/``E``/``State``/…
- axis suffix = ``_mnl``, ``_nkl``, ``_qdl``, ``_kdl``, ``_dkl``, ``_mn``, … when applicable
- Per-thread-partitioner objects follow a parallel ``thr_*`` vocabulary, grouped by role:
- MMA partitioner: ``thr_mma``
- Tiled-copy direction variants ``thr_copy_<src>2<dst>``: ``thr_copy_g2s``, ``thr_copy_s2r``, ``thr_copy_t2r``, ``thr_copy_r2s``, ``thr_copy_r2t``, ``thr_copy_s2t``
- Role-qualified copy variants: ``thr_copy_sfa``, ``thr_copy_sfb``, ``thr_copy_load``, ``thr_copy_beta_g2s``
- MMA variants for multi-matmul kernels: ``thr_mma_qk``, ``thr_mma_pv``, ``thr_mma_kv``, ``thr_mma_qkv``, ``thr_mma_intra1`` / ``thr_mma_intra2``, ``thr_mma_leader_cta``, ``thr_mma_sfb``
- TMEM access partitioners: ``thr_tmem_load``, ``thr_tmem_store`` (with ``_stats`` / ``_vec`` suffix variants)
The tensor produced by ``thr_foo.partition_S(X)`` or ``.partition_D(X)`` is then named by the ``[t|b]FamilyPrefix_*`` convention above.
Concrete references
-------------------
Open these files in the repository to see each pattern in context:
- TMA load partitions for A/B:
- ``tAgA``, ``tAsA``, ``tBgB``, ``tBsB``
- ``CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py`` (around TMA partition of A/B)
- Accumulator fragment in TMEM:
- ``tCtAcc``
- ``CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py`` (accumulator creation and use)
- TMEM → Register (T2R):
- ``tTR_tAcc``, ``tTR_rAcc``, ``tTR_gC``
- ``CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py`` (``epilog_tmem_copy_and_partition``)
- Register → Shared (R2S):
- ``tRS_rC``, ``tRS_sC``
- ``CuTeDSL/cute/blackwell/kernel/mixed_input_gemm/mixed_input_gemm.py`` (``epilog_smem_copy_and_partition``)
- Shared → Global via TMA store:
- ``bSG_sC``, ``bSG_gC``
- ``CuTeDSL/cute/blackwell/kernel/blockscaled_gemm/dense_blockscaled_gemm_persistent.py`` (``epilog_gmem_copy_and_partition``)
- NVFP4/FP8 scale factors:
- ``tAgSFA``/``tAsSFA``, ``tBgSFB``/``tBsSFB``
- ``CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/nvfp4_gemm_0.py`` (scale factor partition and usage)
- Additional examples across ``examples/``:
- Register → Global helper naming in MLA: ``tR2G_rO_src``, ``tR2G_rO_dst``
- ``CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp16.py`` (output store section)
- Shared → Register SIMT atoms in Mamba2 SSD: ``s2r_atom_delta``, ``s2r_atom_cumsum``, ``s2r_atom_d``
- ``CuTeDSL/cute/blackwell/kernel/attention/mamba2_ssd/mamba2_ssd.py`` (SMEM load paths for delta and D)
- ``thr_*`` slices for partitioning per-thread work: ``thr_mma``, ``thr_copy_t2r``, ``thr_copy_r2s``, etc.
- ``CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py`` (``thr_mma``, ``thr_copy_t2r``, ``thr_copy_r2s``)
- Axis-order suffix examples:
- ``gB_nkl``, ``gC_mnl``
- ``CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py`` (global tensor tiling and partitioning)
- Global → Shared (TMA load) block partition ``bGS_*``:
- ``bGS_gC``, ``bGS_sC``
- ``CuTeDSL/cute/blackwell/efc/common_efc.py`` (row/column broadcast prologue building the C-like input for EVT)
- Attention Q/K/V/P/O families and ``_qdl`` / ``_kdl`` / ``_dkl`` axis suffixes:
- ``tQsQ``, ``tQgQ_qdl``, ``mK_kdl``, ``mV_dkl``
- ``CuTeDSL/cute/hopper/kernel/attention/fmha.py`` (Q/K/V TMA partitions)
- ``tOtO``, ``tOrO``, ``tPrP``
- ``CuTeDSL/cute/blackwell/tutorial/tutorial_fmha/fmha_0.py`` (output and softmax fragments)
- ``tKrK``, ``tVrV``
- ``CuTeDSL/cute/blackwell/kernel/attention/mixed_input_fmha/mixed_input_fmha_decode.py`` (mixed-input K/V register fragments)
- TMEM tiled-copy ``tTM*`` family and the generalised ``<src>2<dst>_atom`` naming:
- ``tTMrO`` driven by ``thr_tmem_load``
- ``CuTeDSL/cute/blackwell/tutorial/tutorial_fmha/fmha_0.py``
- Recurrent-state operands (``State`` / ``QState`` / ``Shared``) in TMEM:
- ``tCtState``, ``tCtQState``, ``tCtShared``
- ``CuTeDSL/cute/blackwell/kernel/attention/gated_delta_net/gated_delta_net_chunked.py``
.. {$nv-internal-release begin}
- Sparse-metadata operand ``E`` and ``sp2t_*`` copy ops (sm_140 / Feynman, not yet released):
- ``tCtE``, ``sp2t_copy_op_A``, ``sp2t_copy_op_E``
- ``CuTeDSL/internal/feynman/sm140_sparse_gemm.py`` and ``sm140_sparse_gemm_temporal_split_k.py``
- LUT-based block-scaled GEMM operand ``LutB`` (Rubin, not yet released):
- ``CuTeDSL/cute/rubin/kernel/blockscaled_gemm/dense_blockscaled_gemm_lut.py``
- ``CuTeDSL/cute_ext/rubin/dense_gemm_lutb.py``
.. {$nv-internal-release end}
- Richer ``thr_*`` and ``thr_copy_*`` / ``thr_mma_*`` / ``thr_tmem_*`` partitioner taxonomy:
- ``thr_copy_g2s``, ``thr_copy_s2r``, ``thr_copy_s2t``, ``thr_copy_r2t``, ``thr_mma_qk``, ``thr_mma_pv``, ``thr_tmem_load``, ``thr_tmem_store``
- The attention and Mamba2 examples above are the densest references; any ``fmha_*.py`` or ``mamba2_ssd.py`` file will show the full vocabulary in use

View File

@@ -0,0 +1,11 @@
Architecture-specific MMA Programming Guides
=============================================
This section contains architecture-specific MMA programming guides.
.. toctree::
:maxdepth: 2
wmma_programming
wgmma_programming
tcgen05_programming

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,987 @@
.. _wgmma_programming:
Warpgroup MMA Programming Guide
================================
Hopper (SM90a) introduced the **warpgroup-level MMA** PTX instruction family
``wgmma.mma_async.sync.aligned``. A warpgroup (128 threads / 4 warps)
cooperates on one asynchronous ``D = A * B + C`` matrix multiply-accumulate.
Key architectural characteristics:
* **Warpgroup scope:** One MMA is issued collectively by a 128-thread
warpgroup rather than by a single warp.
* **Asynchronous issue model:** WGMMA instructions are ordered with
``cute.nvgpu.warpgroup.fence()``, ``commit_group()``, and ``wait_group()``.
* **Descriptor-based operand path:** Operand B is sourced from staged shared
memory. Operand A can be sourced either from shared memory descriptors or
from registers via ``OperandSource``.
* **Register accumulator:** The accumulator lives in RMEM and serves as both
the input C and output D of ``cute.gemm()``.
* **Architecture-specific operand layouts:** F16/BF16 supports K-major and
MN-major dense layouts when A comes from SMEM. FP8 and INT8 variants are
K-major only.
The dense DSL op classes currently exposed are ``MmaF16BF16Op`` (F16/BF16),
``MmaF8Op`` (FP8 E4M3/E5M2), and ``MmaI8Op`` (INT8/UINT8); see
`Setting up the TiledMMA, MMA Ops`_ for their full constructor parameters,
instruction K extents, and major-mode constraints.
This guide outlines the CuTe Python DSL programming model for WGMMA kernels:
stage operands in SMEM, build fragment descriptors, launch asynchronous
warpgroup MMAs, and stage the RMEM accumulator back to GMEM in the epilogue.
.. contents:: **Contents**
:local:
:depth: 2
Global Memory (GMEM) to MMA data flow overview
----------------------------------------------
WGMMA instructions require us to stage B input operands in Shared Memory (SMEM),
while A input operands can be sourced from either SMEM or registers (RMEM).
SMEM operands are read asynchronously by the hardware via SMEM descriptors.
The accumulator is always kept in registers (RMEM) of the warpgroup.
The diagram below traces the full data flow of a WGMMA GEMM kernel, for the most
common case where A and B matrices are stored in GMEM and both are staged through
SMEM (``a_src=SMEM``), and the output matrix --accumulated in RMEM-- is written
back to GMEM through an SMEM staging buffer.
There are 3 parallel tracks where each has 2 sub-tracks. Three parallel tracks are
for operands A, B, and C/D, respectively. The two sub-tracks are for copying data between different memory
spaces and for MMA execution.
- **Operand A** (and symmetrically **Operand B**):
- First, we need to create SMEM tensors for A and B matrices: ``sA`` and ``sB``. These
tensors are physically allocated tensors that are the destination of TMA copy and
the source operands for the WGMMA instructions.
- Next the **data copy flow** creates the tensor views for copying data from GMEM to SMEM.
It starts with ``mA`` tensor that represents the matrix A in global memory.
Then ``mA````local_tile````gA`` operation creates the local tile view of A that is the
slice of A matrix needed to compute the given CTA's output tile.
Then ``tma_partition(tma, sA, gA)`` produces TMA views ``tAsA``, ``tAgA``,
and the loop copies tiles from GMEM into SMEM via ``copy(tma, tAgA[k], tAsA[stage])``.
- In parallel, the **MMA flow** turns the SMEM tensor into an iterable tensor of SMEM descriptors
for the WGMMA instructions. ``sA`` (the same shared-memory allocation written by TMA)
``partition_A````tCsA`` (MMA-partitioned SMEM view)
``make_fragment_A````tCrA`` (SMEM descriptor passed to ``cute.gemm()``).
Note that the SMEM descriptor is a view created from the SMEM tensor that is
interpretable by the WGMMA instructions.
- **Accumulator C/D**:
- **RMEM accumulator flow** (gemm input/output): ``partition_C(gC)````tCgC``
``make_rmem_tensor(tCgC.shape)````acc``, which serves as both the accumulator
input (C) and output (D) of ``cute.gemm()`` (and the WGMMA instruction).
- **Output flow** (RMEM → SMEM → GMEM): After the main loop, the accumulator is
type-converted and copied from registers to SMEM via ``stmatrix`` (R2S copy),
then stored to global memory via TMA store (S2G copy):
``mC````local_tile````gC````partition_C````tCgC`` on the destination side,
and ``tRS_rAcc``/``tRS_sD`` / ``bSG_sD``/``bSG_gD`` views drive the two copy stages.
.. code-block:: text
Operand A Dataflow Path Operand B Dataflow Path Accumulator C/D Dataflow Path
─────────────────────── ─────────────────────── ─────────────────────────────
mA: (M, K) [GMEM] mB: (N, K) [GMEM] ┌──── RMEM ──────────┐
│ │ │ make_rmem_tensor() │
│ local_tile(mA, cta_tiler, coord) │ local_tile(mB, cta_tiler, coord) │ acc: accumulator │
▼ ▼ └───────┬────────────┘
gA: (BM, BK, k) [GMEM] gB: (BN, BK, k) [GMEM] │
│ │ acc:(MMA,MMA_M,MMA_N) [RMEM]
│ ┌──── SMEM ─────────┐ │ ┌──── SMEM ─────────┐ │
│ │ sA = alloc(layout)│ │ │ sB = alloc(layout)│ │ mC: (M, N) [GMEM]
│ └──┬────────┬───────┘ │ └──┬────────┬───────┘ │ │
│ │ │ │ │ │ │ │ local_tile
│ │ thr_mma.partition_A(sA) │ │ thr_mma.partition_B(sB) │ ▼
│ │ ▼ │ │ ▼ │ gC: (BM, BN) [GMEM]
│ │ tCsA:(MMA,MMA_M, │ │ tCsB:(MMA,MMA_N, │ │ partition_C
│ │ MMA_K,PIPE) [SMEM] │ │ MMA_K,PIPE) [SMEM] │ ▼
│ │ │ │ │ │ │ tCgC:(MMA,MMA_M,
│ │ make_fragment_A(tCsA) │ │ make_fragment_B(tCsB) │ MMA_N)
│ │ ▼ │ │ ▼ │ [GMEM] (epi dest)
│ │ tCrA:(MMA,MMA_M, │ │ tCrB:(MMA,MMA_N, │ │
│ │ MMA_K,PIPE) │ │ MMA_K,PIPE) │ │
│ │ [SMEM descriptors] │ │ [SMEM descriptors] │ │
│ │ └─────────────┐ │ │ └─────────────┐ │ │
╰─────┤ │ ╰─────┤ │ │ │
▼ │ ▼ │ │ │
tma_partition(tma, │ tma_partition(tma, │ │ │
sA, gA) │ sB, gB) │ │ │
→ tAsA, tAgA │ → tBsB, tBgB │ │ │
▼ │ ▼ │ │ │
┌───┴────────────────────┐ │ ┌──────┴─────────────────┐│ │ │
│ TMA copy loop (A path):│ │ │ TMA copy loop (B path):││ │ │
│ copy(tma, tAgA[k], │ │ │ copy(tma, tBgB[k], ││ │ │
│ tAsA[stage]) │ │ │ tBsB[stage]) ││ │ │
┌─▶│ (writes into sA; │ │ ┌──▶│ (writes into sB; ││ │ │
│ │ tCrA reads same sA) │ │ │ │ tCrB reads same sB) ││ │ │
│ │ repeat for next k/stage│ │ │ │ repeat for next k/stage││ │ │
│ └────────────────────────┘ │ │ └────────────────────────┘│ │ │
│ │ │ │ │ │ │ │
└────────┘ ▼ └─────────┘ ▼ ▼ │
└───────┬───────────────────────────────┴───────────────────┘ │
│ │
▼ │
┌──────────────────────────────────────────────┐ │
│ GEMM Loop: | │
│ warpgroup.fence() │ │
│ cute.gemm(tiled_mma, │ │
│ acc, D (output, RMEM), │ │
┌──▶ │ tCrA[stage], A (SMEM desc -> sA), │ │
│ │ tCrB[stage], B (SMEM desc -> sB), │ │
│ │ acc) C (accumulator, RMEM) │ │
│ │ warpgroup.commit_group() │ │
│ │ warpgroup.wait_group(n) │ │
│ └──────────────────────────────────────────────┘ │
│ │ │ │
└───────┘ | │
▼ │
Epilogue: │
tRS_rAcc = retile(acc) │
tRS_rD = type_convert(tRS_rAcc) │
│ │
▼ │
R2S: copy(tiled_copy_r2s, tRS_rD, tRS_sD) │
[RMEM → SMEM via stmatrix] │
│ │
▼ │
sC = alloc(epi_layout) [SMEM] │
bSG_sD, bSG_gD = tma_partition(tma_c, sC, gC) ◀───────────────────┘
S2G: copy(tma_c, bSG_sD[stage], bSG_gD[coord])
[SMEM → GMEM via TMA store]
**Naming convention:**
* cta_tiler = (BM, BN, BK) = CTA-wide tiler dimensions
* ``mX`` = a global tensor, e.g., (M, K) for A
* ``gX`` = CTA-tiled GMEM slice, e.g., (BM, BK, k) for A
* ``sX`` = SMEM allocation, e.g., (BM, BK, PIPE) for A
* ``tAsA``/``tBsB`` = TMA-partitioned SMEM views
* ``tAgA``/``tBgB`` = TMA-partitioned GMEM views
* ``tCsX`` = MMA-partitioned SMEM view, e.g., (MMA, MMA_M, MMA_K, PIPE) for A
* ``tCrX`` = SMEM descriptor fragment, e.g., (MMA, MMA_M, MMA_K, PIPE) for A
* ``acc`` = RMEM accumulator, (MMA, MMA_M, MMA_N)
* ``tCgC`` = MMA-partitioned GMEM, (MMA, MMA_M, MMA_N)
* ``tRS_rAcc``/``tRS_sD`` = epilogue retile views for R2S (RMEM → SMEM) copy
* ``bSG_sD``/``bSG_gD`` = TMA-partitioned SMEM/GMEM views for epilogue store
* MMA = warpgroup atom thread-value layout; MMA_M/MMA_N/MMA_K = repeat counts
(e.g., BM/inst_M), k = outer K-tiles, PIPE = pipeline stages
Setting up the TiledMMA, MMA Ops
---------------------------------
As shown in the data flow overview, CuTe DSL provides many utilities to tile/partition
the global memory tensors, and create fragment views of SMEM tensors for MMA instructions.
To utilize these functions, we need to setup the TiledMMA, MMA Ops first.
Creating a WGMMA Op
~~~~~~~~~~~~~~~~~~~~
A WGMMA op describes the hardware instruction to use, it has parameters like
data types, instruction shape, operand A source (SMEM or RMEM),
and operand major modes.
.. code-block:: python
import cutlass
import cutlass.cute as cute
from cutlass.cute.nvgpu import OperandMajorMode
import cutlass.cute.nvgpu.warpgroup as warpgroup
op = warpgroup.MmaF16BF16Op(
cutlass.Float16, # A/B element type
cutlass.Float32, # accumulator type
(64, 128, 16), # instruction shape (M, N, K)
warpgroup.OperandSource.SMEM, # A operand from shared memory
OperandMajorMode.K, # A is K-major
OperandMajorMode.K, # B is K-major
)
The key parameters are:
- **Instruction shape** ``(M, N, K)``: determines the size of one hardware MMA
instruction. WGMMA requires ``M = 64`` and ``8 <= N <= 256`` in steps of 8.
K is fixed by the op class (16 for F16/BF16, 32 for FP8 and INT8).
- **OperandSource**: ``SMEM`` reads A from a shared memory descriptor; ``RMEM``
reads A directly from registers.
- **OperandMajorMode**: ``K`` for K-major (default), ``MN`` for transposed layout.
F16/BF16 supports both K-major and MN-major for A and B when ``a_src=SMEM``;
when ``a_src=RMEM``, only B can be transposed. FP8 and INT8 are K-major only.
CuTe DSL provides implementation of the following WGMMA ops:
.. list-table:: WGMMA ops
:header-rows: 1
:widths: 30 24 46
* - PTX name
- Python class
- Constructor parameters
* - ``wgmma.mma_async.m64n{N}k16.{acc}.f16.f16`` / ``.bf16.bf16``
- ``warpgroup.MmaF16BF16Op``
- ``ab_dtype, acc_dtype, instruction_shape, a_src, a_major_mode, b_major_mode``
* - ``wgmma.mma_async.m64n{N}k32.{acc}.{e4m3|e5m2}.{e4m3|e5m2}``
- ``warpgroup.MmaF8Op``
- ``a_dtype, b_dtype, acc_dtype, instruction_shape, a_src, a_major_mode, b_major_mode``
* - ``wgmma.mma_async.m64n{N}k32.s32.{s8|u8}.{s8|u8}``
- ``warpgroup.MmaI8Op``
- ``a_dtype, b_dtype, acc_dtype, instruction_shape, a_src, a_major_mode, b_major_mode``
Creating a Tiled MMA
~~~~~~~~~~~~~~~~~~~~~
A ``TiledMma`` tiles the WGMMA atom across the CTA tile. You can pass the op
directly or create an explicit atom first.
.. code-block:: python
# Option 1: directly from op (common shorthand)
tiled_mma = cute.make_tiled_mma(op)
# Option 2: explicit atom creation
atom = cute.make_mma_atom(op)
tiled_mma = cute.make_tiled_mma(atom)
Spatial tiling with a repeat count
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
A repeat tuple ``(M_rep, N_rep, K_rep)`` replicates the WGMMA atom across the
M, N, and K dimensions, producing a larger tiled MMA that covers a bigger CTA
tile with a single ``cute.gemm`` call. Each entry in the repeat tuple
corresponds to one **warpgroup** (128 threads / 4 warps), so ``(2, 1, 1)``
uses two warpgroups — the standard configuration for large Hopper tiles:
.. code-block:: python
atom = cute.make_mma_atom(op) # op shape: (64, 128, 16)
tiled_mma = cute.make_tiled_mma(
atom,
atom_layout_mnk=(2, 1, 1), # 2 warpgroups in M
)
.. code-block:: text
WGMMA Atom make_tiled_mma(atom, (2, 1, 1))
+---------------+ +----------------+
| | | | ^
| 64 x 128 | | Atom (0,0,0) | |
| x 16 | --(2,1,1)--> | 64 x 128 | | 2 x M_atom
| | repeat | x 16 | | = 128
| | | [Warpgroup 0] | |
+---------------+ +----------------+ |
| | |
| Atom (1,0,0) | |
| 64 x 128 | |
| x 16 | |
| [Warpgroup 1] | v
+----------------+
<-- N_atom = 128 -->
K unchanged = 16
The Hopper dense GEMM examples
(``examples/cute/hopper/kernel/dense_gemm/dense_gemm.py``) use this pattern.
The helper ``sm90_utils.make_trivial_tiled_mma(...)`` selects the repeat count
automatically:
- ``atom_layout_mnk = (2, 1, 1)`` when both ``tile_M > 64`` and
``tile_N > 128`` (two warpgroups reduce register pressure).
- ``atom_layout_mnk = (1, 1, 1)`` otherwise (a single warpgroup suffices).
.. code-block:: python
import cutlass.utils.hopper_helpers as sm90_utils
tiled_mma = sm90_utils.make_trivial_tiled_mma(
a_dtype,
b_dtype,
a_major_mode,
b_major_mode,
acc_dtype,
atom_layout_mnk=(2, 1, 1),
tiler_mn=(64, 128), # atom instruction shape (M, N)
)
Custom tile permutation with ``permutation_mnk``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``make_tiled_mma`` also accepts an optional ``permutation_mnk`` argument that
controls how the tiled atom footprint is laid out across M, N, and K. At a
high level:
- ``atom_layout_mnk`` tells CuTe how many atoms (warpgroups) to replicate.
- ``permutation_mnk`` tells CuTe how the final tiled footprint is ordered.
``permutation_mnk`` is a tuple of layouts or integers that represent the
tile size and ordering of values along each dimension. When a mode's
permutation size is larger than the atom layout's natural coverage
(``atom_layout x inst_shape``), each warpgroup receives additional values
to fill the extended region — the warpgroup count stays the same, but each
warpgroup holds more data.
.. code-block:: python
atom = cute.make_mma_atom(op) # op shape: (64, 128, 16)
tiled_mma = cute.make_tiled_mma(
atom,
atom_layout_mnk=(2, 1, 1),
permutation_mnk=(128, 256, 16), # extend N from 128 to 256
)
.. code-block:: text
Without permutation — natural atom coverage (M = 128, N = 128):
C tile (M=128, N=128)
+----------------+
| | ^
| [Warpgroup 0] | |
| 64 x 128 | | 2 x inst_M
| | | = 128
+----------------+ |
| | |
| [Warpgroup 1] | |
| 64 x 128 | |
| | v
+----------------+
<--- N = 128 --->
(each warpgroup owns one (64, 128) atom)
With permutation_mnk = (128, 256, 16) — N extended to 256:
C tile (M=128, N=256)
+----------------+----------------+
| | | ^ N = 128 → 256:
| [Warpgroup 0] | [Warpgroup 0] | | atom pattern repeats
| 64 x 128 | 64 x 128 | | along N. Each warpgroup
| | | | now holds 2x the values
+----------------+----------------+ | along N (same threads,
| | | | more data).
| [Warpgroup 1] | [Warpgroup 1] | |
| 64 x 128 | 64 x 128 | |
| | | v
+----------------+----------------+
<------------ N = 256 ------------>
| atom coverage | value repeat |
**Why WGMMA typically does not need permutation_mnk:** The WGMMA
instruction already has a large N dimension (64, 128, or 256), so the natural
atom coverage is wide enough that no permutation is needed to align with SMEM
swizzle widths. The Hopper
dense GEMM examples (``dense_gemm.py``, ``dense_gemm_persistent.py``) use
``atom_layout_mnk`` alone without ``permutation_mnk``.
When ``permutation_mnk`` is not provided (default), the tile ordering is
sequential and no permutation is applied.
Partitioning Tensors
---------------------
Before computing, partition the CTA-tiled tensors according to the
tiled MMA layout. WGMMA partitioning is **warpgroup-oriented**: each
warpgroup (128 threads / 4 warps) receives its own slice of the CTA
tile, sized to match the SMEM descriptors and register accumulators
that the WGMMA instruction expects.
**2-warpgroup example**
``GEMM (M, N, K) = (512, 768, 256)``, ``tile_shape_mnk = (128, 256, 64)``,
F16 WGMMA atom = (64, 256, 16), ``atom_layout_mnk = (2, 1, 1)``,
``num_stages = 4``, 2 warpgroups = 256 threads.
Global matrices:
.. code-block:: text
mA: (M, K) = (512, 256) mB: (N, K) = (768, 256) mC: (M, N) = (512, 768)
K=256 K=256 N=768
|<--------->| |<--------->| |<----------------->|
+-----------+ +-----------+ +---+---+---+-------+
| | ^ | | ^ | | | | | ^
| mA | | M=512 | mB | | N=768 | | | | | | M=512
| | v | | v | | | | | v
+-----------+ +-----------+ +---+---+---+-------+
Tiling with ``tile_shape_mnk = (BM, BN, BK) = (128, 256, 64)`` gives
M/BM = 4 tiles, N/BN = 3 tiles, K/BK = 4 tiles:
.. code-block:: text
mA tiled into (M/BM x K/BK) mB tiled into (N/BN x K/BK) mC tiled into (M/BM x N/BN)
= (4 x 4) blocks = (3 x 4) blocks = (4 x 3) blocks
BK=64 x4 BK=64 x4 BN=256 x3
|<--->| |<--->| |<------>|
+-----+-----+-----+-----+ +-----+-----+-----+-----+ +--------+--------+--------+
| | | | | ^ | | | | | ^ | (0,0) | (0,1) | (0,2) | ^
| | | | | |128 | | | | | |256 | | | | |128
+-----+-----+-----+-----+ v +-----+-----+-----+-----+ v +--------+--------+--------+ v
| | | | | ^ | | | | | ^ | (1,0) | (1,1) | (1,2) | ^
| | | | | |128 | | | | | |256 | | | | |128
+-----+-----+-----+-----+ v +-----+-----+-----+-----+ v +--------+--------+--------+ v
| | | | | | | | | | | (2,0) | (2,1) | (2,2) |
+-----+-----+-----+-----+ +-----+-----+-----+-----+ +--------+--------+--------+
| | | | | | (3,0) | (3,1) | (3,2) |
+-----+-----+-----+-----+ +--------+--------+--------+
Each CTA picks one (M-tile, N-tile) coordinate.
For example, CTA at ``tile_coord = (1, 0, :)``.
After ``local_tile`` — one CTA's tile (``k = K/BK = 256/64 = 4``):
.. code-block:: text
gA: (BM, BK, k) = (128, 64, 4) gB: (BN, BK, k) = (256, 64, 4) gC: (BM, BN) = (128, 256)
BK=64 BK=64 BN=256
|<----->| |<----->| |<--------->|
+-------+-- +-------+-- +-----------+
| |.. | |.. | | ^
BM= | gA | k=4 BN= | gB | k=4 BM= | gC | | 128
128 | | 256 | | 128 | | v
+-------+ +-------+ +-----------+
SMEM tensors ``sA`` and ``sB`` include a pipeline staging dimension:
.. code-block:: text
sA: (BM, BK, PIPE) = (128, 64, 4) sB: (BN, BK, PIPE) = (256, 64, 4)
``get_slice(warp_group_thread_layout(warp_group_idx))`` — each
warpgroup receives its slice of the tiled MMA footprint.
With ``atom_layout_mnk = (2, 1, 1)`` and inst shape ``(64, 256, 16)``,
the tiled MMA covers ``(2x64, 1x256, 16) = (128, 256, 16)`` which
exactly matches the CTA tile in M and N. Each warpgroup owns one
64-row slice of M:
.. code-block:: text
sA (one pipeline stage, BM=128, BK=64):
Warpgroup 0's slice Warpgroup 1's slice
inst_K inst_K inst_K inst_K
=16 =16 =16 =16
|<--->|<--->|<--->|<--->| |<--->|<--->|<--->|<--->|
+-----+-----+-----+-----+ ^ +-----+-----+-----+-----+ ^
| 0 | 1 | 2 | 3 | |64 | 0 | 1 | 2 | 3 | |64
+-----+-----+-----+-----+ v +-----+-----+-----+-----+ v
|<-- MMA_K = BK/inst_K = 4 -->| |<-- MMA_K = 4 ---------->|
MMA_M = 64/64 = 1 MMA_M = 64/64 = 1
gC (BM=128, BN=256):
+---------------------------+ ^
| Warpgroup 0: 64 x 256 | | 64
| | |
+---------------------------+ v
| Warpgroup 1: 64 x 256 | ^
| | | 64
+---------------------------+ v
<--------- N = 256 -------->
MMA_M = 64/64 = 1, MMA_N = 256/256 = 1
After partition (per warpgroup):
- ``tCsA: (MMA, MMA_M, MMA_K, PIPE) = (MMA, 1, 4, 4)`` — MMA_M = BM / (atom_M x inst_M) = 128 / (2x64) = 1, MMA_K = BK / inst_K = 64 / 16 = 4
- ``tCsB: (MMA, MMA_N, MMA_K, PIPE) = (MMA, 1, 4, 4)`` — MMA_N = BN / (atom_N x inst_N) = 256 / (1x256) = 1, MMA_K = 4
- ``tCgC: (MMA, MMA_M, MMA_N) = (MMA, 1, 1)`` — MMA_M = 1, MMA_N = 1
The first mode ``MMA`` contains the atom's **thread x value** layout — it
encodes which registers within a warpgroup hold which matrix elements.
The remaining modes are repeat counts that tile the atom across the
full CTA tile.
.. note:: Because the WGMMA instruction shape is large (64 x {64..256}),
the tiled MMA footprint typically covers the entire CTA tile in M and N
with just one or two warpgroups. This means MMA_M and MMA_N are often 1.
The MMA_K dimension is where the repeat count is non-trivial (BK / inst_K
iterations per pipeline stage).
**1-warpgroup example (contrast)**
For a smaller tile ``(128, 128, 64)`` with ``atom_layout_mnk = (1, 1, 1)``,
inst shape ``(64, 128, 16)``, and ``num_stages = 4``,
the tiled MMA covers only ``(64, 128, 16)``.
Now a single warpgroup must iterate over two atom-blocks along M:
- ``tCsA: (MMA, MMA_M, MMA_K, PIPE) = (MMA, 2, 4, 4)`` — MMA_M = 128 / (1x64) = 2
- ``tCsB: (MMA, MMA_N, MMA_K, PIPE) = (MMA, 1, 4, 4)`` — MMA_N = 128 / (1x128) = 1
- ``tCgC: (MMA, MMA_M, MMA_N) = (MMA, 2, 1)``
.. code-block:: python
# Based on examples/cute/hopper/kernel/dense_gemm/dense_gemm.py
@cute.kernel
def kernel(tiled_mma: cute.TiledMma, ...):
tidx, _, _ = cute.arch.thread_idx()
# CTA-tiled global tensors
gA_mkl = cute.local_tile(
mA_mkl, tile_shape_mnk, tile_coord_mnkl, proj=(1, None, 1)
)
gB_nkl = cute.local_tile(
mB_nkl, tile_shape_mnk, tile_coord_mnkl, proj=(None, 1, 1)
)
gC_mnl = cute.local_tile(
mC_mnl, tile_shape_mnk, tile_coord_mnkl, proj=(1, 1, None)
)
# Warpgroup-oriented slicing (128 threads per warpgroup)
warp_group_idx = cute.arch.make_warp_uniform(
tidx // num_threads_per_warp_group # 128
)
warp_group_thread_layout = cute.make_layout(
mma_warp_groups, # e.g. 2
stride=num_threads_per_warp_group, # 128
)
thr_mma = tiled_mma.get_slice(
warp_group_thread_layout(warp_group_idx)
)
# Partition C from global
tCgC = thr_mma.partition_C(gC_mnl) # (MMA, MMA_M, MMA_N)
# Partition A/B from staged SMEM
tCsA = thr_mma.partition_A(sA) # (MMA, MMA_M, MMA_K, PIPE)
tCsB = thr_mma.partition_B(sB) # (MMA, MMA_N, MMA_K, PIPE)
Pre and Post-Conditions for Partitioning
-----------------------------------------
* The inputs of ``partition_A``, ``partition_B``, and ``partition_C`` should be
at least rank-2 tensors.
* The output layout is constrained by the selected MMA atom:
- For A, the output has layout ``(MMA, MMA_M, MMA_K, ...)``.
- For B, the output has layout ``(MMA, MMA_N, MMA_K, ...)``.
- For C, the output has layout ``(MMA, MMA_M, MMA_N, ...)``.
* Partitioning reasons about layout, not memory space or element type.
When ``a_src=OperandSource.RMEM``, the same tiled MMA shape still
determines the logical A footprint, but A is materialized as a register
fragment rather than a shared-memory descriptor.
Making Fragments
-----------------
Fragments are the tensors that the WGMMA instruction operates on. For dense
WGMMA:
- **Fragment A**: an SMEM descriptor when ``a_src=OperandSource.SMEM``, or an
RMEM register fragment when ``a_src=OperandSource.RMEM``.
- **Fragment B**: an SMEM descriptor pointing into staged shared memory buffers.
- **Fragment C (accumulator)**: an RMEM tensor that serves as both the input C
and output D of ``cute.gemm()``.
WGMMA fragments for A and B are **SMEM descriptors** — the hardware reads
directly from shared memory. There is no explicit SMEM → RMEM copy step for
operands A and B. The accumulator, however, still lives in per-thread
registers (RMEM).
Creating fragment descriptors and accumulator fragments
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Fragment creation has two parts:
**1. A and B fragment descriptors**
``make_fragment_A`` and ``make_fragment_B`` take the MMA-partitioned SMEM
views (``tCsA`` / ``tCsB``) and produce descriptor tensors that the WGMMA
instruction consumes. Each descriptor points to one tile within a pipeline
stage in shared memory.
.. code-block:: python
# MMA-partitioned SMEM views (see "Partitioning Tensors")
tCsA = thr_mma.partition_A(sA) # (MMA, MMA_M, MMA_K, PIPE)
tCsB = thr_mma.partition_B(sB) # (MMA, MMA_N, MMA_K, PIPE)
# SMEM descriptor fragments consumed by cute.gemm()
tCrA = tiled_mma.make_fragment_A(tCsA) # (MMA, MMA_M, MMA_K, PIPE)
tCrB = tiled_mma.make_fragment_B(tCsB) # (MMA, MMA_N, MMA_K, PIPE)
Continuing the 2-warpgroup example from `Partitioning Tensors`_
(F16 atom = (64, 256, 16), ``tile_shape_mnk = (128, 256, 64)``,
``atom_layout_mnk = (2, 1, 1)``, ``num_stages = 4``):
.. code-block:: text
tCsA: (MMA, MMA_M=1, MMA_K=4, PIPE=4)
tCsB: (MMA, MMA_N=1, MMA_K=4, PIPE=4)
make_fragment_A(tCsA) -> tCrA: (MMA, 1, 4, 4)
make_fragment_B(tCsB) -> tCrB: (MMA, 1, 4, 4)
Each element of tCrA/tCrB is an SMEM descriptor — one per
(MMA_K, PIPE) pair. The hardware reads SMEM directly via the
descriptor; no explicit SMEM -> RMEM load is needed.
tCrA per warpgroup (4 pipeline stages, 4 K-blocks each):
|<-- MMA_K = BK/inst_K = 4 -->|
stage 0: +------+------+------+------+
| k=0 | k=1 | k=2 | k=3 | inst_M=64 (MMA_M=1)
+------+------+------+------+
stage 1: +------+------+------+------+
| k=0 | k=1 | k=2 | k=3 | inst_M=64
+------+------+------+------+
stage 2: +------+------+------+------+
| k=0 | k=1 | k=2 | k=3 | inst_M=64
+------+------+------+------+
stage 3: +------+------+------+------+
| k=0 | k=1 | k=2 | k=3 | inst_M=64
+------+------+------+------+
Similarly for tCrB with shape (MMA, MMA_N=1, MMA_K=4, PIPE=4).
.. note:: WGMMA fragments for A and B are SMEM descriptors — the hardware
reads SMEM directly, so there is no ``ldmatrix`` retiling step required
before ``cute.gemm()``.
**When A comes from registers (``OperandSource.RMEM``)**
In fused kernels, the output of one MMA can become the A operand of the
next. The second ``TiledMma`` is created with
``a_src=OperandSource.RMEM``, and ``make_fragment_A`` is **not** used.
Instead:
1. The accumulator's C layout ``(MMA, MMA_M, MMA_N)`` is converted to the
A layout ``(MMA, MMA_M, MMA_K)`` expected by the second ``TiledMma``.
2. The accumulator values are type-converted and stored into an RMEM tensor
with the A layout.
3. The resulting RMEM tensor is passed directly to ``cute.gemm()`` as the A
operand — no SMEM descriptor is involved.
See the Hopper FMHA example (``examples/cute/hopper/kernel/attention/fmha.py``) for the complete pattern.
**2. C fragment (accumulator)**
The accumulator lives in per-thread registers (RMEM). Its shape is derived
from the partitioned C layout. The accumulator starts at zero before the K
loop and is updated in-place by each ``cute.gemm()`` call.
.. code-block:: python
# Partition C from global (see "Partitioning Tensors")
tCgC = thr_mma.partition_C(gC_mnl) # (MMA, MMA_M, MMA_N)
# Allocate RMEM accumulator with the same shape
acc_shape = tCgC.shape
acc = cute.make_rmem_tensor(acc_shape, cutlass.Float32)
acc.fill(0.0)
For the same running example:
.. code-block:: text
tCgC: (MMA, MMA_M=1, MMA_N=1)
make_rmem_tensor(tCgC.shape, Float32) -> acc: (MMA, 1, 1)
The accumulator stays in RMEM for the entire main loop.
cute.gemm() reads A/B from SMEM descriptors and accumulates into acc.
+-----------------------------------+
| acc: (MMA, 1, 1) in RMEM |
| 64 x 256 elements per warpgroup |
| Float32 |
+-----------------------------------+
Creating SMEM layouts for A and B
----------------------------------
The SMEM layouts define how A and B tiles are staged in shared memory,
including swizzling for bank-conflict-free descriptor access. The helper
functions in ``cutlass.utils.hopper_helpers`` handle the details.
**Host side** (``@cute.jit``):
.. code-block:: python
import cutlass.utils.hopper_helpers as sm90_utils
# Create SMEM layouts (includes swizzle + staging)
a_smem_layout = sm90_utils.make_smem_layout_a(
a_layout, # LayoutEnum — row-major or col-major
tile_shape_mnk, # CTA tile (M, N, K)
a_dtype, # element type (e.g. Float16)
num_stages, # pipeline depth
)
b_smem_layout = sm90_utils.make_smem_layout_b(
b_layout,
tile_shape_mnk,
b_dtype,
num_stages,
)
epi_smem_layout = sm90_utils.make_smem_layout_epi(
c_dtype,
c_layout,
epi_tile,
epi_stage,
)
``make_smem_layout_a`` and ``make_smem_layout_b`` are convenience helpers that
build a complete, staged SMEM layout in four steps:
1. **Extract the operand tile shape.** For A the ``(M, K)`` portion of
``tile_shape_mnk`` is kept via ``cute.slice_``; for B the ``(N, K)``
portion.
2. **Determine the major mode.** The major mode (K-major or MN-major) is read
from the layout enum (``a_layout.is_k_major_a()``). The major-mode
dimension size is used for swizzle selection.
3. **Select and materialise the swizzle atom.** A heuristic
(``get_smem_layout_atom``) picks the widest swizzle whose contiguous
size (in bits) evenly divides the major-mode dimension:
+------------+-----------------+
| Swizzle | Contiguous bits |
+============+=================+
| SW128 | 1024 (128 B) |
+------------+-----------------+
| SW64 | 512 (64 B) |
+------------+-----------------+
| SW32 | 256 (32 B) |
+------------+-----------------+
| Interleave | 128 (16 B) |
+------------+-----------------+
``make_smem_layout_atom`` then combines the chosen swizzle with a compact
outer layout into a ``ComposedLayout(swizzle, outer)``.
4. **Tile to the operand shape and append the staging dimension.**
``cute.tile_to_shape`` broadcasts the atom to the full ``(M_or_N, K)``
shape with ``num_stages`` appended. The ``order`` argument controls which
dimension is contiguous: ``(0, 1, 2)`` for K-major (K innermost),
``(1, 0, 2)`` for MN-major (MN innermost).
For the running F16 example (``tile_shape_mnk = (128, 256, 64)``,
``num_stages = 4``, K-major A, K-major B):
.. code-block:: text
A operand (K-major, tile = (M=128, K=64)):
major_mode_size = 64
64 * 16 bits = 1024 bits → SW128
atom = make_smem_layout_atom(K_SW128, Float16)
tile_to_shape(atom, (128, 64, 4), order=(0,1,2))
-> a_smem_layout: ComposedLayout with shape (128, 64, 4)
B operand (K-major, tile = (N=256, K=64)):
major_mode_size = 64
64 * 16 bits = 1024 bits → SW128
atom = make_smem_layout_atom(K_SW128, Float16)
tile_to_shape(atom, (256, 64, 4), order=(0,1,2))
-> b_smem_layout: ComposedLayout with shape (256, 64, 4)
**Kernel side** (``@cute.kernel``):
The layout and swizzle are passed to shared-memory allocation. The result
is a ``ComposedLayout`` whose ``.outer`` is the logical layout and ``.inner``
is the swizzle:
.. code-block:: python
# Based on examples/cute/hopper/kernel/dense_gemm/dense_gemm.py
sA = storage.sA.get_tensor(
a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner
)
sB = storage.sB.get_tensor(
b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner
)
After allocation:
- ``sA`` has shape ``(BM, BK, PIPE) = (128, 64, 4)``.
- ``sB`` has shape ``(BN, BK, PIPE) = (256, 64, 4)``.
These are the staged SMEM tensors consumed by ``partition_A`` /
``partition_B`` and ``make_fragment_A`` / ``make_fragment_B``
(see `Making Fragments`_).
.. note:: If you need finer control, you can build layout atoms directly with
``cute.nvgpu.warpgroup.make_smem_layout_atom(...)`` and compose the final
SMEM layout manually via ``cute.tile_to_shape``.
Executing the GEMM (Main Loop)
-------------------------------
The main loop iterates over K-tiles. The WGMMA-specific part of each
iteration is the **fence / gemm / commit / wait** sequence:
.. code-block:: python
acc.fill(0.0)
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
for k_tile in cutlass.range(k_pipe_mmas, k_tile_cnt, 1, unroll=1):
# ... wait for TMA load (pipeline details in dense_gemm.py) ...
cute.nvgpu.warpgroup.fence()
tile_crd = (None, None, None, consumer_read.index)
cute.gemm(tiled_mma, acc, tCrA[tile_crd], tCrB[tile_crd], acc)
cute.nvgpu.warpgroup.commit_group()
cute.nvgpu.warpgroup.wait_group(k_pipe_mmas)
# ... release buffer & advance pipeline (see dense_gemm.py) ...
cute.nvgpu.warpgroup.wait_group(0)
Key points:
- ``fence()`` orders prior SMEM writes before WGMMA issue.
- ``commit_group()`` publishes queued WGMMA instructions as a group.
- ``wait_group(n)`` waits until at most ``n`` groups remain in flight.
``wait_group(0)`` after the loop drains all work before the epilogue.
- ``Field.ACCUMULATE````True`` accumulates (``D += A*B``),
``False`` overwrites (``D = A*B``). The dense GEMM sets ``True`` and
zero-fills ``acc`` so the first iteration computes ``0 + A*B``.
Complete Workflow
------------------
Putting it all together, a typical Hopper WGMMA GEMM has this structure.
The MMA-relevant steps are highlighted; see ``dense_gemm.py`` for the full
kernel including TMA, pipeline, and epilogue details.
.. code-block:: python
import cutlass
import cutlass.cute as cute
from cutlass.cute.nvgpu import OperandMajorMode
import cutlass.cute.nvgpu.warpgroup as warpgroup
import cutlass.utils.hopper_helpers as sm90_utils
# --- Host side (@cute.jit) ---
# 1. MMA op + tiled MMA
op = warpgroup.MmaF16BF16Op(
cutlass.Float16, cutlass.Float32, (64, 128, 16),
warpgroup.OperandSource.SMEM, OperandMajorMode.K, OperandMajorMode.K,
)
tiled_mma = cute.make_tiled_mma(op)
# 2. SMEM layouts
a_smem_layout = sm90_utils.make_smem_layout_a(a_layout, tile_shape_mnk, a_dtype, num_stages)
b_smem_layout = sm90_utils.make_smem_layout_b(b_layout, tile_shape_mnk, b_dtype, num_stages)
# 3. TMA copy atoms + kernel launch (see dense_gemm.py)
.. code-block:: python
# --- Kernel side (@cute.kernel) ---
# 4. Allocate SMEM
smem = cutlass.utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
sA = storage.sA.get_tensor(
a_smem_layout.outer, swizzle=a_smem_layout.inner) # (BM, BK, PIPE)
sB = storage.sB.get_tensor(
b_smem_layout.outer, swizzle=b_smem_layout.inner) # (BN, BK, PIPE)
# 5. CTA-tiled global tensors
gA_mkl = cute.local_tile(mA_mkl, tile_shape_mnk, tile_coord, proj=(1, None, 1))
gB_nkl = cute.local_tile(mB_nkl, tile_shape_mnk, tile_coord, proj=(None, 1, 1))
gC_mnl = cute.local_tile(mC_mnl, tile_shape_mnk, tile_coord, proj=(1, 1, None))
# 6. Warpgroup slice, partition & make fragments
warp_group_idx = cute.arch.make_warp_uniform(tidx // num_threads_per_warp_group)
warp_group_thread_layout = cute.make_layout(mma_warp_groups, stride=num_threads_per_warp_group)
thr_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx))
tCsA = thr_mma.partition_A(sA) # (MMA, MMA_M, MMA_K, PIPE)
tCsB = thr_mma.partition_B(sB) # (MMA, MMA_N, MMA_K, PIPE)
tCrA = tiled_mma.make_fragment_A(tCsA) # SMEM descriptor
tCrB = tiled_mma.make_fragment_B(tCsB) # SMEM descriptor
tCgC = thr_mma.partition_C(gC_mnl) # (MMA, MMA_M, MMA_N)
acc = cute.make_rmem_tensor(tCgC.shape, acc_dtype)
# 7. TMA pipeline setup + prefetch (see dense_gemm.py)
# 8. Main loop — fence / gemm / commit / wait
acc.fill(0.0)
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
for k_tile in cutlass.range(k_pipe_mmas, k_tile_cnt, 1, unroll=1):
# ... wait for TMA load ...
cute.nvgpu.warpgroup.fence()
tile_crd = (None, None, None, consumer_read.index)
cute.gemm(tiled_mma, acc, tCrA[tile_crd], tCrB[tile_crd], acc)
cute.nvgpu.warpgroup.commit_group()
cute.nvgpu.warpgroup.wait_group(k_pipe_mmas)
# ... release buffer, advance pipeline ...
cute.nvgpu.warpgroup.wait_group(0)
# 9. Epilogue: RMEM → SMEM (stmatrix) → GMEM (TMA store)
# ... (see dense_gemm.py)
.. Beyond Simple Dense MMAs
.. ------------------------
.. The current Python DSL coverage for warpgroup MMA is centered on the three
.. dense ops above. PTX also defines additional WGMMA instruction families that
.. do **not** yet have DSL op classes. These are tracked in the source at
.. ``cutlass/cute/nvgpu/warpgroup/mma.py`` (marked ``✗`` in the instruction
.. table).
.. **Structured-sparse WGMMA** (``wgmma.mma_async.sp``)
.. 2:4 structured sparsity in operand A: out of every 4 consecutive K-elements,
.. exactly 2 are non-zero. The instruction K is **doubled** relative to the
.. dense counterpart (e.g. ``m64nNk32`` for F16/BF16 vs ``m64nNk16`` dense)
.. because A is stored in compressed form. Supported data types include
.. F16/BF16, TF32, FP8, and INT8.
.. Compared to the dense workflow, a sparse kernel would add:
.. - A **compressed A tensor** storing only the non-zero values (half the
.. K-elements), and a **metadata tensor E** encoding which 2 of 4 positions
.. are non-zero.
.. - Extra SMEM layouts, TMA loads, and allocations for both the compressed A
.. and the metadata E.
.. - A metadata staging step each K-tile (SMEM to the MMA instruction).
.. Once DSL support is added, the same fence/commit/wait workflow described in
.. this guide applies, with the additional metadata operand.
.. **Dense TF32 WGMMA** (``m64nNk8``)
.. TF32 (19-bit truncated FP32) inputs with FP32 accumulator. The instruction
.. K = 8 is smaller than F16's K = 16, so MMA_K repeat counts are larger for
.. the same BK tile size. Otherwise the workflow is identical to the dense
.. F16/BF16 path — the same SMEM layout, descriptor, and fence/commit/wait
.. pattern applies.
.. **Dense B1 WGMMA** (``m64nNk256``)
.. 1-bit (binary) inputs with INT32 accumulator. The very large instruction
.. K = 256 means each atom consumes 256 bits along K per operand, resulting in
.. small MMA_K repeat counts. This is a niche instruction for binary neural
.. networks.
See also:
- Dense GEMM example: ``examples/cute/hopper/kernel/dense_gemm/dense_gemm.py``
- Persistent GEMM example: ``examples/cute/hopper/kernel/dense_gemm/dense_gemm_persistent.py``
- FMHA example (RMEM A path): ``examples/cute/hopper/kernel/attention/fmha.py``
- Helper utilities: ``cutlass.utils.hopper_helpers``

File diff suppressed because it is too large Load Diff

View File

@@ -720,7 +720,9 @@ class DSLPreprocessor(ast.NodeTransformer):
offset = len(all_args) - len(func_ast.args.defaults)
for i, default_node in enumerate(func_ast.args.defaults):
ast_defaults[all_args[offset + i].arg] = default_node
for kwarg, kw_default in zip(func_ast.args.kwonlyargs, func_ast.args.kw_defaults):
for kwarg, kw_default in zip(
func_ast.args.kwonlyargs, func_ast.args.kw_defaults
):
if kw_default is not None:
ast_defaults[kwarg.arg] = kw_default
for param_name, default_val in params_with_defaults.items():

View File

@@ -1865,7 +1865,7 @@ class BaseDSL(metaclass=DSLSingletonMeta):
sources = set(x.value for x in link_libraries_attributes)
link_libraries = (
link_libraries
+ ("," if len(link_libraries) > 0 else "")
+ ("," if link_libraries and len(sources) > 0 else "")
+ ",".join(sources)
)
self.compile_options.options[LinkLibraries] = LinkLibraries(

View File

@@ -88,6 +88,11 @@ def _get_gpu_arch_info(major: int, minor: int) -> tuple[str, str, list[str]]:
"sm_120a",
["sm_120a"],
), # RTX PRO 6000 / RTX 50 Series
(12, 1): (
"Blackwell",
"sm_121a",
["sm_121a"],
), # DGX Spark
}
return gpu_arch_map.get(
(major, minor), ("Unknown", f"sm_{major}{minor}", [f"sm_{major}{minor}"])

View File

@@ -3330,7 +3330,12 @@ def filter_zeros(
if not isinstance(input, (Layout, Tensor)):
raise TypeError(f"Expected layout or tensor as input, but got {type(input)=}")
if isinstance(input, Tensor):
input = input.value
return _op_wrapper(
partial(_cute_ir.filter_zeros, target_profile=target_profile),
input,
loc=loc,
ip=ip,
)
return _cute_ir.filter_zeros(input, target_profile=target_profile, loc=loc, ip=ip)
@@ -3388,7 +3393,7 @@ def filter(
input.inner, input.offset, filter(input.outer, loc=loc, ip=ip)
)
elif isinstance(input, _Tensor):
return _cute_ir.filter(input.value, loc=loc, ip=ip)
return _op_wrapper(_cute_ir.filter, input, loc=loc, ip=ip)
else:
return _cute_ir.filter(input, loc=loc, ip=ip)
@@ -5020,10 +5025,9 @@ def local_partition(
raise NotImplementedError(
f"Index value should be 32-bit or smaller integer type, but got {index_val.type}"
)
return _cute_ir.local_partition(
input=target.value,
tiler=dice(tiler, proj),
index=index_val,
return _op_wrapper(
partial(_cute_ir.local_partition, tiler=dice(tiler, proj), index=index_val),
target,
loc=loc,
ip=ip,
)
@@ -5114,11 +5118,9 @@ def local_tile(
proj_val = _pack_coord(proj, loc=loc, ip=ip)
proj = proj_val.type.attribute
return _cute_ir.local_tile(
input=input.value,
tile=tiler_val,
coord=coord_val,
proj=proj,
return _op_wrapper(
partial(_cute_ir.local_tile, tile=tiler_val, coord=coord_val, proj=proj),
input,
loc=loc,
ip=ip,
)

View File

@@ -21,6 +21,9 @@ __all__ = [
"MmaFP8Op",
"MmaMXF4Op",
"MmaMXF4NVF4Op",
"MmaMXF8Op",
"MmaMXF8F6F4Op",
"MXF8F6F4_SUPPORTED_PAIRS",
# copy.py
"LdMatrix8x8x16bOp",
"LdMatrix16x8x8bOp",

View File

@@ -224,7 +224,9 @@ class MmaSM120BlockScaledOp(MmaOp):
admissible_archs = [
Arch.sm_120a,
Arch.sm_120f,
Arch.sm_121a,
Arch.sm_121f,
]
def __post_init__(self) -> None:
@@ -239,29 +241,44 @@ class MmaSM120BlockScaledOp(MmaOp):
"CUTE_DSL_ARCH set to sm_120a or sm_121a",
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
)
if self.ab_dtype != Float4E2M1FN:
# (ab_dtype, shape_mnk) consistency: FP4 uses (16,8,64); FP8 uses (16,8,32).
if self.ab_dtype == Float4E2M1FN:
if self.shape_mnk != (16, 8, 64):
raise OpError(
self,
"expects the 'shape_mnk' Op parameter to be (16,8,64) for Float4E2M1FN",
)
elif self.ab_dtype in (Float8E4M3FN, Float8E5M2):
if self.shape_mnk != (16, 8, 32):
raise OpError(
self,
"expects the 'shape_mnk' Op parameter to be (16,8,32) for Float8E4M3FN/Float8E5M2",
)
else:
raise OpError(
self,
"expects the 'ab_dtype' Op parameter to be Float4E2M1FN",
"expects the 'ab_dtype' Op parameter to be Float4E2M1FN, Float8E4M3FN, or Float8E5M2",
)
if self.acc_dtype != Float32:
raise OpError(
self,
"expects the 'acc_dtype' Op parameter to be Float32",
)
if self.shape_mnk != (16, 8, 64):
raise OpError(
self,
"expects the 'shape_mnk' Op parameter to be (16,8,64)",
)
if self.sf_vec_size == 16:
# vec_size=16 is only valid for FP4 (NVFP4) with E4M3 scale.
if self.ab_dtype != Float4E2M1FN:
raise OpError(
self,
"expects the 'sf_vec_size' Op parameter to be 32 for Float8E4M3FN/Float8E5M2",
)
if self.sf_type != Float8E4M3FN:
raise OpError(
self,
"expects the 'sf_type' Op parameter to be Float8E4M3FN",
)
elif self.sf_vec_size == 32:
# vec_size=32 path uses UE8M0 scale for both FP4 (MXF4) and FP8 (MXF8).
if self.sf_type != Float8E8M0FNU:
raise OpError(
self,
@@ -275,7 +292,7 @@ class MmaSM120BlockScaledOp(MmaOp):
def __str__(self) -> str:
return (
"warp-level MXF4/MXF4NVF4 MMA Operation"
"warp-level MXF4/MXF4NVF4/MXF8 MMA Operation"
+ f"\n A/B data type = {self.ab_dtype}"
+ f"\n Accumulator data type = {self.acc_dtype}"
+ f"\n Instruction shape MNK = {self.shape_mnk}"
@@ -474,3 +491,214 @@ class MmaMXF4NVF4Op(MmaSM120BlockScaledOp):
class MmaMXF4NVF4Trait(MmaBlockScaledTrait):
pass
#
# MXF8 MMA
#
@dataclass(frozen=True)
class MmaMXF8Op(MmaSM120BlockScaledOp):
"""
MXF8 warp-level MMA Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma>`__.
This Operation covers the instructions using the ``.e4m3`` / ``.e5m2`` qualifiers for the input operands.
.kind = {.kind::mxf8};
.scale_vec_size = {.scale_vec::1X};
.stype = {.ue8m0};
"""
descriptive_name = "warp-level MXF8 MMA Operation"
def __init__(
self,
ab_dtype: Type[Numeric],
acc_dtype: Type[Numeric],
sf_type: Type[Numeric],
) -> None:
super().__init__(
ab_dtype,
acc_dtype,
(16, 8, 32),
sf_type,
32,
)
def _make_trait(
self,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
**kwargs: Any,
) -> "MmaMXF8Trait":
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
ty = _cute_nvgpu_ir.MmaAtomSM120BlockScaledType.get(
shape_mnk.type.attribute,
32,
False,
self.ab_dtype.mlir_type,
self.ab_dtype.mlir_type,
self.acc_dtype.mlir_type,
self.sf_type.mlir_type,
)
return MmaMXF8Trait(make_atom(ty, loc=loc, ip=ip))
class MmaMXF8Trait(MmaBlockScaledTrait):
pass
#
# MXF8F6F4 mixed-precision MMA (independent A/B dtypes)
#
MXF8F6F4_SUPPORTED_PAIRS = frozenset(
{
(Float4E2M1FN, Float8E4M3FN),
(Float4E2M1FN, Float8E5M2),
(Float8E4M3FN, Float4E2M1FN),
(Float8E5M2, Float4E2M1FN),
}
)
@dataclass(frozen=True)
class MmaMXF8F6F4Op(MmaOp):
"""
SM120 MXF8F6F4 mixed-precision warp-level block-scaled MMA Operation.
Covers the PTX instructions using independent ``.<a_type>.<b_type>``
qualifiers (one of e2m1.e4m3, e2m1.e5m2, e4m3.e2m1, e5m2.e2m1):
.kind = {.kind::mxf8f6f4};
.scale_vec_size = {.scale_vec::1X};
.stype = {.ue8m0};
A and B operand dtypes are independent. Same-dtype FP4/FP4 and FP8/FP8
paths remain on ``MmaMXF4Op`` / ``MmaMXF4NVF4Op`` / ``MmaMXF8Op``
respectively. Same-width mixed-FP8 (E4M3 + E5M2) and FP6 mixed pairs
are not supported.
"""
a_dtype: Type[Numeric]
b_dtype: Type[Numeric]
acc_dtype: Type[Numeric]
sf_type: Type[Numeric]
descriptive_name = "warp-level MXF8F6F4 mixed-precision MMA Operation"
shape_mnk = (16, 8, 32)
sf_vec_size = 32
use_sf_layout_TV = False
admissible_archs = [
Arch.sm_120a,
Arch.sm_121a,
]
def __post_init__(self) -> None:
# Verify arch
arch = BaseDSL._get_dsl().get_arch_enum()
if arch not in self.admissible_archs:
raise OpError(
self,
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
)
if self.acc_dtype != Float32:
raise OpError(
self,
"expects the 'acc_dtype' Op parameter to be Float32",
)
if self.sf_type != Float8E8M0FNU:
raise OpError(
self,
"expects the 'sf_type' Op parameter to be Float8E8M0FNU",
)
# Reject same-dtype pairs explicitly (route to dedicated ops).
if self.a_dtype == self.b_dtype:
if self.a_dtype == Float4E2M1FN:
raise OpError(
self,
"same-dtype Float4E2M1FN/Float4E2M1FN is not supported by MmaMXF8F6F4Op; "
"use MmaMXF4Op (sf_vec_size=32) or MmaMXF4NVF4Op (sf_vec_size=16) instead",
)
if self.a_dtype in (Float8E4M3FN, Float8E5M2):
raise OpError(
self,
"same-dtype FP8/FP8 is not supported by MmaMXF8F6F4Op; "
"use MmaMXF8Op instead",
)
# Reject same-width mixed-FP8 (E4M3 + E5M2) explicitly.
fp8_dtypes = (Float8E4M3FN, Float8E5M2)
if self.a_dtype in fp8_dtypes and self.b_dtype in fp8_dtypes:
raise OpError(
self,
"same-width mixed-FP8 (Float8E4M3FN + Float8E5M2) is not supported; "
"supported MXF8F6F4 pairs are (Float4E2M1FN x Float8E4M3FN/Float8E5M2) "
"and the reverse",
)
# Final allow-list check (catches FP6 and any other unsupported dtype).
if (self.a_dtype, self.b_dtype) not in MXF8F6F4_SUPPORTED_PAIRS:
raise OpError(
self,
f"unsupported (a_dtype, b_dtype) = ({self.a_dtype}, {self.b_dtype}) "
f"for MmaMXF8F6F4Op; supported pairs are "
f"{sorted(repr(p) for p in MXF8F6F4_SUPPORTED_PAIRS)}. "
f"FP6 mixed pairs are not supported.",
)
def __str__(self) -> str:
return (
"warp-level MXF8F6F4 mixed-precision MMA Operation"
+ f"\n A data type = {self.a_dtype}"
+ f"\n B data type = {self.b_dtype}"
+ f"\n Accumulator data type = {self.acc_dtype}"
+ f"\n Instruction shape MNK = {self.shape_mnk}"
+ f"\n Vector size = {self.sf_vec_size}"
+ f"\n SF data type = {self.sf_type}"
)
def _verify_fragment_A(
self,
input: Tensor,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> None:
pass
def _verify_fragment_B(
self,
input: Tensor,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> None:
pass
def _make_trait(
self,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
**kwargs: Any,
) -> "MmaMXF8F6F4Trait":
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
ty = _cute_nvgpu_ir.MmaAtomSM120BlockScaledType.get(
shape_mnk.type.attribute,
self.sf_vec_size,
self.use_sf_layout_TV,
self.a_dtype.mlir_type,
self.b_dtype.mlir_type,
self.acc_dtype.mlir_type,
self.sf_type.mlir_type,
)
return MmaMXF8F6F4Trait(make_atom(ty, loc=loc, ip=ip))
class MmaMXF8F6F4Trait(MmaBlockScaledTrait):
pass

View File

@@ -21,7 +21,11 @@ from jax._src.interpreters import batching
from .compile import get_or_compile_kernel, build_function_spec
from .types import cutlass_to_jax_layout_order, default_tensor_spec, TensorSpec
from .types import (
cutlass_to_jax_layout_order,
default_tensor_spec,
TensorSpec,
)
from .ffi import get_cutlass_call_ffi_name, is_ffi_registered, register_ffi
@@ -77,8 +81,10 @@ def cutlass_call(
objects with ``.shape`` and ``.dtype`` attributes) describing each
output buffer.
input_spec: A :class:`TensorSpec` or list thereof providing
layout/mode/divisibility hints for input tensors. ``None`` infers
defaults from each array.
layout/mode/divisibility hints for input tensors. ``None`` infers
defaults from each array. A ``TensorSpec`` with ``layout=None`` uses
and constrains row-major physical layout; use ``mode`` to remap
physical dimensions to the kernel's logical modes.
output_spec: Same as *input_spec* but applied to output tensors.
input_output_aliases: ``{input_index: output_index}`` mapping that
allows an input buffer to alias an output, avoiding an extra copy.
@@ -308,7 +314,9 @@ def cutlass_call_inner_p_impl(
call_name = get_cutlass_call_ffi_name(allow_cuda_graph)
# Convert layout from CuTeDSL to JAX order as ffi_call expects this.
# Convert explicit layout constraints from CuTeDSL to JAX order. ``None`` is
# passed through intentionally: jax.ffi.ffi_call treats it as default
# row-major layout.
input_layouts = [cutlass_to_jax_layout_order(s.layout) for s in input_spec_flat]
output_layouts = [cutlass_to_jax_layout_order(s.layout) for s in output_spec_flat]

View File

@@ -15,12 +15,18 @@ import jax.numpy as jnp
import cutlass.cute as cute
from cutlass.cutlass_dsl import dsl_user_op
from typing import Optional
from typing import Optional, Sequence
from cutlass._mlir import ir
def reorder_modes(src: str, target: str) -> tuple[int, ...]:
"""Computes the mode given a source and target order."""
def reorder_modes(src: Sequence[str], target: Sequence[str]) -> tuple[int, ...]:
"""Compute a ``TensorSpec.mode`` from physical input order to kernel order.
``src`` names the JAX array's physical dimension order. ``target`` names the
logical mode order that the CuTe kernel expects. The returned tuple can be
passed as ``TensorSpec(mode=...)`` while leaving ``layout`` at its default
row-major value when the JAX buffer is physically row-major.
"""
src = tuple(src)
target = tuple(target)
src_map = {}
@@ -29,52 +35,64 @@ def reorder_modes(src: str, target: str) -> tuple[int, ...]:
return tuple([src_map[d] for d in target])
def gemm_a_major(d: str):
"""Returns order for A tensor major mode."""
def gemm_a_major(d: str) -> str:
"""Return the physical JAX dimension order for an A tensor major mode.
The returned string is not the kernel's canonical logical order. Use
:func:`gemm_a_mode` to map this physical order to kernel logical ``mkl``.
"""
return {"k": "lmk", "m": "lkm"}[d]
def gemm_a_mode(d: str) -> tuple[int, ...]:
"""Returns mode for A tensor major mode."""
"""Return ``TensorSpec.mode`` for A, mapping physical order to logical ``mkl``."""
return reorder_modes(gemm_a_major(d), "mkl")
def gemm_b_major(d: str):
"""Returns order for B tensor major mode."""
def gemm_b_major(d: str) -> str:
"""Return the physical JAX dimension order for a B tensor major mode.
The returned string is not the kernel's canonical logical order. Use
:func:`gemm_b_mode` to map this physical order to kernel logical ``nkl``.
"""
return {"k": "lnk", "n": "lkn"}[d]
def gemm_b_mode(d: str) -> tuple[int, ...]:
"""Returns mode for B tensor major mode."""
"""Return ``TensorSpec.mode`` for B, mapping physical order to logical ``nkl``."""
return reorder_modes(gemm_b_major(d), "nkl")
def gemm_c_major(d: str):
"""Returns order for C tensor major mode."""
def gemm_c_major(d: str) -> str:
"""Return the physical JAX dimension order for a C/D tensor major mode.
The returned string is not the kernel's canonical logical order. Use
:func:`gemm_c_mode` to map this physical order to kernel logical ``mnl``.
"""
return {"n": "lmn", "m": "lnm"}[d]
def gemm_c_mode(d: str) -> tuple[int, ...]:
"""Returns mode for C tensor major mode."""
"""Return ``TensorSpec.mode`` for C/D, mapping physical order to logical ``mnl``."""
return reorder_modes(gemm_c_major(d), "mnl")
def gemm_a_shape(l, m, k, major) -> tuple[int, ...]:
"""Returns shape for A tensor given major mode."""
def gemm_a_shape(l: int, m: int, k: int, major: str) -> tuple[int, ...]:
"""Return the physical row-major JAX shape for A with the requested major mode."""
assert major in ("k", "m")
shape = (l, m, k) if major == "k" else (l, k, m)
return shape
def gemm_b_shape(l, n, k, major) -> tuple[int, ...]:
"""Returns shape for B tensor given major mode."""
def gemm_b_shape(l: int, n: int, k: int, major: str) -> tuple[int, ...]:
"""Return the physical row-major JAX shape for B with the requested major mode."""
assert major in ("k", "n")
shape = (l, n, k) if major == "k" else (l, k, n)
return shape
def gemm_c_shape(l, m, n, major) -> tuple[int, ...]:
"""Returns shape for C tensor given major mode."""
def gemm_c_shape(l: int, m: int, n: int, major: str) -> tuple[int, ...]:
"""Return the physical row-major JAX shape for C/D with the requested major mode."""
assert major in ("m", "n")
shape = (l, m, n) if major == "n" else (l, n, m)
return shape

View File

@@ -9,7 +9,7 @@
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from typing import Optional, Sequence
from typing import Any, Optional, Sequence
from dataclasses import dataclass, field
@@ -18,6 +18,7 @@ import jax.numpy as jnp
import cutlass
import cutlass.cute as cute
from cutlass.cute.core import IntValue
from cutlass.cute.runtime import from_dlpack as _from_dlpack
from cutlass.cute import AddressSpace
from cutlass._mlir import ir
@@ -58,35 +59,69 @@ DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNMENT = 256
class TensorSpec:
"""Specifies the layout and metadata for a JAX array passed to a CuTe kernel.
TensorSpec controls how a JAX array's dimensions are mapped to a cute.Tensor
during jit lowering, including stride ordering, mode permutation, and whether
shapes/strides are compiled as static constants.
TensorSpec controls how a JAX array's input dimensions are mapped to a
``cute.Tensor`` during jit lowering, including compact stride ordering,
mode permutation, and whether shapes/strides are compiled as static
constants. The JAX bridge models tensors as compact layouts: runtime
strides are derived from runtime shapes using ``layout`` order rather than
loaded from a strided view descriptor.
A useful way to choose a spec is to separate physical storage from logical
kernel modes:
1. First choose the public JAX array shape and its compact physical memory
order. If the buffer is a standard row-major JAX array, leave
``layout=None``. ``cutlass_call`` will constrain the FFI operand/result
to row-major physical layout, matching the CuTe tensor strides that are
built from the default.
2. Then use ``mode`` only when the kernel should see those input dimensions
in a different logical order. ``mode`` is applied after the compact
layout is built; it is not a request for JAX/XLA to transpose data.
For example, a row-major JAX buffer shaped ``(expert_count, N, K)`` can be
presented to a kernel expecting logical ``(N, K, expert_count)`` with
``TensorSpec(mode=(1, 2, 0))``. No explicit ``layout`` is needed because the
physical buffer is still ordinary row-major, and the FFI call will be
constrained accordingly. Use ``layout`` only when the compact physical
stride order itself differs from the default row-major order, such as a
column-major compact buffer.
Attributes:
layout: A minor-to-major stride ordering in CuTeDSL convention. ``layout[i]``
gives the stride rank of dimension ``i``, where rank 0 means the smallest
(innermost) stride. For example, row-major order for a 3-D tensor is
``(2, 1, 0)``. If ``None``, row-major is assumed. Use
:func:`jax_to_cutlass_layout_order` to convert from JAX's major-to-minor
convention.
mode: A permutation that maps the stride-ordered dimensions to the mode
positions of the resulting ``cute.Layout``. For example, ``mode=(2, 0, 1)``
reorders an ``(M, K, L)`` layout into ``(K, L, M)`` mode order inside the
kernel. If ``None``, modes match the natural dimension order ``(0, 1, ..., N-1)``.
gives the compact physical stride rank of input dimension ``i``,
where rank 0 means the smallest (innermost) stride. For example,
row-major order for a 3-D tensor is ``(2, 1, 0)``. If ``None``,
row-major is assumed. Use :func:`jax_to_cutlass_layout_order` to
convert from JAX's major-to-minor convention. ``layout`` does not
change which logical mode a dimension represents; combine it with
``mode`` when physical order and kernel-logical order differ.
mode: A permutation applied after the compact layout is constructed. It
selects input dimensions into the mode positions seen by the kernel.
For example, ``mode=(2, 0, 1)`` presents an input shaped
``(M, K, L)`` to the kernel as logical ``(L, M, K)``. If ``None``,
modes match the natural input-dimension order ``(0, 1, ..., N-1)``.
``mode`` changes the tensor layout object seen by CuTe code but
does not materialize a transpose or change the underlying buffer.
static: If ``True``, shapes and strides are compiled as static ``constexpr``
values, which may enable additional compiler optimisations. Kernels that
do not support static shapes will raise a compile error. Must be ``False``
when any dimension is symbolic (e.g. under ``jax.export``).
ptr_assumed_align: Assumed byte alignment of the tensor's data pointer.
Overrides the default of 256 bytes. Rarely needs to change.
divisibility: Optional per-mode divisibility hints. If a single int is passed
divisibility will be applied to the leading (stride=1) dimension only.
divisibility: Optional divisibility hints for input dimensions, in the
same order as the JAX array shape and before any ``mode`` reordering.
Positive hints constrain dynamic shape values and are propagated
through compact stride construction: a stride inherits the product
of the divisibilities for dimensions with lower stride rank.
Positive explicit hints take precedence over inferred concrete
extents. If a single int is passed, it is applied to the leading
compact dimension only, where ``layout[i] == 0``.
"""
# Minor-to-major stride ordering in CuTeDSL convention (layout[i] = stride rank
# of dimension i, 0 = innermost). Defaults to row-major if None.
layout: tuple[int, ...] | None = field(metadata=dict(static=True), default=None)
# Permutation from stride-ordered dimensions to cute.Layout mode positions.
# Permutation from input dimensions to cute.Layout mode positions.
# Defaults to identity (0, 1, ..., N-1) if None.
mode: tuple[int, ...] | None = field(metadata=dict(static=True), default=None)
# If True, shapes and strides are embedded as compile-time constants.
@@ -96,7 +131,7 @@ class TensorSpec:
ptr_assumed_align: int = field(
metadata=dict(static=True), default=DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNMENT
)
# Per-mode divisibility hints.
# Per-input-dimension divisibility hints, before mode reordering.
divisibility: tuple[int | None, ...] | int | None = field(
metadata=dict(static=True), default=None
)
@@ -128,9 +163,10 @@ def row_major_layout(shaped):
def default_tensor_mode(shaped):
"""Returns the identity mode permutation for an N-dimensional tensor.
The mode permutation maps stride-ordered dimensions to ``cute.Layout`` mode
positions. The default identity ``(0, 1, ..., N-1)`` leaves the mode order
unchanged relative to the dimension order.
The mode permutation maps JAX input dimensions to ``cute.Layout`` mode
positions after the compact layout has been constructed. The default
identity ``(0, 1, ..., N-1)`` leaves the mode order unchanged relative to
the JAX shape order.
Args:
shaped: An object with a ``.shape`` attribute, or a shape tuple/sequence.
@@ -151,12 +187,22 @@ def default_tensor_spec(shaped) -> TensorSpec:
TensorSpec(layout=(N-1, ..., 1, 0), mode=(0, 1, ..., N-1), divisibility=(D0, D1, ... DN-1))
This is appropriate for standard row-major (C-contiguous) JAX arrays that
do not require dimension reordering inside the kernel.
do not require dimension reordering inside the kernel. The resulting JAX
CuTe tensor is treated as compact: strides are derived from shapes using the
row-major layout order.
Divisibility hints are inferred only for concrete integer dimensions.
Symbolic dimensions always produce ``None`` for their slot; pass an
explicit ``TensorSpec`` with ``divisibility`` set if you need alignment
hints for symbolic shapes.
If the JAX buffer is row-major but the kernel expects a different logical
mode order, use an explicit :class:`TensorSpec` with ``mode`` set and leave
``layout`` unset. ``cutlass_call`` still constrains the FFI buffer to
row-major layout in this case. For example, ``TensorSpec(mode=(1, 2, 0))``
maps a physical ``(L, M, K)`` row-major input to a logical ``(M, K, L)``
tensor.
Divisibility hints are inferred only for concrete integer input dimensions.
Symbolic dimensions always produce ``None`` for their slot; pass an explicit
``TensorSpec`` with ``divisibility`` set if you need alignment hints for
symbolic shapes or want a weaker explicit constraint than the concrete
extent.
Args:
shaped: An object with a ``.shape`` attribute, or a shape tuple/sequence.
@@ -179,11 +225,12 @@ def default_tensor_spec(shaped) -> TensorSpec:
def _expand_divisibility(
divisibility, order: tuple[int, ...], ndim: int
) -> tuple[int | None, ...] | None:
"""Expand a divisibility spec to a full per-dimension tuple.
"""Expand a divisibility spec to a full per-input-dimension tuple.
A bare ``int`` is placed at the leading-dimension slot (where
``order[i] == 0``, i.e. stride == 1) and ``None`` everywhere else.
A tuple is returned unchanged. ``None`` returns ``None``.
A tuple is already in JAX input-dimension order and is returned unchanged.
``None`` returns ``None``.
"""
if divisibility is None or isinstance(divisibility, tuple):
return divisibility
@@ -268,7 +315,20 @@ def from_dlpack(array, assumed_align: int = DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNM
return _from_dlpack(array, assumed_align=assumed_align)
def _validate_permutation(name: str, perm, shape):
def _assume_divisible_int(
value: Any,
divby: int,
*,
loc: ir.Location | None = None,
ip: ir.InsertionPoint | None = None,
) -> Any:
"""Attach a divisibility assumption to an integer value without narrowing it."""
if divby <= 1:
return value
return cute.assume(IntValue(value, loc=loc, ip=ip), divby=divby, loc=loc, ip=ip)
def _validate_permutation(name: str, perm: Sequence[int], shape: Sequence[Any]) -> None:
if len(perm) != len(shape):
raise ValueError(f"{name} must be same length as shape", perm, shape)
for s in perm:
@@ -292,9 +352,11 @@ class JaxArray:
can be concrete or symbolic in the case of jax.export.
3. mem_space: The memory space of the tensor. Defaults to gmem.
4. assumed_align: The alignment of the tensor. Defaults to XLA alignment.
5. order: Specifies the order of the shape to determine strides.
6. mode: Specifies how to map ordered elements to the modes od a cute.Layout.
5. order: Specifies the compact physical stride order of the shape.
6. mode: Specifies how to map input dimensions to the logical modes seen by
the kernel after the compact layout is constructed.
7. static: If True, tensor shapes and strides are compiled statically.
8. divisibility: Optional divisibility hints in input-dimension order.
"""
def __init__(
@@ -381,6 +443,21 @@ class JaxArrayValue(JaxArray):
ip: Optional[ir.InsertionPoint] = None,
):
i32 = ir.IntegerType.get_signless(32)
# Track the divisibility available for each input dimension. Explicit
# hints win; otherwise concrete dimensions contribute their known extent.
dim_divisibility = None
if self.divisibility is not None:
dim_divisibility = []
for div_spec, static_s in zip(self.divisibility, self.shape):
if div_spec is not None and div_spec > 0:
dim_divisibility.append(div_spec)
elif isinstance(static_s, int):
dim_divisibility.append(static_s)
else:
dim_divisibility.append(1)
dim_divisibility = tuple(dim_divisibility)
pairs = sorted(zip(shape, order), key=lambda x: x[1])
# Compute strides for each element in order.
@@ -395,28 +472,29 @@ class JaxArrayValue(JaxArray):
for i in range(len(shape)):
strides_ordered.append(strides[order[i]])
if dim_divisibility is not None:
# A compact stride is the product of all dimensions with a lower
# stride order, so it inherits the product of their divisibility.
stride_divisibility = []
for dim_order in order:
divby = 1
for other_dim, other_order in enumerate(order):
if other_order < dim_order:
divby *= dim_divisibility[other_dim]
stride_divisibility.append(divby)
strides_ordered = [
_assume_divisible_int(s, divby, loc=loc, ip=ip)
for s, divby in zip(strides_ordered, stride_divisibility)
]
# Shapes are expected to be int32 so truncate to that before creating layout
shape_i32 = tuple(arith.trunci(i32, s) for s in shape)
# Apply per-mode divisibility assumptions so the compiler can exploit alignment.
if self.divisibility is not None:
assumed = []
for s32, div_spec, static_s in zip(
shape_i32, self.divisibility, self.shape
):
if isinstance(static_s, int):
# Pure static shape is known even though a dynamic shape is
# used. We can assume the exact shape here. We keep the shape
# as a dynamic value to avoid breaking code that may expect
# a dynamic value.
assumed.append(cute.assume(s32, divby=static_s))
elif div_spec is not None:
# Using a dynamic value so apply the div_spec if its provided.
assumed.append(cute.assume(s32, divby=div_spec))
else:
# No divisibility specification for this shape
assumed.append(s32)
shape_i32 = tuple(assumed)
if dim_divisibility is not None:
shape_i32 = tuple(
_assume_divisible_int(s, divby, loc=loc, ip=ip)
for s, divby in zip(shape_i32, dim_divisibility)
)
return cute.make_layout(shape_i32, stride=tuple(strides_ordered))

View File

@@ -84,6 +84,8 @@ from .tmem_allocator import (
from .layout import LayoutEnum
from .block import block_copy
from .mixed_input_helpers import (
TransformMode,
scale_tma_partition,
@@ -176,6 +178,7 @@ __all__ = [
"sm90",
"sm100",
"gemm",
"block_copy",
"ClcDynamicPersistentTileSchedulerParams",
"ClcDynamicPersistentTileScheduler",
"print_latex",

View File

@@ -612,7 +612,7 @@ def get_tmem_load_op(
def get_smem_layout_atom_ab(
major_mode: OperandMajorMode,
element_type: Type[Numeric],
smem_shape_mn_k: Tuple[int, int],
smem_shape_mn_k: cute.Tile,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
@@ -625,13 +625,16 @@ def get_smem_layout_atom_ab(
:param element_type: The element type for the SMEM tensor.
:type element_type: Type[Numeric]
:param smem_shape_mn_k: The shape of the SMEM tensor.
:type smem_shape_mn_k: Tuple[int, int]
:type smem_shape_mn_k: cute.Tile
:return: The SMEM layout atom kind
:rtype: cutlass.cute.nvgpu.tcgen05.SmemLayoutAtomKind
"""
is_k_major = major_mode == OperandMajorMode.K
major_mode_size = smem_shape_mn_k[1] if is_k_major else smem_shape_mn_k[0]
major_mode_size = (
cute.size(smem_shape_mn_k, mode=[1])
if is_k_major
else cute.size(smem_shape_mn_k, mode=[0])
)
assert major_mode_size % 8 == 0
sw128_num_contiguous_bits = 1024
sw64_num_contiguous_bits = 512
@@ -711,6 +714,7 @@ def make_smem_layout(
cute.append(smem_tile_shape, num_stages),
order=(0, 1, 2) if is_k_major else (1, 0, 2),
)
return cute.coalesce(smem_layout, target_profile=(1, 1, 1), loc=loc, ip=ip)
@@ -1956,12 +1960,35 @@ def thrfrg_SFA(
"""Thread-fragment scale factor A tensor for SM120 block-scaled MMA.
Implements the ThrFrg partitioning for scale factor A according to the
corresponding C++ code.
corresponding C++ code in cutlass/include/cute/atom/mma_traits_sm120.hpp:
SFALayout for SM120 MXF4 16x8x64 uses K=64, SM120 MXF8F6F4 16x8x32 uses
K=32; the stride pattern ``((_8,_0,_1), _16)`` is shared.
"""
assert cute.rank(sfa_tensor) >= 2
atom_shape_mnk = tiled_mma.shape_mnk
atom_sfa_layout = cute.make_layout(shape=((2, 2, 8), 64), stride=((8, 0, 1), 16))
# K-dim of the warp-MMA atom: FP4 -> 64, FP8 -> 32 (per mma_traits_sm120.hpp).
# For FP8 (atom_K=32) where mma_nsf=1, wrap K in a 2-tuple ``(atom_K, 1)``
# so the layout's K mode keeps its 2D structure and the resulting fragment
# has the same rank as the FP4 path. For FP4 (atom_K=64) the original 1D
# layout already produces a 2D K decomposition through SMEM-layout
# composition, so we keep the original shape.
atom_K = atom_shape_mnk[2]
if atom_K == 32:
atom_sfa_layout = cute.make_layout(
shape=((2, 2, 8), (atom_K, 1)),
stride=((8, 0, 1), (16, 0)),
)
elif atom_K == 64:
atom_sfa_layout = cute.make_layout(
shape=((2, 2, 8), atom_K),
stride=((8, 0, 1), 16),
)
else:
raise ValueError(
f"thrfrg_SFA: unsupported atom_K={atom_K}; SM120 block-scaled atoms "
f"use atom_K=32 (mxf8/mxf8f6f4) or atom_K=64 (mxf4/mxf4nvf4)"
)
permutation_mnk = tiled_mma.permutation_mnk
thr_layout_vmnk = tiled_mma.thr_layout_vmnk
@@ -2000,12 +2027,32 @@ def thrfrg_SFB(
"""Thread-fragment scale factor B tensor for SM120 block-scaled MMA.
Implements the ThrFrg partitioning for scale factor B according to the
corresponding C++ code.
corresponding C++ code in cutlass/include/cute/atom/mma_traits_sm120.hpp:
SFBLayout for SM120 MXF4 16x8x64 uses K=64, SM120 MXF8F6F4 16x8x32 uses
K=32; the stride pattern ``((_0,_1), _8)`` is shared.
"""
assert cute.rank(sfb_tensor) >= 2
atom_shape_mnk = tiled_mma.shape_mnk
atom_sfb_layout = cute.make_layout(shape=((4, 8), 64), stride=((0, 1), 8))
# K-dim of the warp-MMA atom: FP4 -> 64, FP8 -> 32 (per mma_traits_sm120.hpp).
# See :func:`thrfrg_SFA` for the rationale behind the FP8-only
# ``(atom_K, 1)`` wrapping.
atom_K = atom_shape_mnk[2]
if atom_K == 32:
atom_sfb_layout = cute.make_layout(
shape=((4, 8), (atom_K, 1)),
stride=((0, 1), (8, 0)),
)
elif atom_K == 64:
atom_sfb_layout = cute.make_layout(
shape=((4, 8), atom_K),
stride=((0, 1), 8),
)
else:
raise ValueError(
f"thrfrg_SFB: unsupported atom_K={atom_K}; SM120 block-scaled atoms "
f"use atom_K=32 (mxf8/mxf8f6f4) or atom_K=64 (mxf4/mxf4nvf4)"
)
permutation_mnk = tiled_mma.permutation_mnk
thr_layout_vmnk = tiled_mma.thr_layout_vmnk

View File

@@ -0,0 +1,248 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from cutlass.cutlass_dsl import dsl_user_op, CuTeDSL
from cutlass.cute.typing import Tensor
from cutlass.cute.core import make_layout, filter_zeros
from cutlass.cute.atom import TiledCopy
from cutlass.cute.algorithm import copy
from cutlass.cute.nvgpu import tcgen05
from cutlass.cute.nvgpu.cpasync.copy import (
TmaCopyOp,
CopyBulkTensorTileG2SOp,
CopyBulkTensorTileG2SMulticastOp,
)
from cutlass.cute.nvgpu.cpasync.helpers import tma_partition
from cutlass.cute.nvgpu.tcgen05.copy import _S2TCopyBase
from typing import Any, Optional
from cutlass._mlir import ir
def _check_required_args(
required_args: list[str], kwargs: dict, condition: bool = True
) -> None:
if not condition:
return
for arg in required_args:
if arg not in kwargs:
raise ValueError(f"Argument {arg} is required.")
def _tma_copy_impl(
tiled_copy: TiledCopy,
src: Tensor,
dst: Tensor,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
**kwargs: Any,
) -> None:
"""Internal implementation for TMA-based block-level copy."""
#
# Handle tma_multicast argument
#
if "tma_multicast" in kwargs:
if not isinstance(
tiled_copy.op,
(
CopyBulkTensorTileG2SOp,
),
):
raise ValueError(
"block_copy with tma_multicast expects a non-multicast G2S TMA copy atom "
"(CopyBulkTensorTileG2SOp) for compiler-driven multicast"
)
# Mark as coming from block API
kwargs["tma_multicast"]["from_block_api"] = True
#
# Check if required arguments are provided
#
is_bar_ptr_required = isinstance(
tiled_copy.op,
(
CopyBulkTensorTileG2SOp,
CopyBulkTensorTileG2SMulticastOp,
),
)
_check_required_args(["tma_bar_ptr"], kwargs, is_bar_ptr_required)
#
# TMA bulk tensor copies: partition via tma_partition
#
is_g2s = isinstance(
tiled_copy.op,
(
CopyBulkTensorTileG2SOp,
),
)
stensor = dst if is_g2s else src
gtensor = src if is_g2s else dst
cta_coord = 0
cta_layout = make_layout(1, loc=loc, ip=ip)
s_ptn, g_ptn = tma_partition(
tiled_copy, cta_coord, cta_layout, stensor, gtensor, loc=loc, ip=ip
)
s_ptn = filter_zeros(s_ptn)
g_ptn = filter_zeros(g_ptn)
src_arg = g_ptn if is_g2s else s_ptn
dst_arg = s_ptn if is_g2s else g_ptn
return copy(tiled_copy, src_arg, dst_arg, loc=loc, ip=ip, **kwargs)
def _utccp_copy_impl(
tiled_copy: TiledCopy,
src: Tensor,
dst: Tensor,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
**kwargs: Any,
) -> None:
"""Internal implementation for S2T (SMEM to TMEM) copy operations.
This function abstracts the S2T copy pattern which involves:
1. Filtering zeros from src (smem) and dst (tmem) tensors
2. Creating a tiled copy using make_s2t_copy
3. Partitioning source and destination
4. Getting the SMEM descriptor tensor
5. Executing the copy
:param tiled_copy: The tiled copy for S2T operations.
:type tiled_copy: TiledCopy
:param src: The source tensor in shared memory.
:type src: Tensor
:param dst: The destination tensor in TMEM.
:type dst: Tensor
"""
# Filter zeros from src (smem) and dst (tmem) tensors
src_compact = filter_zeros(src)
dst_compact = filter_zeros(dst)
# S2T has a single thread slice; election handled automatically in lowering
thr_copy = tiled_copy.get_slice(0)
# Partition source and destination
src_partitioned = thr_copy.partition_S(src_compact, loc=loc, ip=ip)
dst_partitioned = thr_copy.partition_D(dst_compact, loc=loc, ip=ip)
# Get SMEM descriptor tensor for the source
smem_desc_tensor = tcgen05.get_s2t_smem_desc_tensor(
tiled_copy, src_partitioned, loc=loc, ip=ip
)
# Execute the copy
return copy(tiled_copy, smem_desc_tensor, dst_partitioned, loc=loc, ip=ip, **kwargs)
@dsl_user_op
@CuTeDSL.jit
def block_copy(
tiled_copy: TiledCopy,
src: Tensor,
dst: Tensor,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
**kwargs: Any,
) -> None:
"""Performs a block-level copy operation.
This function adds an abstraction layer over the `cute.copy` usage model by
allowing operands with layouts shaped like tiles to be passed directly. This
removes the need to manually partition. The API is designed to support multiple
copy kinds; currently TMA-based copies and S2T (SMEM to TMEM) copies are supported.
**TMA copy requirements**:
When using TMA-based tiled copies, the ``src`` and ``dst`` tensors must have
their first mode representing the TMATile, i.e. tensors shaped as ``(TMATile, Rest...)``.
For a rank-2 tensor with logical layout (e.g., ``(TILE_M, TILE_N)``), call
``group_modes(tensor, 0, 2)`` before passing it to this function.
**TMA multicast support**:
For TMA-based copies that enable compiler-driven multicast in a 2D cluster, pass the
``tma_multicast`` argument as a dict with the following keys:
- ``cluster_shape``: a tuple of 2 integers ``(cluster_m, cluster_n)``
representing the **2D cluster shape**.
- ``multicast_dim``: either ``"M"`` or ``"N"`` indicating which
cluster dimension the multicast happens along.
- ``use_2cta_mma_inst`` (optional): a ``bool`` indicating whether to
use 2CTA MMA instructions when the loaded data is consumed by MMA.
Defaults to ``False`` when omitted.
**S2T (SMEM to TMEM) copy**:
When using S2T copy operations (e.g., ``tcgen05.Cp4x32x128bOp``), the function
automatically handles the filtering, partitioning, and SMEM descriptor creation.
Pass a copy atom created with ``cute.make_copy_atom(tcgen05.Cp*Op(...), dtype)``
along with source (SMEM) and destination (TMEM) tensors.
Examples:
.. code-block:: python
# 1) TMA load without compiler-driven multicast
# Note: group_modes is called to make the first mode TMATile
block_copy(tma_atom_a, group_modes(tCgA_, 0, 2), group_modes(tCsA_, 0, 2),
tma_bar_ptr=tma_bar_ptr)
# 2) TMA load with compiler-driven multicast along M in a (4,2) cluster
block_copy(
tma_atom_a,
group_modes(tCgA_, 0, 2),
group_modes(tCsA_, 0, 2),
tma_multicast={
"cluster_shape": (4, 2),
"multicast_dim": "M",
"use_2cta_mma_inst": True,
},
tma_bar_ptr=tma_bar_ptr,
)
# 3) TMA store
# Note that `tma_bar_ptr` and CTA params (`cta_coord` and `cta_layout`)
# are not needed for TMA store
block_copy(tma_atom_c, group_modes(tCsC_, 0, 2), group_modes(tCgC_, 0, 2))
# 4) S2T copy (SMEM to TMEM)
copy_atom_s2t = cute.make_copy_atom(
tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), sf_dtype
)
block_copy(copy_atom_s2t, tCsSF, tCtSF)
:param tiled_copy: The tiled_copy or copy_atom of the current copy operation.
:type tiled_copy: TiledCopy
:param src: The source tensor.
:type src: Tensor
:param dst: The destination tensor.
:type dst: Tensor
:param tma_multicast: Optional dict for TMA multicast configuration with keys
``cluster_shape``, ``multicast_dim``, and optionally
``use_2cta_mma_inst``.
:type tma_multicast: dict, optional
"""
import cutlass # local import to avoid circular import at module load time
if cutlass.const_expr(isinstance(tiled_copy.op, TmaCopyOp)):
return _tma_copy_impl(tiled_copy, src, dst, loc=loc, ip=ip, **kwargs)
elif cutlass.const_expr(isinstance(tiled_copy.op, _S2TCopyBase)):
return _utccp_copy_impl(tiled_copy, src, dst, loc=loc, ip=ip, **kwargs)
else:
raise NotImplementedError(
f"Copy op {type(tiled_copy.op).__name__} is not supported yet."
)

View File

@@ -1,3 +1,3 @@
# Use `pip install -r requirements-cu13.txt` with the present file to install a
# wheel consistent with the present state of the github repository
nvidia-cutlass-dsl[cu13]==4.4.2
nvidia-cutlass-dsl[cu13]==4.5.1

View File

@@ -1,3 +1,3 @@
# Use `pip install -r requirements.txt` with the present file to install a
# wheel consistent with the present state of the github repository
nvidia-cutlass-dsl==4.4.2
nvidia-cutlass-dsl==4.5.1

View File

@@ -133,7 +133,7 @@ def get_option_registry():
this._option_registry = OptionRegistry(device_cc())
return this._option_registry
this.__version__ = '4.5.0'
this.__version__ = '4.5.1'
from cutlass_cppgen.backend import create_memory_pool
from cutlass_cppgen.emit.pytorch import pytorch

View File

@@ -51,7 +51,7 @@ setup_pycute.perform_setup()
setup(
name='cutlass_cppgen',
version='4.5.0',
version='4.5.1',
description='CUTLASS Pythonic Interface',
package_dir={'': '.'},
packages=[

View File

@@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='cutlass_library',
version='4.5.0',
version='4.5.1',
description='CUTLASS library generation scripts',
packages=['cutlass_library']
)

View File

@@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='pycute',
version='4.5.0',
version='4.5.1',
description='Python implementation of CuTe',
packages=['pycute'],
)

View File

@@ -554,6 +554,112 @@ struct RandomUniformFunc {
}
};
/// Computes an exponent-uniform random distribution for UE8M0 scale factors.
template <>
struct RandomUniformFunc<float_ue8m0_t> {
using Element = float_ue8m0_t;
using FloatType = float;
/// Parameters structure
struct Params {
//
// Data members
//
uint64_t seed;
int exp_min;
int exp_range;
int int_scale; ///< Retained for Params compatibility; exponent is integral.
double pnan;
int exclude_zero; ///< Retained for Params compatibility; unused for UE8M0.
/// Default ctor
CUTLASS_HOST_DEVICE
Params() { }
//
// Methods
//
CUTLASS_HOST_DEVICE
static int closest_log2_exp(FloatType value) {
using CUTLASS_CMATH_NAMESPACE :: log2;
using CUTLASS_CMATH_NAMESPACE :: nearbyint;
// UE8M0 scale factors are strictly positive. Keep invalid lower bounds
// finite so callers using the generic [0, max] default do not produce NaN.
FloatType min_scale = FloatType(Element::bitcast(0x01));
FloatType positive_value = value > FloatType(0) ? value : min_scale;
return int(nearbyint(log2(positive_value)));
}
/// Construction of uniform RNG functor.
Params(
uint64_t seed_ = 0,
FloatType max_ = FloatType(1),
FloatType min_ = FloatType(0),
int int_scale_ = -1,
double pnan_ = 0,
int exclude_zero_ = -1
):
seed(seed_),
exp_min(closest_log2_exp(min_)),
exp_range(closest_log2_exp(max_) - closest_log2_exp(min_)),
int_scale(int_scale_),
pnan(pnan_),
exclude_zero(exclude_zero_) {
}
};
//
// Data members
//
/// Parameters object
Params params;
/// RNG state object
curandState_t rng_state;
//
// Methods
//
/// Device-side initialization of RNG
CUTLASS_DEVICE
RandomUniformFunc(Params const &params): params(params) {
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
curand_init(params.seed, gtid, 0, &rng_state);
}
/// Compute random value and update RNG state
CUTLASS_DEVICE
Element operator()() {
// Draw random float in [0.0, 1.0] to determine if element should be NaN.
if constexpr (std::numeric_limits<Element>::has_quiet_NaN) {
if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) {
return Element(NAN);
}
}
using CUTLASS_CMATH_NAMESPACE :: pow;
FloatType rnd = random_uniform_float<FloatType>(&rng_state);
int exponent_count = params.exp_range + 1;
int exponent_offset = int(rnd * FloatType(exponent_count));
exponent_offset = exponent_offset < exponent_count ? exponent_offset : exponent_count - 1;
FloatType exp = FloatType(params.exp_min + exponent_offset);
FloatType sf = FloatType(pow(FloatType(2), exp));
return Element(sf);
}
};
/// Computes a random Gaussian distribution
template <typename Real>
struct RandomUniformFunc<complex<Real>> {
@@ -763,6 +869,16 @@ struct TensorFillRandomUniformFunc {
}
};
template <typename Element>
struct UniformDistributionValueType {
using Type = typename RealType<Element>::Type;
};
template <>
struct UniformDistributionValueType<float_ue8m0_t> {
using Type = float;
};
} // namespace detail
///////////////////////////////////////////////////////////////////////////////////////////////////
@@ -774,8 +890,10 @@ template <
void TensorFillRandomUniform(
TensorView<Element, Layout> view, ///< destination tensor
uint64_t seed, ///< seed for RNG
typename RealType<Element>::Type max = Element(1), ///< upper bound of distribution
typename RealType<Element>::Type min = Element(0), ///< lower bound for distribution
typename detail::UniformDistributionValueType<Element>::Type max =
typename detail::UniformDistributionValueType<Element>::Type(1), ///< upper bound of distribution
typename detail::UniformDistributionValueType<Element>::Type min =
typename detail::UniformDistributionValueType<Element>::Type(0), ///< lower bound for distribution
int bits = -1, ///< If non-negative, specifies number of fractional bits that
/// are not truncated to zero. Permits reducing precision of
/// data.
@@ -805,8 +923,8 @@ void BlockFillRandomUniform(
Element *ptr,
size_t capacity,
uint64_t seed, ///< seed for RNG
typename RealType<Element>::Type max, ///< upper bound of distribution
typename RealType<Element>::Type min, ///< lower bound for distribution
typename detail::UniformDistributionValueType<Element>::Type max, ///< upper bound of distribution
typename detail::UniformDistributionValueType<Element>::Type min, ///< lower bound for distribution
int bits = -1, ///< If non-negative, specifies number of fractional bits that
/// are not truncated to zero. Permits reducing precision of
/// data.
@@ -1768,6 +1886,7 @@ void TensorFillRandom(
) {
using Real = typename RealType<Element>::Type;
using UniformReal = typename detail::UniformDistributionValueType<Element>::Type;
if (dist.kind == Distribution::Gaussian) {
TensorFillRandomGaussian<Element, Layout>(
@@ -1782,8 +1901,8 @@ void TensorFillRandom(
TensorFillRandomUniform<Element, Layout>(
view,
seed,
static_cast<Real>(dist.uniform.max),
static_cast<Real>(dist.uniform.min),
static_cast<UniformReal>(dist.uniform.max),
static_cast<UniformReal>(dist.uniform.min),
dist.int_scale,
dist.uniform.pnan,
exclude_zero,
@@ -1830,6 +1949,7 @@ void BlockFillRandom(
cudaStream_t stream = nullptr) {
using Real = typename RealType<Element>::Type;
using UniformReal = typename detail::UniformDistributionValueType<Element>::Type;
if (dist.kind == Distribution::Gaussian) {
BlockFillRandomGaussian<Element>(
@@ -1846,8 +1966,8 @@ void BlockFillRandom(
ptr,
capacity,
seed,
static_cast<Real>(dist.uniform.max),
static_cast<Real>(dist.uniform.min),
static_cast<UniformReal>(dist.uniform.max),
static_cast<UniformReal>(dist.uniform.min),
dist.int_scale,
dist.uniform.pnan,
stream);

View File

@@ -658,6 +658,74 @@ public:
}
};
/// Computes an exponent-uniform random distribution for UE8M0 scale factors.
template <>
struct RandomUniformFunc<float_ue8m0_t> {
using Element = float_ue8m0_t;
uint64_t seed;
int exp_min;
int exp_range;
int int_scale; ///< Retained for Params compatibility; exponent is integral.
double pnan;
private:
using engine_type = std::mt19937;
public:
engine_type bernoulli_rnd;
std::bernoulli_distribution bernoulli_dist;
bool exclude_zero; ///< Retained for Params compatibility; unused for UE8M0.
static int closest_log2_exp(double value) {
// UE8M0 scale factors are strictly positive. Keep invalid lower bounds
// finite so callers using the generic [0, max] default do not produce NaN.
double min_scale = double(Element::bitcast(0x01));
double positive_value = value > 0.0 ? value : min_scale;
return int(std::nearbyint(std::log2(positive_value)));
}
RandomUniformFunc(
uint64_t seed_ = 0,
double max = 1,
double min_ = 0,
int int_scale_ = -1,
double pnan_ = 0,
bool exclude_zero_ = false
):
seed(seed_),
exp_min(closest_log2_exp(min_)),
exp_range(closest_log2_exp(max) - closest_log2_exp(min_)),
int_scale(int_scale_),
pnan(pnan_),
bernoulli_rnd{static_cast<engine_type::result_type>(seed_)},
bernoulli_dist(pnan_),
exclude_zero(exclude_zero_)
{
std::srand((unsigned)seed);
}
/// Compute random value and update RNG state
Element operator()() {
// Sample from NaN distribution.
if constexpr (std::numeric_limits<Element>::has_quiet_NaN) {
if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) {
return Element(NAN);
}
}
double rnd = double(std::rand()) / double(RAND_MAX);
int exponent_count = exp_range + 1;
int exponent_offset = int(rnd * double(exponent_count));
exponent_offset = exponent_offset < exponent_count ? exponent_offset : exponent_count - 1;
double sf = std::pow(2.0, double(exp_min + exponent_offset));
return Element(sf);
}
};
/// Partial specialization for initializing a complex value.
template <typename Element>
struct RandomUniformFunc<complex<Element> > {