mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-23 06:15:03 +00:00
v4.5.1 update. (#3237)
This commit is contained in:
24
CHANGELOG.md
24
CHANGELOG.md
@@ -2,6 +2,28 @@
|
||||
|
||||
# CUTLASS 4.x
|
||||
|
||||
## [4.5.1](https://github.com/NVIDIA/cutlass/releases/tag/v4.5.1) (2026-05-15)
|
||||
|
||||
### CuTe DSL
|
||||
* Bug fixing and improvements
|
||||
- Fixed following issues:
|
||||
https://github.com/NVIDIA/cutlass/issues/3219
|
||||
https://github.com/NVIDIA/cutlass/issues/3218
|
||||
https://github.com/NVIDIA/cutlass/issues/3212
|
||||
https://github.com/NVIDIA/cutlass/issues/3210
|
||||
https://github.com/NVIDIA/cutlass/issues/3208
|
||||
https://github.com/NVIDIA/cutlass/issues/3201
|
||||
https://github.com/NVIDIA/cutlass/issues/3227
|
||||
- Fixed Jax int64 stride divisibility issue
|
||||
- Fixed issues for SM120 blockscaled MMAs
|
||||
- added missing MXFP8MMAOP and MXF8F6F4MMAOP for sm120.
|
||||
|
||||
### CUTLASS C++
|
||||
* Fix SM100 F8F6F4 SS MMA (1SM and 2SM) traits to use typed op templates.
|
||||
* Add UE8M0 (uniform exponent distribution) initialization support in tensor fill utilities.
|
||||
* Add `cvt.rn.bf16x2.e4m3x2` conversion instruction support to `numeric_conversion.h`.
|
||||
* Update [example 93](https://github.com/NVIDIA/cutlass/tree/main/examples/93_blackwell_low_latency_gqa) with paged KV cache support for Blackwell low-latency GQA.
|
||||
|
||||
## [4.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.5.0) (2026-05-01)
|
||||
|
||||
### CuTe DSL
|
||||
@@ -20,7 +42,7 @@
|
||||
- Improved source code correlation for profiling/debugging
|
||||
- Fixed an aarch64 segfault issue with tvm-ffi
|
||||
- Re-organization for CuTe DSL examples/tutorials for better discoverability
|
||||
|
||||
|
||||
* More examples of authorizing peak-performance kernels
|
||||
- MOE examles
|
||||
- A new style of grouped-gemm that aligns to torch's grouped_mm and scaled_groued_mm interface.
|
||||
|
||||
21
README.md
21
README.md
@@ -1,9 +1,9 @@
|
||||

|
||||
# Overview
|
||||
|
||||
# CUTLASS 4.5.0
|
||||
# CUTLASS 4.5.1
|
||||
|
||||
_CUTLASS 4.5.0 - May 2026_
|
||||
_CUTLASS 4.5.1 - May 2026_
|
||||
|
||||
CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM)
|
||||
and related computations at all levels and scales within CUDA. It incorporates strategies for
|
||||
@@ -61,6 +61,17 @@ To get started quickly - please refer :
|
||||
- Improved source code correlation for profiling/debugging
|
||||
- Fixed an aarch64 segfault issue with tvm-ffi
|
||||
- Re-organization for CuTe DSL examples/tutorials for better discoverability
|
||||
- Fixed following issues:
|
||||
https://github.com/NVIDIA/cutlass/issues/3219
|
||||
https://github.com/NVIDIA/cutlass/issues/3218
|
||||
https://github.com/NVIDIA/cutlass/issues/3212
|
||||
https://github.com/NVIDIA/cutlass/issues/3210
|
||||
https://github.com/NVIDIA/cutlass/issues/3208
|
||||
https://github.com/NVIDIA/cutlass/issues/3201
|
||||
https://github.com/NVIDIA/cutlass/issues/3227
|
||||
- Fixed Jax int64 stride divisibility issue
|
||||
- Fixed issues for SM120 blockscaled MMAs
|
||||
- added missing MXFP8MMAOP and MXF8F6F4MMAOP for sm120.
|
||||
|
||||
* More examples of authorizing peak-performance kernels
|
||||
- MOE examles
|
||||
@@ -90,11 +101,13 @@ To get started quickly - please refer :
|
||||
* Add [example 95](https://github.com/NVIDIA/cutlass/tree/main/examples/95_blackwell_gemm_green_context) to support green context SM partition
|
||||
- Enables launching GEMM on stream with partial SM allocation.
|
||||
* Add [Snake](https://github.com/NVIDIA/cutlass/blob/main/test/unit/epilogue/thread/activation.cu#L409) activation functor for EVT.
|
||||
* Fix SM100 F8F6F4 SS MMA (1SM and 2SM) traits to use typed op templates.
|
||||
* Add UE8M0 (uniform exponent distribution) initialization support in tensor fill utilities.
|
||||
* Add `cvt.rn.bf16x2.e4m3x2` conversion instruction support to `numeric_conversion.h`.
|
||||
* Update [example 93](https://github.com/NVIDIA/cutlass/tree/main/examples/93_blackwell_low_latency_gqa) with paged KV cache support for Blackwell low-latency GQA.
|
||||
* Fix some kernel issues:
|
||||
- Fix l2_capacity=0 handling in Blackwell SM100/SM120 kernel templates
|
||||
- Fix CUTLASS clang build issues
|
||||
- Fix atomicCAS read-modify-write loop in `ConstSubbyteReference`
|
||||
- Replace `__nv_atomic_load_n` with `volatile` for CUDA 11.4 compatibility in subbyte reference
|
||||
- Remove `PipelineStorage` shadowing in SM100 complex epilogue
|
||||
- Fix build issue in SM90 epilogue fusion visitor TMA warpspecialized
|
||||
* Fix some profiler issues:
|
||||
|
||||
@@ -75,19 +75,17 @@ CUTE_HOST_DEVICE constexpr
|
||||
auto
|
||||
to_tiled_mma_sm100_ts(
|
||||
TiledMMA<MMA_Atom<
|
||||
MMA_Traits<SM100_MMA_F8F6F4_SS, a_type, b_type, c_type,
|
||||
cute::C<M>, cute::C<N>,
|
||||
cute::integral_constant<UMMA::Major, a_major>,
|
||||
cute::integral_constant<UMMA::Major, b_major>,
|
||||
cute::integral_constant<UMMA::ScaleIn, a_neg>,
|
||||
cute::integral_constant<UMMA::ScaleIn, b_neg>>,
|
||||
SM100_MMA_F8F6F4_SS<a_type, b_type, c_type,
|
||||
M, N,
|
||||
a_major, b_major,
|
||||
a_neg, b_neg>,
|
||||
TAs...>, TMs...>) {
|
||||
|
||||
return TiledMMA<MMA_Atom<
|
||||
MMA_Traits<SM100_MMA_F8F6F4_TS<a_type, b_type, c_type,
|
||||
SM100_MMA_F8F6F4_TS<a_type, b_type, c_type,
|
||||
M, N,
|
||||
a_major, b_major,
|
||||
a_neg, b_neg, UMMA::Saturate::False>>,
|
||||
a_neg, b_neg, UMMA::Saturate::False>,
|
||||
TAs...>, TMs...>{};
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,9 @@ if (NOT MSVC AND CUTLASS_NVCC_ARCHS MATCHES "100a|100f|103a|103f")
|
||||
cutlass_example_add_executable(
|
||||
93_blackwell_low_latency_gqa
|
||||
tgv_gqa.cu
|
||||
common.cuh
|
||||
tgv_gqa.cuh
|
||||
tgv_gqa_paged.cuh
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
140
examples/93_blackwell_low_latency_gqa/common.cuh
Normal file
140
examples/93_blackwell_low_latency_gqa/common.cuh
Normal file
@@ -0,0 +1,140 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cuda_runtime.h"
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
|
||||
#include <iostream>
|
||||
#include <type_traits>
|
||||
|
||||
#ifndef gpuErrChk
|
||||
#define gpuErrChk(ans) { gpuAssert2((ans), __FILE__, __LINE__); }
|
||||
inline void gpuAssert2(cudaError_t code, const char *file, int line, bool abort=true) {
|
||||
if (code != cudaSuccess) {
|
||||
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
|
||||
if (abort) exit(code);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace TGV {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// Store value to remote shared memory in the cluster
|
||||
CUTE_DEVICE void
|
||||
store_shared_remote_f32(float value, uint32_t dsmem_addr, uint32_t remote_barrier_addr) {
|
||||
asm volatile("st.async.shared::cluster.mbarrier::complete_tx::bytes.f32 [%0], %1, [%2];"
|
||||
: : "r"(dsmem_addr), "f"(value), "r"(remote_barrier_addr));
|
||||
}
|
||||
|
||||
// given a smem tensor, return the dsmem tensor for the given rank, the tensor addr is in smem addr space (not generic addr space)
|
||||
template <class Tensor>
|
||||
CUTE_DEVICE auto
|
||||
get_dsmem_tensor(Tensor tensor, int rank) {
|
||||
using T = typename decltype(tensor)::value_type;
|
||||
// tensor.data().get() is the smem addr in the generic addr space, in the generic addr space a region is reserved for smem
|
||||
// doing ld/st to this region of the generic addr space will be converted into ld.shared/st.shared to the smem addr space by the compiler
|
||||
// the mapa (and many inline ptx) instruction's input and output addr are in the smem/dsmem addr space, so we need to explicitly convert from generic to shared addr space
|
||||
uint32_t smem_addr = __cvta_generic_to_shared(tensor.data().get()); // smem addr space
|
||||
// mapa to get the dsmem addr of this tensor in another CTA
|
||||
uint32_t dsmem_addr = set_block_rank(smem_addr, rank); // smem addr space
|
||||
return make_tensor(make_smem_ptr((T*)dsmem_addr), tensor.layout());
|
||||
}
|
||||
|
||||
// copied from SM100::TMEM::LOAD::copy_unpack cutlass/include/cute/atom/copy_traits_sm100.hpp
|
||||
// what it does is given a tmem address, load the data into rmem tensor with the given tcgen05.ld copy op
|
||||
template <
|
||||
class CopyOp,
|
||||
class TD, class DLayout>
|
||||
CUTLASS_DEVICE void
|
||||
tmem_load(
|
||||
uint32_t tmem_addr,
|
||||
Tensor<TD,DLayout>& dst
|
||||
) {
|
||||
static_assert(is_rmem<TD>::value, "Expected RMEM dst.");
|
||||
|
||||
using RegTypeDst = typename remove_extent<typename CopyOp::DRegisters>::type;
|
||||
Tensor rD = recast<RegTypeDst>(dst);
|
||||
|
||||
constexpr int RegNumDst = extent<typename CopyOp::DRegisters>::value;
|
||||
CUTE_STATIC_ASSERT_V(size(rD) == Int<RegNumDst>{},
|
||||
"The tcgen05.ld CopyOp's size does not match the destination tensor size.");
|
||||
|
||||
detail::explode(CopyOp::copy,
|
||||
&tmem_addr, seq<0>{},
|
||||
rD, make_seq<RegNumDst>{});
|
||||
}
|
||||
|
||||
// copied from SM100::TMEM::STORE::copy_unpack cutlass/include/cute/atom/copy_traits_sm100.hpp
|
||||
// what it does is given a tmem address, store the data in rmem tensor to the tmem address with the given tcgen05.st copy op
|
||||
template <
|
||||
class CopyOp,
|
||||
class TS, class SLayout>
|
||||
CUTLASS_DEVICE void
|
||||
tmem_store(
|
||||
Tensor<TS,SLayout>& src,
|
||||
uint32_t tmem_addr
|
||||
) {
|
||||
static_assert(is_rmem<TS>::value, "Expected RMEM src.");
|
||||
|
||||
using RegTypeSrc = typename remove_extent<typename CopyOp::SRegisters>::type;
|
||||
Tensor rS = recast<RegTypeSrc>(src);
|
||||
|
||||
constexpr int RegNumSrc = extent<typename CopyOp::SRegisters>::value;
|
||||
CUTE_STATIC_ASSERT_V(size(rS) == Int<RegNumSrc>{},
|
||||
"The tcgen05.st CopyOp's size does not match the source tensor size.");
|
||||
|
||||
detail::explode(CopyOp::copy,
|
||||
rS, make_seq<RegNumSrc>{},
|
||||
&tmem_addr, seq<0>{});
|
||||
}
|
||||
|
||||
// issue cp.async to load 4 bytes (one int) from gmem to smem
|
||||
CUTLASS_DEVICE void
|
||||
cp_async(
|
||||
int* gmem_addr,
|
||||
int* smem_addr
|
||||
) {
|
||||
uint32_t smem_int_ptr = cute::cast_smem_ptr_to_uint(smem_addr);
|
||||
asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"l"(gmem_addr),
|
||||
"n"(sizeof(int)));
|
||||
}
|
||||
|
||||
} // namespace TGV
|
||||
@@ -1,6 +1,18 @@
|
||||
# Blackwell Low Latency GQA
|
||||
|
||||
This example introduces TGV GQA, a CuTe C++-based Blackwell kernel optimized for low latency (low batch) generation phase GQA.
|
||||
The example ships two variants:
|
||||
|
||||
- `tgv_gqa.cuh` — contiguous KV cache (default, `--mode 0`).
|
||||
- `tgv_gqa_paged.cuh` — paged KV cache (`--mode 1`). Layout matches a typical paged-attention serving runtime: a combined KV
|
||||
buffer of shape `(num_pages_total, 2, Page_Size, kvH, dH)` (BS folded into `num_pages_total`, mode-1 selects K vs V) plus a
|
||||
`(kvL/Page_Size, BS)` page table that maps `(bs, per_batch_page_idx) -> physical page id`. The example harness builds the
|
||||
page table host-side using a Fisher-Yates shuffle to stress non-contiguous mappings; replacing that with another placement
|
||||
policy needs no kernel changes.
|
||||
|
||||
`common.cuh` holds the shared inline PTX wrappers used by both variants (`cp_async`, `tmem_load`/`tmem_store`,
|
||||
`store_shared_remote_f32`, `get_dsmem_tensor`).
|
||||
|
||||
To compile and run this example:
|
||||
```bash
|
||||
# in cutlass top level directory
|
||||
@@ -8,7 +20,10 @@ mkdir build && cd build
|
||||
cmake .. -DCUTLASS_NVCC_ARCHS=100a -DCUTLASS_ENABLE_TESTS=OFF -DCUTLASS_ENABLE_EXAMPLES=ON -DCUTLASS_ENABLE_LIBRARY=OFF
|
||||
cd examples/93_blackwell_low_latency_gqa
|
||||
make
|
||||
# contiguous KV cache (default)
|
||||
./93_blackwell_low_latency_gqa --kvL 8192 --kvH 8 --qH 64 --BS 1
|
||||
# paged KV cache
|
||||
./93_blackwell_low_latency_gqa --kvL 8192 --kvH 8 --qH 64 --BS 1 --mode 1
|
||||
```
|
||||
|
||||
Supported configs are:
|
||||
@@ -20,11 +35,11 @@ Supported configs are:
|
||||
- Flash decoding, configurable number of splits
|
||||
- Cluster reduction with configurable number of reduction cta
|
||||
- Attention sink and sliding window
|
||||
- Paged KV cache (`--mode 1`)
|
||||
|
||||
Unsupported features are:
|
||||
- Persistent schedule
|
||||
- MTP
|
||||
- Paged KV cache
|
||||
|
||||
## Kernel Design
|
||||
|
||||
|
||||
@@ -40,8 +40,14 @@
|
||||
kvL is max_seq_len, seq_lens[BS] is the actual seq len for each batch
|
||||
sinks has shape (qHLocal * kvH), i.e. one sink per q head
|
||||
|
||||
--mode 1 enables the paged variant:
|
||||
Combined KV cache shape (num_pages_total, 2, Page_Size, kvH, dH), with BS folded into num_pages_total
|
||||
and KV(_,0,_,_,_) = K, KV(_,1,_,_,_) = V.
|
||||
page_table shape (kvL/Page_Size, BS), entry [p, bs] = physical page id of batch bs's per-batch page p.
|
||||
|
||||
Example usage:
|
||||
$ ./examples/93_blackwell_low_latency_gqa --kvL 8192 --kvH 8 --qH 64 --BS 1
|
||||
$ ./examples/93_blackwell_low_latency_gqa --kvL 8192 --kvH 8 --qH 64 --BS 1 --mode 1
|
||||
*/
|
||||
|
||||
// Standard library includes
|
||||
@@ -51,6 +57,8 @@
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <ctime>
|
||||
#include <numeric>
|
||||
#include <utility>
|
||||
#include <getopt.h>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
@@ -68,7 +76,9 @@
|
||||
// CuTe includes
|
||||
#include <cute/tensor.hpp> // CuTe tensor implementation
|
||||
|
||||
#include "common.cuh"
|
||||
#include "tgv_gqa.cuh"
|
||||
#include "tgv_gqa_paged.cuh"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@@ -411,9 +421,23 @@ struct ProblemStride {
|
||||
int stride_O_qL;
|
||||
int stride_O_dH;
|
||||
int stride_O_BS;
|
||||
|
||||
// Combined KV (paged mode) layout: (num_pages_total, 2, Page_Size, kvH, dH).
|
||||
// BS is folded into num_pages_total = BS * kvL / Page_Size (no explicit BS mode).
|
||||
// Mode 1 is the K/V selector; harness picks dH-innermost packed strides below.
|
||||
int stride_KV_pages;
|
||||
int stride_KV_KV;
|
||||
int stride_KV_ps;
|
||||
int stride_KV_kvH;
|
||||
int stride_KV_dH;
|
||||
|
||||
// Page table (paged mode) layout: (kvL/Page_Size, BS), entry [p, bs] = physical page id of batch bs's per-batch
|
||||
// page p. Mode 0 is innermost in memory (per-batch page id, stride 1); mode 1 advances by pages_per_batch across batches.
|
||||
int stride_PT_p;
|
||||
int stride_PT_BS;
|
||||
};
|
||||
|
||||
ProblemStride make_gqa_stride(int kvH, int qHLocal, int qL, int kvL, int dH, int BS) {
|
||||
ProblemStride make_gqa_stride(int kvH, int qHLocal, int qL, int kvL, int dH, int BS, int Page_Size) {
|
||||
ProblemStride stride;
|
||||
|
||||
// Q shape ((qHLocal, qL), dH, kvH, BS), where dH is contiguous
|
||||
@@ -446,6 +470,20 @@ ProblemStride make_gqa_stride(int kvH, int qHLocal, int qL, int kvL, int dH, int
|
||||
stride.stride_O_dH = 1;
|
||||
stride.stride_O_BS = kvH * qHLocal * dH * qL;
|
||||
|
||||
// Combined KV (paged) shape (num_pages_total, 2, Page_Size, kvH, dH), dH innermost contiguous.
|
||||
// BS is folded into num_pages_total; the (bs_idx, per_batch_page_idx) -> physical page mapping lives in the
|
||||
// page_table tensor (see stride_PT_*). Strides slowest -> fastest:
|
||||
// num_pages_total, KV (K/V selector), Page_Size, kvH, dH.
|
||||
stride.stride_KV_dH = 1;
|
||||
stride.stride_KV_kvH = dH;
|
||||
stride.stride_KV_ps = kvH * dH;
|
||||
stride.stride_KV_KV = Page_Size * kvH * dH;
|
||||
stride.stride_KV_pages = 2 * Page_Size * kvH * dH;
|
||||
|
||||
// Page table shape (kvL/Page_Size, BS); per-batch page idx is mode 0 (contiguous, stride 1), batch is mode 1 (advances by pages_per_batch).
|
||||
stride.stride_PT_p = 1;
|
||||
stride.stride_PT_BS = kvL / Page_Size;
|
||||
|
||||
return stride;
|
||||
}
|
||||
|
||||
@@ -459,17 +497,25 @@ public:
|
||||
static constexpr int CTA_qL = 1;
|
||||
static constexpr int CTA_kvL = 128;
|
||||
static constexpr int CTA_dH = 64;
|
||||
// Page_Size only used by gqa_paged (mode 1). Page_Size must divide CTA_kvL; CTA_kvL/Page_Size = pages per CTA tile.
|
||||
static constexpr int Page_Size = 32;
|
||||
static constexpr int BMM1_DMA_Stage = 3;
|
||||
static constexpr int BMM2_DMA_Stage = 3;
|
||||
// Page-idx staging (mode 1 only). Num_Page_Idx_Per_Stage must be a multiple of CTA_kvL/Page_Size
|
||||
// so a DMA stage's pages live in one pi stage. Page_Idx_Stage = pipeline depth on the page-idx side.
|
||||
static constexpr int Page_Idx_Stage = 2;
|
||||
static constexpr int Num_Page_Idx_Per_Stage = 8 * (CTA_kvL / Page_Size);
|
||||
static constexpr int MaxSplits = 8;
|
||||
static constexpr int NumReductionCTA = 8;
|
||||
static constexpr bool NoSink = true;
|
||||
static constexpr bool VarSeqLens = false;
|
||||
|
||||
private:
|
||||
int kvH_, qHLocal_, qL_, kvL_, dH_, BS_;
|
||||
float softmax_scale_;
|
||||
ProblemStride stride_;
|
||||
int sliding_window_size_;
|
||||
int mode_; // 0: gqa, 1: gqa_paged
|
||||
|
||||
// Host vectors
|
||||
thrust::host_vector<TypeQKV> host_Q_;
|
||||
@@ -480,17 +526,82 @@ private:
|
||||
thrust::host_vector<int> host_seq_lens_;
|
||||
thrust::host_vector<TypeAcc> host_sinks_;
|
||||
|
||||
// Device vectors
|
||||
// Device vectors
|
||||
thrust::device_vector<TypeQKV> device_Q_;
|
||||
thrust::device_vector<TypeQKV> device_K_;
|
||||
thrust::device_vector<TypeQKV> device_V_;
|
||||
// Combined KV cache for paged mode (mode_ == 1).
|
||||
// Layout: (num_pages_total, 2, Page_Size, kvH, dH) with dH innermost; BS is folded into num_pages_total.
|
||||
// The (bs, per_batch_page_idx) -> physical page mapping is held in device_page_table_ (built by
|
||||
// build_random_page_table: a Fisher-Yates shuffle of [0, num_pages_total) sliced per batch).
|
||||
// KV(_,0,_,_,_) is K, KV(_,1,_,_,_) is V.
|
||||
thrust::device_vector<TypeQKV> device_KV_;
|
||||
// Page table for paged mode. Logical shape (kvL/Page_Size, BS); entry [p, bs] is the physical page id for
|
||||
// batch bs and per-batch page index p. Padded by Num_Page_Idx_Per_Stage ints at the tail so the device's
|
||||
// last-pi-stage cp.async never reads past the allocation.
|
||||
thrust::device_vector<int> device_page_table_;
|
||||
thrust::device_vector<TypeO> device_O_;
|
||||
thrust::device_vector<int> device_seq_lens_;
|
||||
thrust::device_vector<TypeAcc> device_sinks_;
|
||||
|
||||
// Build the host-side page table: tail-padded buffer with shape (kvL/Page_Size, BS) and contents = a random
|
||||
// per-batch slice of a Fisher-Yates shuffle of [0, num_pages_total). Each batch's slice length is
|
||||
// seq_len[bs]/Page_Size, so total pages assigned = sum(seq_len_pages) <= num_pages_total. Tail padding
|
||||
// (Num_Page_Idx_Per_Stage ints) covers the device's last-pi-stage cp.async OOB read.
|
||||
thrust::host_vector<int> build_random_page_table(int pages_per_batch, int num_pages_total) {
|
||||
thrust::host_vector<int> host_page_table(num_pages_total + Num_Page_Idx_Per_Stage, 0);
|
||||
auto host_tensor_page_table = make_tensor(host_page_table.data(),
|
||||
make_layout(make_shape(pages_per_batch, BS_),
|
||||
make_stride(stride_.stride_PT_p, stride_.stride_PT_BS)));
|
||||
std::vector<int> perm(num_pages_total);
|
||||
std::iota(perm.begin(), perm.end(), 0);
|
||||
for (int i = num_pages_total - 1; i > 0; --i) {
|
||||
std::swap(perm[i], perm[rand() % (i + 1)]);
|
||||
}
|
||||
int perm_offset = 0;
|
||||
for (int bs = 0; bs < BS_; ++bs) {
|
||||
// ceil_div: a partial tail page (seq_len % Page_Size != 0) still needs a physical page assigned
|
||||
// so the kernel's address-by-page lookup for positions [floor(seq_len/Page_Size)*Page_Size, seq_len)
|
||||
// hits real packed K/V data. Positions past seq_len are masked by the kernel and don't contribute.
|
||||
int seq_len_pages = cutlass::ceil_div(host_seq_lens_[bs], Page_Size);
|
||||
for (int p = 0; p < seq_len_pages; ++p) {
|
||||
host_tensor_page_table(p, bs) = perm[perm_offset + p];
|
||||
}
|
||||
perm_offset += seq_len_pages;
|
||||
}
|
||||
return host_page_table;
|
||||
}
|
||||
|
||||
// Pack the flat per-batch K/V buffers into the combined paged KV layout, using the page table for the
|
||||
// (bs, per_batch_page) -> physical page mapping. Bounded by seq_len_pages per batch so out-of-seq-len entries
|
||||
// (which the page table doesn't populate) aren't dereferenced.
|
||||
template <class HostTensorPageTable, class HostTensorK, class HostTensorV, class HostTensorKV>
|
||||
void pack_combined_kv(HostTensorPageTable const& host_tensor_page_table,
|
||||
HostTensorK const& host_tensor_K, HostTensorV const& host_tensor_V,
|
||||
HostTensorKV& host_tensor_KV) {
|
||||
for (int bs = 0; bs < BS_; ++bs) {
|
||||
// ceil_div pages so the partial tail page (when seq_len isn't page-aligned) is packed.
|
||||
// We pack the full Page_Size for the tail page; positions past seq_len carry whatever
|
||||
// initialize_tensor wrote into the K/V buffers, but the kernel masks those out.
|
||||
int seq_len_pages = cutlass::ceil_div(host_seq_lens_[bs], Page_Size);
|
||||
for (int p = 0; p < seq_len_pages; ++p) {
|
||||
int global_page = host_tensor_page_table(p, bs);
|
||||
for (int ps = 0; ps < Page_Size; ++ps) {
|
||||
int kvl = p * Page_Size + ps;
|
||||
for (int kvh = 0; kvh < kvH_; ++kvh) {
|
||||
for (int dh = 0; dh < dH_; ++dh) {
|
||||
host_tensor_KV(global_page, 0, ps, kvh, dh) = host_tensor_K(kvl, dh, kvh, bs);
|
||||
host_tensor_KV(global_page, 1, ps, kvh, dh) = host_tensor_V(dh, kvl, kvh, bs);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
GQATester(int kvH, int qH, int qL, int kvL, int dH, int BS, float softmax_scale, int sliding_window_size) :
|
||||
kvH_(kvH), qHLocal_(qH / kvH), qL_(qL), kvL_(kvL), dH_(dH), BS_(BS), softmax_scale_(softmax_scale), sliding_window_size_(sliding_window_size) {
|
||||
GQATester(int kvH, int qH, int qL, int kvL, int dH, int BS, float softmax_scale, int sliding_window_size, int mode = 0) :
|
||||
kvH_(kvH), qHLocal_(qH / kvH), qL_(qL), kvL_(kvL), dH_(dH), BS_(BS), softmax_scale_(softmax_scale), sliding_window_size_(sliding_window_size), mode_(mode) {
|
||||
assert(sliding_window_size_ >= 0);
|
||||
// Allocate host memory
|
||||
host_Q_.resize(kvH_ * qHLocal_ * qL_ * dH_ * BS_);
|
||||
@@ -501,7 +612,7 @@ public:
|
||||
host_seq_lens_.resize(BS_);
|
||||
host_sinks_.resize(qHLocal_ * kvH_); // one sink per q head
|
||||
|
||||
stride_ = make_gqa_stride(kvH_, qHLocal_, qL_, kvL_, dH_, BS_);
|
||||
stride_ = make_gqa_stride(kvH_, qHLocal_, qL_, kvL_, dH_, BS_, Page_Size);
|
||||
|
||||
// Create host CuTe tensors for initialization
|
||||
auto host_tensor_Q = make_tensor(host_Q_.data(), TGV::gqa::make_layout_Q(kvH_, qHLocal_, qL_, dH_, BS_, stride_.stride_Q_kvH, stride_.stride_Q_qHLocal, stride_.stride_Q_qL, stride_.stride_Q_dH, stride_.stride_Q_BS));
|
||||
@@ -513,10 +624,8 @@ public:
|
||||
initialize_tensor(host_tensor_Q);
|
||||
initialize_tensor(host_tensor_K);
|
||||
initialize_tensor(host_tensor_V);
|
||||
// have batch size matching kvL (i.e. max seq len) for now
|
||||
bool test_var_seq_lens = false;
|
||||
for (int i = 0; i < BS_; ++i) {
|
||||
if (test_var_seq_lens) {
|
||||
if (VarSeqLens) {
|
||||
host_seq_lens_[i] = rand() % kvL_ + 1;
|
||||
}
|
||||
else { // all the batch have the same seq len
|
||||
@@ -535,27 +644,81 @@ public:
|
||||
device_seq_lens_ = host_seq_lens_;
|
||||
device_sinks_ = host_sinks_;
|
||||
|
||||
// For paged mode, build the page_table and combined KV tensor on device. The harness owns the layouts:
|
||||
// stride_KV_* and stride_PT_* were computed in make_gqa_stride above and are used both here (host pack)
|
||||
// and downstream (passed into gqa_paged_host as stride args). Combined KV shape: (num_pages_total, 2,
|
||||
// Page_Size, kvH, dH), BS folded into num_pages_total. We populate the page_table via
|
||||
// build_random_page_table (Fisher-Yates shuffle of [0, num_pages_total) sliced per batch) to stress
|
||||
// non-contiguous mappings.
|
||||
if (mode_ == 1) {
|
||||
assert(kvL_ % Page_Size == 0);
|
||||
assert(kvL_ % CTA_kvL == 0);
|
||||
int pages_per_batch = kvL_ / Page_Size;
|
||||
int num_pages_total = BS_ * pages_per_batch;
|
||||
|
||||
auto host_page_table = build_random_page_table(pages_per_batch, num_pages_total);
|
||||
auto host_tensor_page_table = make_tensor(host_page_table.data(),
|
||||
make_layout(make_shape(pages_per_batch, BS_),
|
||||
make_stride(stride_.stride_PT_p, stride_.stride_PT_BS)));
|
||||
|
||||
thrust::host_vector<TypeQKV> host_KV(num_pages_total * stride_.stride_KV_pages);
|
||||
auto host_tensor_KV = make_tensor(host_KV.data(),
|
||||
make_layout(make_shape(num_pages_total, 2, Page_Size, kvH_, dH_),
|
||||
make_stride(stride_.stride_KV_pages, stride_.stride_KV_KV,
|
||||
stride_.stride_KV_ps, stride_.stride_KV_kvH, stride_.stride_KV_dH)));
|
||||
pack_combined_kv(host_tensor_page_table, host_tensor_K, host_tensor_V, host_tensor_KV);
|
||||
|
||||
device_KV_ = host_KV;
|
||||
device_page_table_ = host_page_table;
|
||||
}
|
||||
|
||||
gpuErrChk(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
void run_kernel(bool pdl, int pdl_count = -1, cudaStream_t stream = 0) {
|
||||
TGV::gqa::gqa_host<
|
||||
TypeQKV, TypeO, TypeAcc,
|
||||
CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH,
|
||||
BMM1_DMA_Stage, BMM2_DMA_Stage,
|
||||
MaxSplits,
|
||||
NumReductionCTA>(
|
||||
device_K_.data().get(), device_Q_.data().get(), device_V_.data().get(), device_O_.data().get(),
|
||||
device_seq_lens_.data().get(),
|
||||
NoSink ? nullptr : device_sinks_.data().get(),
|
||||
kvH_, qHLocal_, qL_, kvL_, dH_, BS_,
|
||||
stride_.stride_K_kvH, stride_.stride_K_kvL, stride_.stride_K_dH, stride_.stride_K_BS,
|
||||
stride_.stride_Q_kvH, stride_.stride_Q_qHLocal, stride_.stride_Q_qL, stride_.stride_Q_dH, stride_.stride_Q_BS,
|
||||
stride_.stride_V_kvH, stride_.stride_V_kvL, stride_.stride_V_dH, stride_.stride_V_BS,
|
||||
stride_.stride_O_kvH, stride_.stride_O_qHLocal, stride_.stride_O_qL, stride_.stride_O_dH, stride_.stride_O_BS,
|
||||
softmax_scale_,
|
||||
sliding_window_size_,
|
||||
pdl, pdl_count, stream);
|
||||
if (mode_ == 0) {
|
||||
TGV::gqa::gqa_host<
|
||||
TypeQKV, TypeO, TypeAcc,
|
||||
CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH,
|
||||
BMM1_DMA_Stage, BMM2_DMA_Stage,
|
||||
MaxSplits,
|
||||
NumReductionCTA>(
|
||||
device_K_.data().get(), device_Q_.data().get(), device_V_.data().get(), device_O_.data().get(),
|
||||
device_seq_lens_.data().get(),
|
||||
NoSink ? nullptr : device_sinks_.data().get(),
|
||||
kvH_, qHLocal_, qL_, kvL_, dH_, BS_,
|
||||
stride_.stride_K_kvH, stride_.stride_K_kvL, stride_.stride_K_dH, stride_.stride_K_BS,
|
||||
stride_.stride_Q_kvH, stride_.stride_Q_qHLocal, stride_.stride_Q_qL, stride_.stride_Q_dH, stride_.stride_Q_BS,
|
||||
stride_.stride_V_kvH, stride_.stride_V_kvL, stride_.stride_V_dH, stride_.stride_V_BS,
|
||||
stride_.stride_O_kvH, stride_.stride_O_qHLocal, stride_.stride_O_qL, stride_.stride_O_dH, stride_.stride_O_BS,
|
||||
softmax_scale_,
|
||||
sliding_window_size_,
|
||||
pdl, pdl_count, stream);
|
||||
}
|
||||
else {
|
||||
TGV::gqa_paged::gqa_paged_host<
|
||||
TypeQKV, TypeO, TypeAcc,
|
||||
CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH,
|
||||
Page_Size,
|
||||
BMM1_DMA_Stage, BMM2_DMA_Stage,
|
||||
Page_Idx_Stage, Num_Page_Idx_Per_Stage,
|
||||
MaxSplits,
|
||||
NumReductionCTA>(
|
||||
device_KV_.data().get(),
|
||||
device_Q_.data().get(),
|
||||
device_O_.data().get(),
|
||||
NoSink ? nullptr : device_sinks_.data().get(),
|
||||
device_seq_lens_.data().get(),
|
||||
device_page_table_.data().get(),
|
||||
kvH_, qHLocal_, qL_, kvL_, dH_, BS_,
|
||||
stride_.stride_KV_pages, stride_.stride_KV_KV, stride_.stride_KV_ps, stride_.stride_KV_kvH, stride_.stride_KV_dH,
|
||||
stride_.stride_Q_kvH, stride_.stride_Q_qHLocal, stride_.stride_Q_qL, stride_.stride_Q_dH, stride_.stride_Q_BS,
|
||||
stride_.stride_O_kvH, stride_.stride_O_qHLocal, stride_.stride_O_qL, stride_.stride_O_dH, stride_.stride_O_BS,
|
||||
stride_.stride_PT_p, stride_.stride_PT_BS,
|
||||
softmax_scale_,
|
||||
sliding_window_size_,
|
||||
pdl, pdl_count, stream);
|
||||
}
|
||||
}
|
||||
|
||||
bool verify() {
|
||||
@@ -604,16 +767,17 @@ public:
|
||||
|
||||
};
|
||||
|
||||
void benchmark_gqa(int kvH, int qH, int qL, int kvL, int dH, int BS, float softmax_scale, int sliding_window_size, bool pdl, int pdl_count, int num_testers = 4, int bench_iters = 100) {
|
||||
void benchmark_gqa(int kvH, int qH, int qL, int kvL, int dH, int BS, float softmax_scale, int sliding_window_size, int mode, bool pdl, int pdl_count, int num_testers = 4, int bench_iters = 100) {
|
||||
std::cout << "=== GQA Benchmark ===" << std::endl;
|
||||
std::cout << "Problem size: kvH=" << kvH << ", qH=" << qH << ", qL=" << qL << ", kvL=" << kvL << ", dH=" << dH << ", BS=" << BS << ", sliding_window_size=" << sliding_window_size << std::endl;
|
||||
std::cout << "Mode: " << mode << " (" << (mode == 0 ? "gqa" : "gqa_paged") << ")" << std::endl;
|
||||
std::cout << "Number of testers (L2 thrashing): " << num_testers << std::endl;
|
||||
std::cout << "Benchmark iterations: " << bench_iters << std::endl;
|
||||
|
||||
// Create multiple tester instances to thrash L2 cache
|
||||
std::vector<std::unique_ptr<GQATester>> testers;
|
||||
for (int i = 0; i < num_testers; ++i) {
|
||||
testers.push_back(std::make_unique<GQATester>(kvH, qH, qL, kvL, dH, BS, softmax_scale, sliding_window_size));
|
||||
testers.push_back(std::make_unique<GQATester>(kvH, qH, qL, kvL, dH, BS, softmax_scale, sliding_window_size, mode));
|
||||
}
|
||||
std::cout << "Created " << num_testers << " GQATester instances" << std::endl;
|
||||
|
||||
@@ -708,6 +872,8 @@ int main(int argc, char* argv[]) {
|
||||
bool pdl = false;
|
||||
// don't support it yet
|
||||
int pdl_count = -1;
|
||||
// 0: gqa (contiguous KV cache), 1: gqa_paged (paged KV cache)
|
||||
int mode = 0;
|
||||
|
||||
// arg parsing
|
||||
while (1) {
|
||||
@@ -718,6 +884,7 @@ int main(int argc, char* argv[]) {
|
||||
{"qL", required_argument, 0, 0},
|
||||
{"BS", required_argument, 0, 0},
|
||||
{"sliding_window_size", required_argument, 0, 0},
|
||||
{"mode", required_argument, 0, 0},
|
||||
{0, 0, 0, 0} // denote end of array
|
||||
};
|
||||
|
||||
@@ -737,14 +904,17 @@ int main(int argc, char* argv[]) {
|
||||
else if (option_index == 3) qL = atoi(optarg);
|
||||
else if (option_index == 4) BS = atoi(optarg);
|
||||
else if (option_index == 5) sliding_window_size = atoi(optarg);
|
||||
else if (option_index == 6) mode = atoi(optarg);
|
||||
break;
|
||||
default: assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
GQATester tester(kvH, qH, qL, kvL, dH, BS, softmax_scale, sliding_window_size);
|
||||
GQATester tester(kvH, qH, qL, kvL, dH, BS, softmax_scale, sliding_window_size, mode);
|
||||
bool success = tester.verify();
|
||||
std::cout << "Correctness test " << (success ? "PASSED" : "FAILED") << std::endl;
|
||||
std::cout << "Correctness test"
|
||||
<< " mode=" << mode
|
||||
<< " " << (success ? "PASSED" : "FAILED") << std::endl;
|
||||
|
||||
benchmark_gqa(kvH, qH, qL, kvL, dH, BS, softmax_scale, sliding_window_size, pdl, pdl_count, 100, 1000);
|
||||
benchmark_gqa(kvH, qH, qL, kvL, dH, BS, softmax_scale, sliding_window_size, mode, pdl, pdl_count, 100, 1000);
|
||||
}
|
||||
|
||||
@@ -50,13 +50,7 @@
|
||||
#include <cute/arch/tmem_allocator_sm100.hpp> // TMEM allocator for SM100
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
|
||||
#define gpuErrChk(ans) { gpuAssert2((ans), __FILE__, __LINE__); }
|
||||
inline void gpuAssert2(cudaError_t code, const char *file, int line, bool abort=true) {
|
||||
if (code != cudaSuccess) {
|
||||
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
|
||||
if (abort) exit(code);
|
||||
}
|
||||
}
|
||||
#include "common.cuh"
|
||||
|
||||
namespace TGV {
|
||||
namespace gqa {
|
||||
@@ -93,27 +87,6 @@ acc = tl.dot(p.to(q.dtype), v) # [BLOCK_qL, BLOCK_dH]
|
||||
tl.store(acc_ptrs, acc)
|
||||
*/
|
||||
|
||||
// Store value to remote shared memory in the cluster
|
||||
CUTE_DEVICE void
|
||||
store_shared_remote_f32(float value, uint32_t dsmem_addr, uint32_t remote_barrier_addr) {
|
||||
asm volatile("st.async.shared::cluster.mbarrier::complete_tx::bytes.f32 [%0], %1, [%2];"
|
||||
: : "r"(dsmem_addr), "f"(value), "r"(remote_barrier_addr));
|
||||
}
|
||||
|
||||
// given a smem tensor, return the dsmem tensor for the given rank, the tensor addr is in smem addr space (not generic addr space)
|
||||
template <class Tensor>
|
||||
CUTE_DEVICE auto
|
||||
get_dsmem_tensor(Tensor tensor, int rank) {
|
||||
using T = typename decltype(tensor)::value_type;
|
||||
// tensor.data().get() is the smem addr in the generic addr space, in the generic addr space a region is reserved for smem
|
||||
// doing ld/st to this region of the generic addr space will be converted into ld.shared/st.shared to the smem addr space by the compiler
|
||||
// the mapa (and many inline ptx) instruction's input and output addr are in the smem/dsmem addr space, so we need to explicitly convert from generic to shared addr space
|
||||
uint32_t smem_addr = __cvta_generic_to_shared(tensor.data().get()); // smem addr space
|
||||
// mapa to get the dsmem addr of this tensor in another CTA
|
||||
uint32_t dsmem_addr = set_block_rank(smem_addr, rank); // smem addr space
|
||||
return make_tensor(make_smem_ptr((T*)dsmem_addr), tensor.layout());
|
||||
}
|
||||
|
||||
// Helper methods to create layouts
|
||||
// K always has the shape (kvL, dH, kvH, BS)
|
||||
// kvH has to be the last dim because we do mma partitioning to the first two dims (M, K) in gemm terminology
|
||||
@@ -243,6 +216,8 @@ struct SharedStorage {
|
||||
alignas(16) cute::uint64_t bmm1_softmax_full_barrier; // Barrier between BMM1 and softmax, BMM1 tells softmax the tile is ready/full, softmax can start consuming it
|
||||
alignas(16) cute::uint64_t bmm2_epilog_full_barrier; // Barrier between BMM2 and epilog, BMM2 tells epilog the tile is ready/full, epilog can start consuming it
|
||||
|
||||
alignas(16) cute::uint64_t tmem_allocation_result_barrier; // Barrier between MMA and epilog, sync tmem allocation/deallocation status between MMA and epilogue warps within CTA
|
||||
|
||||
// for cluster reduction
|
||||
alignas(16) cute::uint64_t maxsum_mailbox_full_barrier; // barrier indicating the st.async of fmax and fsum are done
|
||||
alignas(16) cute::uint64_t acc2_mailbox_full_barrier; // barrier indicating the st.async of acc2 are done
|
||||
@@ -512,67 +487,6 @@ cta_reduce_transposed(
|
||||
return acc;
|
||||
}
|
||||
|
||||
// copied from SM100::TMEM::LOAD::copy_unpack cutlass/include/cute/atom/copy_traits_sm100.hpp
|
||||
// what it does is given a tmem address, load the data into rmem tensor with the given tcgen05.ld copy op
|
||||
template <
|
||||
class CopyOp,
|
||||
class TD, class DLayout>
|
||||
CUTLASS_DEVICE void
|
||||
tmem_load(
|
||||
uint32_t tmem_addr,
|
||||
Tensor<TD,DLayout>& dst
|
||||
) {
|
||||
static_assert(is_rmem<TD>::value, "Expected RMEM dst.");
|
||||
|
||||
using RegTypeDst = typename remove_extent<typename CopyOp::DRegisters>::type;
|
||||
Tensor rD = recast<RegTypeDst>(dst);
|
||||
|
||||
constexpr int RegNumDst = extent<typename CopyOp::DRegisters>::value;
|
||||
CUTE_STATIC_ASSERT_V(size(rD) == Int<RegNumDst>{},
|
||||
"The tcgen05.ld CopyOp's size does not match the destination tensor size.");
|
||||
|
||||
detail::explode(CopyOp::copy,
|
||||
&tmem_addr, seq<0>{},
|
||||
rD, make_seq<RegNumDst>{});
|
||||
}
|
||||
|
||||
// copied from SM100::TMEM::STORE::copy_unpack cutlass/include/cute/atom/copy_traits_sm100.hpp
|
||||
// what it does is given a tmem address, store the data in rmem tensor to the tmem address with the given tcgen05.st copy op
|
||||
template <
|
||||
class CopyOp,
|
||||
class TS, class SLayout>
|
||||
CUTLASS_DEVICE void
|
||||
tmem_store(
|
||||
Tensor<TS,SLayout>& src,
|
||||
uint32_t tmem_addr
|
||||
) {
|
||||
static_assert(is_rmem<TS>::value, "Expected RMEM src.");
|
||||
|
||||
using RegTypeSrc = typename remove_extent<typename CopyOp::SRegisters>::type;
|
||||
Tensor rS = recast<RegTypeSrc>(src);
|
||||
|
||||
constexpr int RegNumSrc = extent<typename CopyOp::SRegisters>::value;
|
||||
CUTE_STATIC_ASSERT_V(size(rS) == Int<RegNumSrc>{},
|
||||
"The tcgen05.st CopyOp's size does not match the source tensor size.");
|
||||
|
||||
detail::explode(CopyOp::copy,
|
||||
rS, make_seq<RegNumSrc>{},
|
||||
&tmem_addr, seq<0>{});
|
||||
}
|
||||
|
||||
// issue cp.async
|
||||
CUTLASS_DEVICE void
|
||||
cp_async(
|
||||
int* gmem_addr,
|
||||
int* smem_addr
|
||||
) {
|
||||
uint32_t smem_int_ptr = cute::cast_smem_ptr_to_uint(smem_addr);
|
||||
asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"l"(gmem_addr),
|
||||
"n"(sizeof(int)));
|
||||
}
|
||||
|
||||
// mapping between thread id (T) -> dH (row index of Acc2)
|
||||
template <int CTA_dH>
|
||||
CUTLASS_DEVICE auto
|
||||
@@ -831,14 +745,13 @@ template <
|
||||
class TiledBMM1,
|
||||
class TiledBMM2,
|
||||
int CTA_qHLocal, int CTA_qL, int CTA_kvL, int CTA_dH>
|
||||
CUTLASS_DEVICE void
|
||||
CUTLASS_DEVICE void
|
||||
MMA_warp(
|
||||
SharedStorage& shared_storage,
|
||||
WorkTileInfo work_tile_info,
|
||||
OTensor mO,
|
||||
TiledBMM1 tiled_bmm1,
|
||||
TiledBMM2 tiled_bmm2,
|
||||
cutlass::arch::NamedBarrier& tmem_allocation_barrier
|
||||
TiledBMM2 tiled_bmm2
|
||||
) {
|
||||
if (!work_tile_info.is_valid()) {
|
||||
// we don't allocate tmem for invalid tiles but we still need to relinquish the allocation lock
|
||||
@@ -900,7 +813,7 @@ MMA_warp(
|
||||
tmem_allocator.allocate(Acc1_col_max, &shared_storage.bmm1_tmem_base_ptr);
|
||||
tmem_allocator.allocate(Acc2_col_max, &shared_storage.bmm2_tmem_base_ptr);
|
||||
// notify epilog warp that tmem allocation is complete
|
||||
tmem_allocation_barrier.arrive();
|
||||
arrive_barrier(shared_storage.tmem_allocation_result_barrier);
|
||||
|
||||
// relinquish early so that prefetch cta can be launched
|
||||
tmem_allocator.release_allocation_lock();
|
||||
@@ -963,7 +876,10 @@ MMA_warp(
|
||||
MMA_gemm<decltype(tCrV), decltype(tCrP), decltype(tCtAcc2), TiledBMM2, '2', Print>(tCrV, tCrP, tCtAcc2, tiled_bmm2, bmm2_stage_idx, tma_bmm2_full_barrier_phase_bit, bmm2_accumulate, shared_storage.tmasoftmax_bmm2_full_barrier, shared_storage.tma_bmm2_empty_barrier, shared_storage.bmm2_epilog_full_barrier);
|
||||
|
||||
// wait for tmem deallocation signal from epilog warp
|
||||
tmem_allocation_barrier.arrive_and_wait();
|
||||
arrive_barrier(shared_storage.tmem_allocation_result_barrier);
|
||||
// initial phase bit = 1 since it's already flipped once for tmem allocation
|
||||
// it will flip to 0 when tmem can be deallocated, so we wait for old phase bit of 1
|
||||
wait_barrier(shared_storage.tmem_allocation_result_barrier, 1);
|
||||
|
||||
// deallocate TMEM
|
||||
tmem_allocator.free(shared_storage.bmm1_tmem_base_ptr, Acc1_col_max);
|
||||
@@ -982,7 +898,7 @@ template <
|
||||
int CTA_qHLocal, int CTA_qL, int CTA_kvL, int CTA_dH,
|
||||
int NumReductionCTA,
|
||||
bool NoSink>
|
||||
CUTLASS_DEVICE void
|
||||
CUTLASS_DEVICE void
|
||||
EPILOG_warp(
|
||||
SharedStorage& shared_storage,
|
||||
WorkTileInfo work_tile_info,
|
||||
@@ -994,7 +910,6 @@ EPILOG_warp(
|
||||
TiledBMM2 tiled_bmm2,
|
||||
float softmax_scale_log2,
|
||||
int sliding_window_size,
|
||||
cutlass::arch::NamedBarrier& tmem_allocation_barrier,
|
||||
cutlass::arch::NamedBarrier& epilog_barrier,
|
||||
int NumSplits,
|
||||
int tid, // tid local to epilog warp
|
||||
@@ -1024,7 +939,9 @@ EPILOG_warp(
|
||||
|
||||
// wait for tmem allocation in mma warp to complete, only do the wait for valid tiles
|
||||
if (work_tile_info.is_valid()) {
|
||||
tmem_allocation_barrier.arrive_and_wait();
|
||||
arrive_barrier(shared_storage.tmem_allocation_result_barrier);
|
||||
// initial phase bit = 0, it will flip to 1 when tmem is allocated, so we wait for old phase bit of 0
|
||||
wait_barrier(shared_storage.tmem_allocation_result_barrier, 0);
|
||||
}
|
||||
|
||||
// update tmem base ptr of the accumulator tensor
|
||||
@@ -1344,11 +1261,15 @@ EPILOG_warp(
|
||||
static_assert(MaxSplits <= 32, "we can use 1 warp to initialize mailbox");
|
||||
// initialize mailbox tensor for fmax and fsum, when NumSplits < MaxSplits, we need to init those value to -inf and 0
|
||||
// because we do reduction on the full tensor (of size MaxSplits) not just the valid splits
|
||||
// there is no need to init sAcc2 because it will be scaled with beta which will be 0 for invalid splits
|
||||
if (tid < MaxSplits) {
|
||||
fill(sFmaxMailbox(tid, _), -cutlass::platform::numeric_limits<TypeAcc>::infinity());
|
||||
clear(sFsumMailbox(tid, _));
|
||||
}
|
||||
// we also need to clear out acc2 mailbox for invalid splits, because acc2 value could be nan
|
||||
// nan * 0 (beta) = nan, we still need to clear acc2
|
||||
if (tid < CTA_dH) {
|
||||
clear(sAcc2Mailbox(tid, _, _));
|
||||
}
|
||||
|
||||
// ensure initialized smem is visible to the entire cluster
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
@@ -1606,7 +1527,7 @@ EPILOG_warp(
|
||||
}
|
||||
|
||||
// signal the mma warp tcgen05.ld of bmm2 is done, can start deallocate all tmem
|
||||
tmem_allocation_barrier.arrive();
|
||||
arrive_barrier(shared_storage.tmem_allocation_result_barrier);
|
||||
}
|
||||
|
||||
// only NumReductionCTA number of reduction ctas will do the reduction
|
||||
@@ -1734,15 +1655,16 @@ EPILOG_warp(
|
||||
}*/
|
||||
}
|
||||
|
||||
// K has shape (kvL, dH, kvH, BS)
|
||||
// Q has shape ((qHLocal, qL), dH, kvH, BS)
|
||||
// V has shape (dH, kvL, kvH, BS)
|
||||
// O has shape (dH, (qHLocal, qL), kvH, BS)
|
||||
// sinks has shape ((qHLocal, qL), kvH)
|
||||
// seq_len has shape (BS)
|
||||
// mK has shape (kvL, dH, kvH, BS)
|
||||
// mQ has shape ((qHLocal, qL), dH, kvH, BS)
|
||||
// mV has shape (dH, kvL, kvH, BS)
|
||||
// mO has shape (dH, (qHLocal, qL), kvH, BS)
|
||||
// mSink has shape ((qHLocal, qL), kvH)
|
||||
// mSeqLens has shape (BS)
|
||||
template <
|
||||
class SharedStorage,
|
||||
class KTensor, class QTensor, class VTensor, class OTensor, class SinkTensor,
|
||||
class SeqLensTensor,
|
||||
class TmaAtomK, class TmaAtomQ, class TmaAtomV,
|
||||
class TiledBMM1, class TiledBMM2,
|
||||
class TypeAcc,
|
||||
@@ -1758,7 +1680,7 @@ gqa_device(
|
||||
VTensor mV,
|
||||
OTensor mO,
|
||||
SinkTensor mSink,
|
||||
int* seq_lens,
|
||||
SeqLensTensor mSeqLens,
|
||||
CUTE_GRID_CONSTANT TmaAtomK const tma_atom_K,
|
||||
CUTE_GRID_CONSTANT TmaAtomQ const tma_atom_Q,
|
||||
CUTE_GRID_CONSTANT TmaAtomV const tma_atom_V,
|
||||
@@ -1782,7 +1704,7 @@ gqa_device(
|
||||
int BS_idx = blockIdx.x / kvH;
|
||||
// only thread 0 issues cp.async to load the seq_len
|
||||
if (threadIdx.x == 0) {
|
||||
cp_async(&seq_lens[BS_idx], &shared_storage.seq_len);
|
||||
cp_async(&mSeqLens(BS_idx), &shared_storage.seq_len);
|
||||
}
|
||||
|
||||
//if (threadIdx.x == 0) {
|
||||
@@ -1805,15 +1727,13 @@ gqa_device(
|
||||
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterBarrier, 1>(&shared_storage.bmm1_softmax_full_barrier, /* arrival count */ 1);
|
||||
// 1 thread (BMM2) arrive to signal epilog
|
||||
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterBarrier, 1>(&shared_storage.bmm2_epilog_full_barrier, /* arrival count */ 1);
|
||||
// 32 (mma) + 128 (epilog) to signal tmem allocation/deallocation result
|
||||
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterBarrier, 1>(&shared_storage.tmem_allocation_result_barrier, /* arrival count */ 32 + 128);
|
||||
// 1 thread (epilog) arrive to signal maxsum
|
||||
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterTransactionBarrier, 1>(&shared_storage.maxsum_mailbox_full_barrier, /* arrival count */ 1);
|
||||
// 1 thread (epilog) arrive to signal acc2
|
||||
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterTransactionBarrier, 1>(&shared_storage.acc2_mailbox_full_barrier, /* arrival count */ 1);
|
||||
}
|
||||
// Sync tmem allocation status between MMA and softmax/epilogue warps within CTA
|
||||
// 32 threads (mma) + 128 threads (epilog) to sync
|
||||
// also used for tmem deallocation between epilog warps and mma warps within CTA
|
||||
cutlass::arch::NamedBarrier tmem_allocation_barrier(32 + 128, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier);
|
||||
// syncing all threads (128) within 4 epilog warps
|
||||
cutlass::arch::NamedBarrier epilog_barrier(128, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
||||
|
||||
@@ -1913,13 +1833,13 @@ gqa_device(
|
||||
DMA_KV_warp<SharedStorage, WorkTileInfo, decltype(mK), decltype(mV), TmaAtomK, TmaAtomV, TiledBMM1, TiledBMM2, CTA_kvL, CTA_dH>(shared_storage, work_tile_info, mK, mV, &tma_atom_K, &tma_atom_V, tiled_bmm1, tiled_bmm2);
|
||||
}
|
||||
else if (warp_idx == 2) {
|
||||
MMA_warp<SharedStorage, WorkTileInfo, decltype(mO), TiledBMM1, TiledBMM2, CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH>(shared_storage, work_tile_info, mO, tiled_bmm1, tiled_bmm2, tmem_allocation_barrier);
|
||||
}
|
||||
MMA_warp<SharedStorage, WorkTileInfo, decltype(mO), TiledBMM1, TiledBMM2, CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH>(shared_storage, work_tile_info, mO, tiled_bmm1, tiled_bmm2);
|
||||
}
|
||||
else if (warp_idx >= 4) {
|
||||
// epilog tid is from 128 to 255, need to offset by -128 when getting the per thread slice
|
||||
int tid = threadIdx.x - 128;
|
||||
// warp_idx - 4 because epilog warp group starts from warp 4
|
||||
EPILOG_warp<SharedStorage, WorkTileInfo, decltype(mK), decltype(mO), decltype(mSink), TiledBMM1, TiledBMM2, CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH, NumReductionCTA, NoSink>(shared_storage, work_tile_info, mK, mO, mSink, seq_len, tiled_bmm1, tiled_bmm2, softmax_scale_log2, sliding_window_size, tmem_allocation_barrier, epilog_barrier, NumSplits, tid, warp_idx - 4, rank);
|
||||
EPILOG_warp<SharedStorage, WorkTileInfo, decltype(mK), decltype(mO), decltype(mSink), TiledBMM1, TiledBMM2, CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH, NumReductionCTA, NoSink>(shared_storage, work_tile_info, mK, mO, mSink, seq_len, tiled_bmm1, tiled_bmm2, softmax_scale_log2, sliding_window_size, epilog_barrier, NumSplits, tid, warp_idx - 4, rank);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@@ -1980,6 +1900,7 @@ void gqa_host(
|
||||
Tensor mV = make_tensor(make_gmem_ptr(device_ptr_V), layout_V); // (dH, kvL, kvH, BS)
|
||||
Tensor mO = make_tensor(make_gmem_ptr(device_ptr_O), layout_O); // (dH, (qHLocal, qL), kvH, BS)
|
||||
Tensor mSink = make_tensor(make_gmem_ptr(device_ptr_sinks), layout_sinks); // ((qHLocal, qL), kvH)
|
||||
Tensor mSeqLens = make_tensor(make_gmem_ptr(seq_lens), make_layout(make_shape(BS))); // (BS)
|
||||
|
||||
//printf("mK: "); print(mK); printf("\n");
|
||||
//printf("mQ: "); print(mQ); printf("\n");
|
||||
@@ -2174,6 +2095,7 @@ void gqa_host(
|
||||
if (device_ptr_sinks != nullptr) {
|
||||
auto *kernel_instance =
|
||||
&gqa_device<SMEMStorage, decltype(mK_tma), decltype(mQ_tma), decltype(mV_tma), decltype(mO), decltype(mSink),
|
||||
decltype(mSeqLens),
|
||||
decltype(tma_atom_K), decltype(tma_atom_Q), decltype(tma_atom_V),
|
||||
decltype(tiled_bmm1), decltype(tiled_bmm2),
|
||||
TypeAcc,
|
||||
@@ -2185,14 +2107,15 @@ void gqa_host(
|
||||
// portable max cluster size is 8, but sm100a supports 16, need explicit opt in
|
||||
gpuErrChk(cudaFuncSetAttribute(*kernel_instance, cudaFuncAttributeNonPortableClusterSizeAllowed, 1));
|
||||
gpuErrChk(cudaLaunchKernelEx(&config, kernel_instance, mK_tma, mQ_tma, mV_tma, mO, mSink,
|
||||
seq_lens,
|
||||
mSeqLens,
|
||||
tma_atom_K, tma_atom_Q, tma_atom_V,
|
||||
tiled_bmm1, tiled_bmm2,
|
||||
softmax_scale * Log2_E, sliding_window_size, pdl_count));
|
||||
}
|
||||
}
|
||||
else {
|
||||
auto *kernel_instance =
|
||||
&gqa_device<SMEMStorage, decltype(mK_tma), decltype(mQ_tma), decltype(mV_tma), decltype(mO), decltype(mSink),
|
||||
decltype(mSeqLens),
|
||||
decltype(tma_atom_K), decltype(tma_atom_Q), decltype(tma_atom_V),
|
||||
decltype(tiled_bmm1), decltype(tiled_bmm2),
|
||||
TypeAcc,
|
||||
@@ -2204,7 +2127,7 @@ void gqa_host(
|
||||
// portable max cluster size is 8, but sm100a supports 16, need explicit opt in
|
||||
gpuErrChk(cudaFuncSetAttribute(*kernel_instance, cudaFuncAttributeNonPortableClusterSizeAllowed, 1));
|
||||
gpuErrChk(cudaLaunchKernelEx(&config, kernel_instance, mK_tma, mQ_tma, mV_tma, mO, mSink,
|
||||
seq_lens,
|
||||
mSeqLens,
|
||||
tma_atom_K, tma_atom_Q, tma_atom_V,
|
||||
tiled_bmm1, tiled_bmm2,
|
||||
softmax_scale * Log2_E, sliding_window_size, pdl_count));
|
||||
|
||||
867
examples/93_blackwell_low_latency_gqa/tgv_gqa_paged.cuh
Normal file
867
examples/93_blackwell_low_latency_gqa/tgv_gqa_paged.cuh
Normal file
@@ -0,0 +1,867 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <cstdio>
|
||||
#include <cmath>
|
||||
|
||||
// Cutlass includes
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/gemm/collective/builders/sm100_common.inl> // mma/smem selector, umma::major
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/arch/grid_dependency_control.h>
|
||||
|
||||
// CuTe includes
|
||||
#include <cute/tensor.hpp> // CuTe tensor implementation
|
||||
#include <cute/arch/tmem_allocator_sm100.hpp> // TMEM allocator for SM100
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
#include <cute/arch/copy_sm90_tma.hpp>
|
||||
|
||||
#include "common.cuh"
|
||||
#include "tgv_gqa.cuh" // reuse layout helpers, WorkTileInfo, TMA_copy, MMA_gemm, reductions, DMA_Q/MMA/EPILOG warps
|
||||
|
||||
// Grouped Query Attention (paged KV cache) with dual BMM + online softmax. 7 warps: 1 DMA_Q, 1 DMA_KV, 1 MMA,
|
||||
// 4 EPILOG. Warp 3 is unused.
|
||||
// Warp 0 (DMA_Q): Loads Q via TMA, single-stage. Reused from gqa namespace (gqa::DMA_Q_warp).
|
||||
// Warp 1 (DMA_KV): Loads K then V via TMA, one cp.async.bulk per page. Issues its own cp.async of per-CTA-tile
|
||||
// page indices into a smem staging buffer (lane-distributed, thread-local cp_async fence/wait).
|
||||
// Defined locally in gqa_paged namespace.
|
||||
// Warp 2 (MMA): Performs BMM1 (K@Q) and BMM2 (V@P). Reused from gqa namespace (gqa::MMA_warp).
|
||||
// Warps 4-7 (EPILOG): Softmax partial max/sum warp reduction, cluster wide max/sum reduction, final flash-decode
|
||||
// output with attention sink support. Reused from gqa namespace (gqa::EPILOG_warp).
|
||||
// WorkTileInfo: Reused from gqa namespace. Attention-specific fields: BS_idx, kvH_idx, kvL_idx_start/end, dH_idx, qHLocal_idx, qL_idx.
|
||||
// SharedStorage: Defined locally. Inherits from gqa::SharedStorage and adds the paged-only pieces: PageIdx smem
|
||||
// buffer (double-buffered because K and V are offset by 1 tile -- around pi-stage boundaries they read different
|
||||
// slots) and paged views (tensor_sK_paged / tensor_sV_paged / tensor_sPageIdx) over the inherited K/V smem buffers.
|
||||
|
||||
namespace TGV {
|
||||
namespace gqa_paged {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// Symbols reused identically from TGV::gqa -- see tgv_gqa.cuh for definitions.
|
||||
// Log2_E and WorkTileInfo are pulled in here; everything else is referenced explicitly via gqa:: at call sites.
|
||||
using TGV::gqa::Log2_E;
|
||||
using TGV::gqa::WorkTileInfo;
|
||||
|
||||
// The (bs_idx, per_batch_page_idx) -> physical page id mapping lives in a gmem page_table tensor of shape
|
||||
// (kvL/Page_Size, BS), seeded by the host harness and fetched at runtime by DMA_KV_warp. To install a
|
||||
// new placement policy, populate the page_table differently host-side; the kernel does not assume any structure.
|
||||
// The shared memory buffers for Q, K, V matrices.
|
||||
template <
|
||||
class TypeQKV, // Tensor Q/K/V data type
|
||||
class TypeAcc, // Tensor Acc data type
|
||||
class KSmemLayout, // ((Mma_M, Mma_K), NumMma_M, NumMma_K, BMM1_DMA_Stage)
|
||||
class KPagedSmemLayout,// (Page_Size, CTA_dH, NumPagePerCTATile, BMM1_DMA_Stage), same memory as KSmemLayout
|
||||
class QSmemLayout, // ((Mma_N, Mma_K), NumMma_N, NumMma_K, 1)
|
||||
class VSmemLayout, // ((Mma_M, Mma_K), NumMma_M, NumMma_K, BMM2_DMA_Stage)
|
||||
class VPagedSmemLayout,// (CTA_dH, Page_Size, NumPagePerCTATile, BMM2_DMA_Stage), same memory as VSmemLayout
|
||||
class SSmemLayout, // (CTA_kvL, (CTA_qHLocal, CTA_qL), 1) aka C matrix (M, N, 1) for bmm1
|
||||
class PSmemLayout, // ((CTA_qHLocal, CTA_qL), CTA_kvL, 1) aka B matrix (N, K, 1) for bmm2
|
||||
class WRSmemLayout, // (NumEpilogWarps, (CTA_qHLocal, CTA_qL)), WR stands for warp reduce
|
||||
class MSMailboxSmemLayout,// (MaxSplits, CTA_qHLocal * CTA_qL / NumReductionCTA), MS stands max and sum
|
||||
class Acc1SmemLayout, // (CTA_kvL, (CTA_qHLocal, CTA_qL), 1)
|
||||
class Acc2MailboxSmemLayout, // (CTA_dH, CTA_qHLocal * CTA_qL / NumReductionCTA, MaxSplits)
|
||||
class SinksSmemLayout, // (CTA_qHLocal * CTA_qL / NumReductionCTA)
|
||||
class PageIdxSmemLayout, // ((NumPagePerCTATile, Tiles_Per_Pi_Stage), Page_Idx_Stage), int32 page indices staged by DMA_KV warp
|
||||
int BMM1_DMA_Stage,
|
||||
int BMM2_DMA_Stage,
|
||||
int Page_Idx_Stage>
|
||||
// Paged kernel adds two pieces on top of the plain gqa SharedStorage:
|
||||
// - PageIdx smem buffer (DMA_KV staging, int32 page indices). Double-buffered: K and V are 1 tile apart, so
|
||||
// around a pi-stage boundary they read different slots. DMA_KV writes the next slot via lane-distributed
|
||||
// cp.async at K_t_in_stage==1 (when V has crossed into the current pi-stage and the next slot is free).
|
||||
// - paged views of the inherited K/V smem buffers (same memory, different layout)
|
||||
struct SharedStorage : TGV::gqa::SharedStorage<
|
||||
TypeQKV, TypeAcc,
|
||||
KSmemLayout, QSmemLayout, VSmemLayout, SSmemLayout, PSmemLayout,
|
||||
WRSmemLayout, MSMailboxSmemLayout, Acc1SmemLayout,
|
||||
Acc2MailboxSmemLayout, SinksSmemLayout,
|
||||
BMM1_DMA_Stage, BMM2_DMA_Stage> {
|
||||
// DMA_KV staging buffer for int32 page indices, layout ((NumPagePerCTATile, Tiles_Per_Pi_Stage), Page_Idx_Stage)
|
||||
alignas(128) cute::ArrayEngine<int, cute::cosize_v<PageIdxSmemLayout>> PageIdx;
|
||||
|
||||
// alternative paged view of the same K smem buffer used by TMA: (Page_Size, CTA_dH, NumPagePerCTATile, BMM1_DMA_Stage)
|
||||
CUTE_DEVICE constexpr auto tensor_sK_paged() { return make_tensor(make_smem_ptr(this->K.begin()), KPagedSmemLayout{}); }
|
||||
// alternative paged view of the same V smem buffer used by TMA: (CTA_dH, Page_Size, NumPagePerCTATile, BMM2_DMA_Stage)
|
||||
CUTE_DEVICE constexpr auto tensor_sV_paged() { return make_tensor(make_smem_ptr(this->V.begin()), VPagedSmemLayout{}); }
|
||||
CUTE_DEVICE constexpr auto tensor_sPageIdx() { return make_tensor(make_smem_ptr(PageIdx.begin()), PageIdxSmemLayout{}); }
|
||||
};
|
||||
|
||||
// paged TMA copy: like gqa::TMA_copy, but issues NumPagePerCTATile back-to-back copies that share one full barrier.
|
||||
// empty/full barrier semantics still operate at CTA-tile granularity -- one slot covers the whole (CTA_kvL, CTA_dH) tile,
|
||||
// transaction bytes counts all pages, and a single set_barrier_transaction_bytes arrives once per stage.
|
||||
//
|
||||
// Page indices are read from smem -- DMA_KV stages them from the gmem page_table via lane-distributed
|
||||
// cp.async and a thread-local cp_async_fence/wait. The caller passes a smem pointer to the
|
||||
// NumPagePerCTATile page indices owned by this tile; this function performs an ld.shared per page and uses the
|
||||
// resulting global page id as the gmem coordinate for the TMA copy.
|
||||
template <
|
||||
class GTensor,
|
||||
class STensor,
|
||||
class TmaAtom,
|
||||
char Name,
|
||||
bool Print,
|
||||
int DMA_Stage,
|
||||
int NumPagePerCTATile>
|
||||
CUTLASS_DEVICE void
|
||||
TMA_copy_paged(
|
||||
GTensor gTensor, // ((TMA, NumTma_K), Num_Page_Global)
|
||||
STensor sTensor, // ((TMA, NumTma_K), NumPagePerCTATile, DMA_Stage)
|
||||
int k_tile,
|
||||
int const* page_idx_smem, // pointer to NumPagePerCTATile contiguous int page indices in smem (this tile's slice)
|
||||
int& tma_mma_empty_barrier_phase_bit,
|
||||
int tma_transaction_bytes, // total bytes for one CTA tile (NumPagePerCTATile pages)
|
||||
TmaAtom const* tma_atom,
|
||||
cute::uint64_t* tma_mma_full_barrier,
|
||||
cute::uint64_t* tma_mma_empty_barrier
|
||||
) {
|
||||
// wait for the smem slot to be empty before issuing pages for the next CTA tile
|
||||
wait_barrier(tma_mma_empty_barrier[k_tile % DMA_Stage], tma_mma_empty_barrier_phase_bit);
|
||||
|
||||
if constexpr (Print) {
|
||||
if (elect_one_sync() && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 0)) {
|
||||
printf("[DMA_%c] barrier empty, kblock %d (paged, %d pages)\n", Name, k_tile, NumPagePerCTATile);
|
||||
}
|
||||
}
|
||||
|
||||
if (elect_one_sync()) {
|
||||
// single set_barrier_transaction_bytes accounts for all pages; later cute::copy calls don't add arrivals.
|
||||
set_barrier_transaction_bytes(tma_mma_full_barrier[k_tile % DMA_Stage], tma_transaction_bytes);
|
||||
CUTE_UNROLL
|
||||
for (int p = 0; p < NumPagePerCTATile; p++) {
|
||||
// per-page lookup: load global page id from smem and feed it as the gmem coordinate for one TMA copy.
|
||||
int global_page = page_idx_smem[p];
|
||||
copy(tma_atom->with(tma_mma_full_barrier[k_tile % DMA_Stage]),
|
||||
gTensor(_, global_page),
|
||||
sTensor(_, p, k_tile % DMA_Stage));
|
||||
}
|
||||
}
|
||||
|
||||
if ((k_tile % DMA_Stage) == (DMA_Stage - 1)) {
|
||||
tma_mma_empty_barrier_phase_bit ^= 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Lane-distributed cp.async of one pi-stage's worth (Num_Page_Idx_Per_Stage ints) of page indices from gmem
|
||||
// page table to a smem slot. Each lane loads ceil(Num_Page_Idx_Per_Stage/32) ints; the LSU coalesces them.
|
||||
// Caller pre-slices the BS mode of the page table (it is CTA-constant) and is responsible for
|
||||
// cp_async_fence/wait + __syncwarp ordering after this returns.
|
||||
template <
|
||||
int Tiles_Per_Pi_Stage, int Num_Page_Idx_Per_Stage, int Page_Idx_Stage,
|
||||
class GPageTableTensor, class SPageIdxTensor>
|
||||
CUTLASS_DEVICE void
|
||||
issue_pi_stage_cp_async(
|
||||
GPageTableTensor gPageTable, // ((NumPagePerCTATile, kvL/CTA_kvL),), int gmem -- BS already sliced
|
||||
SPageIdxTensor sPageIdx, // ((NumPagePerCTATile, Tiles_Per_Pi_Stage), Page_Idx_Stage), int smem
|
||||
int kvL_idx_start, int pi_stage_idx) {
|
||||
int tile_start = kvL_idx_start + pi_stage_idx * Tiles_Per_Pi_Stage;
|
||||
int slot = pi_stage_idx % Page_Idx_Stage;
|
||||
// Shift gPageTable's data pointer to the first page of this pi-stage; flat-index `i` walks consecutive
|
||||
// (page-within-CTA-tile, CTA-tile) pairs in column-major order across Num_Page_Idx_Per_Stage ints.
|
||||
auto gp = domain_offset(make_coord(make_coord(0, tile_start)), gPageTable);
|
||||
// Use threadIdx.x-derived lane (cheap: threadIdx.x is already live in a register from earlier in the kernel).
|
||||
int lane = cutlass::canonical_lane_idx();
|
||||
CUTE_UNROLL
|
||||
for (int i = lane; i < Num_Page_Idx_Per_Stage; i += 32) {
|
||||
cp_async(&gp(i), &sPageIdx(i, slot));
|
||||
}
|
||||
}
|
||||
|
||||
// Paged DMA_KV warp: TMA-loads K then V, one cp.async.bulk per page. Page indices are staged into smem by this
|
||||
// same warp via lane-distributed cp.async (no separate Read_Page_Idx warp, no transaction-barrier handshake).
|
||||
// See the in-body comment ahead of the prolog for the page-idx fetch pipeline. Tiles_Per_Pi_Stage >= 2 is
|
||||
// required so the K_t_in_stage==1 issue hook exists.
|
||||
template <
|
||||
class SharedStorage,
|
||||
class WorkTileInfo,
|
||||
class KTensor,
|
||||
class VTensor,
|
||||
class PageTableTensor,
|
||||
class TmaAtomK,
|
||||
class TmaAtomV,
|
||||
int CTA_kvL, int CTA_dH, int Page_Size,
|
||||
int Page_Idx_Stage, int Num_Page_Idx_Per_Stage>
|
||||
CUTLASS_DEVICE void
|
||||
DMA_KV_warp(
|
||||
SharedStorage& shared_storage,
|
||||
WorkTileInfo work_tile_info,
|
||||
KTensor mK,
|
||||
VTensor mV,
|
||||
PageTableTensor mPageTable, // (kvL/Page_Size, BS), int gmem; underlying allocation tail-padded by Num_Page_Idx_Per_Stage ints. Mode-0 is partitioned by NumPagePerCTATile to access per-CTA-tile slices.
|
||||
// when passing tma descriptor as function argument, it has to be pass by pointer/reference, if pass by value, it will live on local memory (i.e. the stack)
|
||||
// and the tma unit cannot access the local memory, (even if it can, the local memory is strided by thread id, the content for each thread is strided)
|
||||
TmaAtomK const* tma_atom_K,
|
||||
TmaAtomV const* tma_atom_V) {
|
||||
|
||||
if (!work_tile_info.is_valid()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// CTA_kvL % Page_Size == 0, Num_Page_Idx_Per_Stage % NumPagePerCTATile == 0, Page_Idx_Stage == 2,
|
||||
// and Tiles_Per_Pi_Stage >= 2 are all checked in gqa_paged_host.
|
||||
constexpr int NumPagePerCTATile = CTA_kvL / Page_Size;
|
||||
constexpr int Tiles_Per_Pi_Stage = Num_Page_Idx_Per_Stage / NumPagePerCTATile;
|
||||
|
||||
// setup code for K tensor
|
||||
// paged smem view of K, NOT mma partitioned, used purely for TMA: (Page_Size, CTA_dH, NumPagePerCTATile, BMM1_DMA_Stage)
|
||||
// the same smem buffer is also accessible via tensor_sK() in MMA-partitioned form for the MMA warp.
|
||||
Tensor sK_paged = shared_storage.tensor_sK_paged(); // (Page_Size, CTA_dH, NumPagePerCTATile, BMM1_DMA_Stage)
|
||||
// mK has shape (Page_Size, dH, Num_Page_Global, kvH); BS is folded into Num_Page_Global so there is no BS mode.
|
||||
// local_tile with a static (Page_Size, CTA_dH) tile_shape materializes the leading two modes statically (TMA
|
||||
// partition requires static mode-0 sizes); dH==CTA_dH and Page_Size already matches so the divisions collapse.
|
||||
// The page selection is done at TMA-issue time via the smem page_idx slice (staged by DMA_KV's own cp.async),
|
||||
// not by indexing on a BS mode here.
|
||||
Tensor gK = local_tile(mK, make_shape(Int<Page_Size>{}, Int<CTA_dH>{}),
|
||||
make_coord(0, 0, _, work_tile_info.kvH_idx)); // (Page_Size, CTA_dH, Num_Page_Global)
|
||||
|
||||
// group modes [0,2) on both sK_paged and gK so the TMA box (Page_Size, CTA_dH) is mode 0; outer modes are
|
||||
// (NumPagePerCTATile, BMM1_DMA_Stage) for smem and Num_Page_Global for gmem.
|
||||
auto [tAgK, tAsK] = tma_partition(*tma_atom_K,
|
||||
Int<0>{}, // cta_coord: 1x1 cluster
|
||||
Layout<_1>{}, // cta_layout: CTA coord -> logical multicast id, no multicast, just identity layout
|
||||
group_modes<0, 2>(sK_paged), group_modes<0, 2>(gK));
|
||||
// tAsK: ((TMA, NumTma_K), NumPagePerCTATile, BMM1_DMA_Stage) -- 3 modes
|
||||
// tAgK: ((TMA, NumTma_K), Num_Page_Global) -- 2 modes
|
||||
// the shape of the TMA box is (Page_Size, CTA_dH)
|
||||
|
||||
/*if (elect_one_sync() && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 0)) {
|
||||
printf("[PAGED-K] sK_paged:\t"); print(sK_paged); printf("\n"); // (Page_Size, CTA_dH, NumPagePerCTATile, BMM1_DMA_Stage)
|
||||
printf("[PAGED-K] gK:\t"); print(gK); printf("\n"); // (Page_Size, CTA_dH, Num_Page_Global)
|
||||
printf("[PAGED-K] tAgK:\t"); print(tAgK); printf("\n"); // ((TMA, NumTma_K), Num_Page_Global)
|
||||
printf("[PAGED-K] tAsK:\t"); print(tAsK); printf("\n"); // ((TMA, NumTma_K), NumPagePerCTATile, BMM1_DMA_Stage)
|
||||
}*/
|
||||
|
||||
// setup code for V tensor
|
||||
// paged smem view of V, NOT mma partitioned, used purely for TMA: (CTA_dH, Page_Size, NumPagePerCTATile, BMM2_DMA_Stage)
|
||||
// the same smem buffer is also accessible via tensor_sV() in MMA-partitioned form for the MMA warp.
|
||||
Tensor sV_paged = shared_storage.tensor_sV_paged(); // (CTA_dH, Page_Size, NumPagePerCTATile, BMM2_DMA_Stage)
|
||||
// mV has shape (dH, Page_Size, Num_Page_Global, kvH); BS is folded into Num_Page_Global so there is no BS mode.
|
||||
Tensor gV = local_tile(mV, make_shape(Int<CTA_dH>{}, Int<Page_Size>{}),
|
||||
make_coord(0, 0, _, work_tile_info.kvH_idx)); // (CTA_dH, Page_Size, Num_Page_Global)
|
||||
|
||||
// group modes [0,2) on both sV_paged and gV so the TMA box (CTA_dH, Page_Size) is mode 0; outer modes are
|
||||
// (NumPagePerCTATile, BMM2_DMA_Stage) for smem and Num_Page_Global for gmem.
|
||||
auto [tAgV, tAsV] = tma_partition(*tma_atom_V,
|
||||
Int<0>{}, // cta_coord: 1x1 cluster
|
||||
Layout<_1>{}, // cta_layout: CTA coord -> logical multicast id, no multicast, just identity layout
|
||||
group_modes<0, 2>(sV_paged), group_modes<0, 2>(gV));
|
||||
// tAsV: ((TMA, NumTma_K), NumPagePerCTATile, BMM2_DMA_Stage) -- 3 modes
|
||||
// tAgV: ((TMA, NumTma_K), Num_Page_Global) -- 2 modes
|
||||
// the shape of the TMA box is (CTA_dH, Page_Size)
|
||||
|
||||
/*if (elect_one_sync() && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 0)) {
|
||||
printf("[PAGED-V] sV_paged:\t"); print(sV_paged); printf("\n"); // (CTA_dH, Page_Size, NumPagePerCTATile, BMM2_DMA_Stage)
|
||||
printf("[PAGED-V] gV:\t"); print(gV); printf("\n"); // (CTA_dH, Page_Size, Num_Page_Global)
|
||||
printf("[PAGED-V] tAgV:\t"); print(tAgV); printf("\n"); // ((TMA, NumTma_K), Num_Page_Global)
|
||||
printf("[PAGED-V] tAsV:\t"); print(tAsV); printf("\n"); // ((TMA, NumTma_K), NumPagePerCTATile, BMM2_DMA_Stage)
|
||||
}*/
|
||||
|
||||
// total K/V bytes per stage slot = all NumPagePerCTATile pages combined (they share one full barrier).
|
||||
// Slice tAsK/tAsV at one stage to get all pages in that stage; size_in_bytes is implicit via sizeof(tensor_like).
|
||||
int tma_K_transaction_bytes = sizeof(make_tensor_like(tAsK(_, _, 0)));
|
||||
int tma_V_transaction_bytes = sizeof(make_tensor_like(tAsV(_, _, 0)));
|
||||
int k_tile_count = work_tile_info.kvL_idx_end - work_tile_info.kvL_idx_start;
|
||||
// BMM1_DMA_Stage = mode 3 of rank-4 sK_paged (Page_Size, CTA_dH, NumPage, Stage)
|
||||
// BMM2_DMA_Stage = mode 3 of rank-4 sV_paged (CTA_dH, Page_Size, NumPage, Stage)
|
||||
int constexpr BMM1_DMA_Stage = size<3>(sK_paged);
|
||||
int constexpr BMM2_DMA_Stage = size<3>(sV_paged);
|
||||
|
||||
/*if (elect_one_sync() && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 0)) {
|
||||
printf("[PAGED-K] tma_K_transaction_bytes (per CTA tile)=%d\n", tma_K_transaction_bytes);
|
||||
printf("[PAGED-V] tma_V_transaction_bytes=%d, k_tile_count=%d, BMM1_DMA_Stage=%d, BMM2_DMA_Stage=%d\n", tma_V_transaction_bytes, k_tile_count, BMM1_DMA_Stage, BMM2_DMA_Stage);
|
||||
}*/
|
||||
|
||||
// details of how the phase bit works is in DMA_Q_warp (in tgv_gqa.cuh)
|
||||
int tma_bmm1_empty_barrier_phase_bit = 1;
|
||||
int tma_bmm2_empty_barrier_phase_bit = 1;
|
||||
|
||||
int bmm1_k_tile = 0;
|
||||
int bmm2_k_tile = 0;
|
||||
|
||||
bool constexpr Print = false;
|
||||
|
||||
// sPageIdx layout: ((NumPagePerCTATile, Tiles_Per_Pi_Stage), Page_Idx_Stage). For tile t (CTA-tile idx):
|
||||
// t_in_stage = t % Tiles_Per_Pi_Stage; pi_slot = (t / Tiles_Per_Pi_Stage) % Page_Idx_Stage;
|
||||
// page_idx_smem_ptr = &sPageIdx(make_coord(0, t_in_stage), pi_slot).
|
||||
Tensor sPageIdx = shared_storage.tensor_sPageIdx();
|
||||
|
||||
// gPageTable view: per-CTA slice of the page table, then split mode 0 by NumPagePerCTATile so the
|
||||
// (page-within-CTA-tile, CTA-tile-idx) split is explicit. BS is sliced upfront -- it's CTA-constant.
|
||||
// Tail-padded host allocation lets us read up to Tiles_Per_Pi_Stage tiles past kvL_idx_end without OOB.
|
||||
Tensor gPageTable = logical_divide(mPageTable(_, work_tile_info.BS_idx), Shape<Int<NumPagePerCTATile>>{}); // ((NumPagePerCTATile, kvL/CTA_kvL),)
|
||||
|
||||
// SOL order: K0, (K1, V0), (K2, V1), ...
|
||||
//
|
||||
// Page-idx fetch pipeline. The page-idx smem buffer is double-buffered (Page_Idx_Stage=2 slots) because
|
||||
// K leads V by 1 tile -- around a pi-stage boundary K reads slot S%2 (stage S) while V is finishing the
|
||||
// last tile of slot (S-1)%2 (stage S-1), so the two slots must coexist. Synchronization is thread-local
|
||||
// cp_async_fence/cp_async_wait<0> + __syncwarp -- no cross-warp barrier.
|
||||
// * Prolog: issue stage 0 -> slot 0, fence, wait, syncwarp, then K0 TMA.
|
||||
// * Main loop K_t_in_stage == 0: drain the previously-issued cp.async (stage K_pi_stage's data). At this
|
||||
// point at most one cp.async group is outstanding so wait<0> is correct.
|
||||
// * Main loop K_t_in_stage == 1: pre-issue stage K_pi_stage+1 -> slot (K_pi_stage+1)%2 if it has tiles
|
||||
// in this CTA's range. V crossed into pi-stage K_pi_stage at the previous iteration, so slot
|
||||
// (K_pi_stage+1)%2 = (K_pi_stage-1)%2 is no longer being ld.shared'd and is safe to overwrite. The
|
||||
// consumer's wait at the next pi-stage boundary (Tiles_Per_Pi_Stage-1 tiles later) hides the RTT.
|
||||
|
||||
// Prolog: stage 0 must be in slot 0 before K0 reads it.
|
||||
{
|
||||
issue_pi_stage_cp_async<Tiles_Per_Pi_Stage, Num_Page_Idx_Per_Stage, Page_Idx_Stage>(
|
||||
gPageTable, sPageIdx, work_tile_info.kvL_idx_start, /*pi_stage_idx=*/0);
|
||||
cp_async_fence();
|
||||
cp_async_wait<0>();
|
||||
__syncwarp();
|
||||
|
||||
// K0 prolog (no V yet).
|
||||
TMA_copy_paged<decltype(tAgK), decltype(tAsK), TmaAtomK, 'K', Print, BMM1_DMA_Stage, NumPagePerCTATile>(tAgK, tAsK, bmm1_k_tile, &sPageIdx(make_coord(0, 0), 0), tma_bmm1_empty_barrier_phase_bit, tma_K_transaction_bytes, tma_atom_K, shared_storage.tma_bmm1_full_barrier, shared_storage.tma_bmm1_empty_barrier);
|
||||
bmm1_k_tile++;
|
||||
}
|
||||
|
||||
for (; bmm1_k_tile < k_tile_count; bmm1_k_tile++, bmm2_k_tile++) {
|
||||
int K_t_in_stage = bmm1_k_tile % Tiles_Per_Pi_Stage;
|
||||
int K_pi_stage = bmm1_k_tile / Tiles_Per_Pi_Stage;
|
||||
int K_pi_slot = K_pi_stage % Page_Idx_Stage;
|
||||
|
||||
// Drain stage K_pi_stage's cp.async before K reads its slot.
|
||||
if (K_t_in_stage == 0) {
|
||||
cp_async_wait<0>();
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Pre-issue stage K_pi_stage+1 if there are still tiles in this CTA's range to consume from it.
|
||||
if (K_t_in_stage == 1) {
|
||||
int next_pi_stage = K_pi_stage + 1;
|
||||
int next_tile_start = work_tile_info.kvL_idx_start + next_pi_stage * Tiles_Per_Pi_Stage;
|
||||
if (next_tile_start < work_tile_info.kvL_idx_end) {
|
||||
issue_pi_stage_cp_async<Tiles_Per_Pi_Stage, Num_Page_Idx_Per_Stage, Page_Idx_Stage>(
|
||||
gPageTable, sPageIdx, work_tile_info.kvL_idx_start, next_pi_stage);
|
||||
cp_async_fence();
|
||||
}
|
||||
}
|
||||
|
||||
TMA_copy_paged<decltype(tAgK), decltype(tAsK), TmaAtomK, 'K', Print, BMM1_DMA_Stage, NumPagePerCTATile>(tAgK, tAsK, bmm1_k_tile, &sPageIdx(make_coord(0, K_t_in_stage), K_pi_slot), tma_bmm1_empty_barrier_phase_bit, tma_K_transaction_bytes, tma_atom_K, shared_storage.tma_bmm1_full_barrier, shared_storage.tma_bmm1_empty_barrier);
|
||||
|
||||
int V_t_in_stage = bmm2_k_tile % Tiles_Per_Pi_Stage;
|
||||
int V_pi_slot = (bmm2_k_tile / Tiles_Per_Pi_Stage) % Page_Idx_Stage;
|
||||
TMA_copy_paged<decltype(tAgV), decltype(tAsV), TmaAtomV, 'V', Print, BMM2_DMA_Stage, NumPagePerCTATile>(tAgV, tAsV, bmm2_k_tile, &sPageIdx(make_coord(0, V_t_in_stage), V_pi_slot), tma_bmm2_empty_barrier_phase_bit, tma_V_transaction_bytes, tma_atom_V, shared_storage.tmasoftmax_bmm2_full_barrier, shared_storage.tma_bmm2_empty_barrier);
|
||||
}
|
||||
|
||||
// V epilog (last V). Its slot is already populated by an earlier cp.async.
|
||||
{
|
||||
int V_t_in_stage = bmm2_k_tile % Tiles_Per_Pi_Stage;
|
||||
int V_pi_slot = (bmm2_k_tile / Tiles_Per_Pi_Stage) % Page_Idx_Stage;
|
||||
TMA_copy_paged<decltype(tAgV), decltype(tAsV), TmaAtomV, 'V', Print, BMM2_DMA_Stage, NumPagePerCTATile>(tAgV, tAsV, bmm2_k_tile, &sPageIdx(make_coord(0, V_t_in_stage), V_pi_slot), tma_bmm2_empty_barrier_phase_bit, tma_V_transaction_bytes, tma_atom_V, shared_storage.tmasoftmax_bmm2_full_barrier, shared_storage.tma_bmm2_empty_barrier);
|
||||
}
|
||||
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
}
|
||||
|
||||
// mK has shape (Page_Size, dH, num_pages, kvH)
|
||||
// mQ has shape ((qHLocal, qL), dH, kvH, BS)
|
||||
// mV has shape (dH, Page_Size, num_pages, kvH)
|
||||
// mO has shape (dH, (qHLocal, qL), kvH, BS)
|
||||
// mSink has shape ((qHLocal, qL), kvH)
|
||||
// mSeqLens has shape (BS)
|
||||
// mPageTable has shape (kvL/Page_Size, BS), entry [p, bs] = physical page id of batch bs's per-batch page p
|
||||
// (BS is folded into num_pages on the K/V side; mPageTable supplies the (bs, kvL) -> page mapping.)
|
||||
template <
|
||||
class SharedStorage,
|
||||
class KTensor, class QTensor, class VTensor, class OTensor, class SinkTensor,
|
||||
class SeqLensTensor, class PageTableTensor,
|
||||
class TmaAtomK, class TmaAtomQ, class TmaAtomV,
|
||||
class TiledBMM1, class TiledBMM2,
|
||||
class TypeAcc,
|
||||
int CTA_qHLocal, int CTA_qL, int CTA_kvL, int CTA_dH,
|
||||
int Page_Size,
|
||||
int BMM1_DMA_Stage, int BMM2_DMA_Stage,
|
||||
int Page_Idx_Stage, int Num_Page_Idx_Per_Stage,
|
||||
int MaxSplits, int NumReductionCTA,
|
||||
bool NoSink>
|
||||
__maxnreg__(128)
|
||||
__global__ void
|
||||
gqa_paged_device(
|
||||
KTensor mK,
|
||||
QTensor mQ,
|
||||
VTensor mV,
|
||||
OTensor mO,
|
||||
SinkTensor mSink,
|
||||
SeqLensTensor mSeqLens,
|
||||
PageTableTensor mPageTable, // (kvL/Page_Size, BS); underlying allocation tail-padded by Num_Page_Idx_Per_Stage ints
|
||||
CUTE_GRID_CONSTANT TmaAtomK const tma_atom_K,
|
||||
CUTE_GRID_CONSTANT TmaAtomQ const tma_atom_Q,
|
||||
CUTE_GRID_CONSTANT TmaAtomV const tma_atom_V,
|
||||
TiledBMM1 tiled_bmm1,
|
||||
TiledBMM2 tiled_bmm2,
|
||||
float softmax_scale_log2,
|
||||
int sliding_window_size,
|
||||
int pdl_count
|
||||
) {
|
||||
// Allocate SMEM
|
||||
extern __shared__ char shared_memory[];
|
||||
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
|
||||
|
||||
// WorkTileInfo, for non persistent static scheduler, cta id is the work tile info
|
||||
// since loading the seq_lens is on the critical path of the prolog, we want to start it as soon as possible
|
||||
// mK shape: (Page_Size, CTA_dH, Num_Page_Global, kvH) -- BS is folded into Num_Page_Global; kvH is mode 3.
|
||||
int kvH = shape<3>(mK);
|
||||
int BS_idx = blockIdx.x / kvH;
|
||||
// only thread 0 issues cp.async to load the seq_len -- keep this as the first thing on the critical path.
|
||||
if (threadIdx.x == 0) {
|
||||
cp_async(&mSeqLens(BS_idx), &shared_storage.seq_len);
|
||||
}
|
||||
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
|
||||
// barrier initialization, warp 0 does initialization
|
||||
if (warp_idx == 0) {
|
||||
// transaction barrier because tma arrive on it, 6 thread arrive: one for DMA_Q warp, one for DMA_KV (K fetch) warp, and 4 for softmax warp (tcgen05.ld Acc1 is done).
|
||||
// For paged K, DMA_KV still arrives ONCE per stage (via set_barrier_transaction_bytes once with total bytes for all pages).
|
||||
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterTransactionBarrier, BMM1_DMA_Stage>(shared_storage.tma_bmm1_full_barrier, /* arrival count */ 6);
|
||||
// 1 thread (BMM1) arrive to signal DMA_Q and DMA_KV (K fetch) warp
|
||||
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterBarrier, BMM1_DMA_Stage>(shared_storage.tma_bmm1_empty_barrier, /* arrival count */ 1);
|
||||
// transaction barrier because tma arrive on it, 5 thread arrive: one for DMA_KV (V fetch) warp and 4 for softmax warp (S/P store)
|
||||
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterTransactionBarrier, BMM2_DMA_Stage>(shared_storage.tmasoftmax_bmm2_full_barrier, /* arrival count */ 5);
|
||||
// 1 thread (BMM2) arrive to signal DMA_KV (V fetch) and softmax warp (P store)
|
||||
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterBarrier, BMM2_DMA_Stage>(shared_storage.tma_bmm2_empty_barrier, /* arrival count */ 1);
|
||||
// 1 thread (BMM1) arrive to signal softmax
|
||||
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterBarrier, 1>(&shared_storage.bmm1_softmax_full_barrier, /* arrival count */ 1);
|
||||
// 1 thread (BMM2) arrive to signal epilog
|
||||
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterBarrier, 1>(&shared_storage.bmm2_epilog_full_barrier, /* arrival count */ 1);
|
||||
// 32 (mma) + 128 (epilog) to signal tmem allocation/deallocation result
|
||||
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterBarrier, 1>(&shared_storage.tmem_allocation_result_barrier, /* arrival count */ 32 + 128);
|
||||
// 1 thread (epilog) arrive to signal maxsum
|
||||
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterTransactionBarrier, 1>(&shared_storage.maxsum_mailbox_full_barrier, /* arrival count */ 1);
|
||||
// 1 thread (epilog) arrive to signal acc2
|
||||
cutlass::arch::detail::initialize_barrier_array_aligned<cutlass::arch::ClusterTransactionBarrier, 1>(&shared_storage.acc2_mailbox_full_barrier, /* arrival count */ 1);
|
||||
// No page-idx full/empty barriers: DMA_KV fetches its own page indices via cp.async with thread-local
|
||||
// cp_async_fence/wait, so the cross-warp transaction-barrier handshake is gone.
|
||||
}
|
||||
// syncing all threads (128) within 4 epilog warps
|
||||
cutlass::arch::NamedBarrier epilog_barrier(128, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
||||
|
||||
// barrier initialization needs to be visible to all warps
|
||||
// defer it as late as possible to allow some thread divergence in prolog
|
||||
cutlass::arch::fence_barrier_init();
|
||||
#if 0
|
||||
// this will have a membar.gpu to ensure dsmem write visibility within the entire cluster, because there isn't a membar.cluster
|
||||
// membar.gpu is 0.2us
|
||||
cluster_sync();
|
||||
#else
|
||||
// the alternative is to use proper fences
|
||||
// at the ptx level, fence.mbarrier_init.release.cluster act as a release fence (in cluster scope) for mbarrier init op
|
||||
cutlass::arch::fence_barrier_init();
|
||||
// thread 0 waits for its previously issued cp.async to complete
|
||||
// here we overlap the cp.async with the barrier initialization as much as possible
|
||||
if (threadIdx.x == 0) {
|
||||
cp_async_fence();
|
||||
cp_async_wait<0>();
|
||||
}
|
||||
|
||||
// the cluster sync serves two purposes:
|
||||
// 1. it waits for all threads in the cluster to see the barrier initialization
|
||||
// 2. it waits for all threads in the cta to see the cp.async result in smem
|
||||
cluster_arrive_relaxed();
|
||||
cluster_wait();
|
||||
#endif
|
||||
|
||||
// bid.y includes rasterization of qHLocal and qL
|
||||
int qHLocal = shape<0>(shape<0>(mQ));
|
||||
int num_qHLocal = cutlass::ceil_div(qHLocal, CTA_qHLocal);
|
||||
// bid.z is the split id for kvL split, try to evenly distribute the kvL blocks to every CTA
|
||||
int rank = blockIdx.z;
|
||||
int seq_len = shared_storage.seq_len;
|
||||
// Sliding window optimization: when enabled, we only process tokens in range
|
||||
// [seq_len - sliding_window_size, seq_len). To simplify tile distribution,
|
||||
// we align the skip boundary to CTA_kvL tiles.
|
||||
int workload_seq_len = seq_len;
|
||||
int seq_len_skip_offset = 0;
|
||||
if (sliding_window_size > 0 && sliding_window_size < seq_len) {
|
||||
int unaligned_skip = seq_len - sliding_window_size;
|
||||
seq_len_skip_offset = (unaligned_skip / CTA_kvL) * CTA_kvL;
|
||||
workload_seq_len = seq_len - seq_len_skip_offset;
|
||||
}
|
||||
int NumKVTiles = cutlass::ceil_div(workload_seq_len, CTA_kvL);
|
||||
int NumSplits = cute::min(MaxSplits, NumKVTiles);
|
||||
int kvL_tile_count_per_cta = NumKVTiles / MaxSplits;
|
||||
int kvL_tile_count_remainder = NumKVTiles % MaxSplits;
|
||||
int kvL_tile_count = kvL_tile_count_per_cta + (rank < kvL_tile_count_remainder ? 1 : 0);
|
||||
int kvL_tile_count_skip_offset = seq_len_skip_offset / CTA_kvL;
|
||||
int kvL_tile_count_start = rank * kvL_tile_count_per_cta + (rank < kvL_tile_count_remainder ? rank : kvL_tile_count_remainder) + kvL_tile_count_skip_offset;
|
||||
int kvL_tile_count_end = kvL_tile_count_start + kvL_tile_count;
|
||||
WorkTileInfo work_tile_info {
|
||||
.BS_idx = (int32_t)BS_idx,
|
||||
.kvH_idx = (int32_t)(blockIdx.x % kvH),
|
||||
.kvL_idx_start = (int32_t)kvL_tile_count_start,
|
||||
.kvL_idx_end = (int32_t)kvL_tile_count_end,
|
||||
.dH_idx = 0, // no dH tiling for bmm2
|
||||
.qHLocal_idx = (int32_t)(blockIdx.y % num_qHLocal),
|
||||
.qL_idx = (int32_t)(blockIdx.y / num_qHLocal),
|
||||
.is_valid_tile = (kvL_tile_count > 0)
|
||||
};
|
||||
|
||||
if (warp_idx == 0) {
|
||||
gqa::DMA_Q_warp<SharedStorage, WorkTileInfo, decltype(mQ), TmaAtomQ, TiledBMM1, CTA_qHLocal, CTA_qL, CTA_dH>(shared_storage, work_tile_info, mQ, &tma_atom_Q, tiled_bmm1);
|
||||
}
|
||||
else if (warp_idx == 1) {
|
||||
DMA_KV_warp<SharedStorage, WorkTileInfo, decltype(mK), decltype(mV), decltype(mPageTable), TmaAtomK, TmaAtomV, CTA_kvL, CTA_dH, Page_Size, Page_Idx_Stage, Num_Page_Idx_Per_Stage>(shared_storage, work_tile_info, mK, mV, mPageTable, &tma_atom_K, &tma_atom_V);
|
||||
}
|
||||
else if (warp_idx == 2) {
|
||||
gqa::MMA_warp<SharedStorage, WorkTileInfo, decltype(mO), TiledBMM1, TiledBMM2, CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH>(shared_storage, work_tile_info, mO, tiled_bmm1, tiled_bmm2);
|
||||
}
|
||||
// warp_idx == 3 is unused (the previous Read_Page_Idx warp was folded into DMA_KV).
|
||||
else if (warp_idx >= 4) {
|
||||
// epilog tid is from 128 to 255, need to offset by -128 when getting the per thread slice
|
||||
int tid = threadIdx.x - 128;
|
||||
// EPILOG_warp only uses mK for shape(mK) (to build a kvL coord predicate). mK on the paged kernel has shape
|
||||
// (Page_Size, CTA_dH, Num_Page_Global, kvH) with BS folded into Num_Page_Global, so we build a coord-only
|
||||
// identity tensor with the original (kvL, dH, kvH, BS) shape. BS comes straight from mO (rank 4, mode 3).
|
||||
// kvL is recovered from mPageTable shape (kvL/Page_Size, BS); Page_Size is a compile-time constant so this
|
||||
// is a constant multiply (no integer divide), unlike deriving from shape<2>(mK)/BS which costs ~3%.
|
||||
int dH = shape<1>(mK);
|
||||
int BS = shape<3>(mO);
|
||||
int kvL = static_cast<int>(shape<0>(mPageTable)) * Page_Size;
|
||||
auto mK_coord = make_identity_tensor(make_shape(kvL, dH, kvH, BS));
|
||||
// warp_idx - 4 because epilog warp group starts from warp 4
|
||||
gqa::EPILOG_warp<SharedStorage, WorkTileInfo, decltype(mK_coord), decltype(mO), decltype(mSink), TiledBMM1, TiledBMM2, CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH, NumReductionCTA, NoSink>(shared_storage, work_tile_info, mK_coord, mO, mSink, seq_len, tiled_bmm1, tiled_bmm2, softmax_scale_log2, sliding_window_size, epilog_barrier, NumSplits, tid, warp_idx - 4, rank);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// KV has shape (num_pages_total, 2, Page_Size, kvH, dH); KV(_,0,...) is K, KV(_,1,...) is V; BS is folded into num_pages_total
|
||||
// Q has shape ((qHLocal, qL), dH, kvH, BS)
|
||||
// O has shape (dH, (qHLocal, qL), kvH, BS)
|
||||
// sinks has shape (qHLocal * kvH), i.e. one sink per q head, when device_ptr_sinks is nullptr, it's disabled
|
||||
// seq_lens has shape (BS); kvL is max_seq_len, seq_lens[bs] is the actual seq len for batch bs
|
||||
// page_table has shape (kvL/Page_Size, BS)
|
||||
// sliding_window_size is the size of the sliding window, when it's 0, it's disabled
|
||||
template<
|
||||
class TypeQKV, class TypeO, class TypeAcc,
|
||||
int CTA_qHLocal, int CTA_qL, int CTA_kvL, int CTA_dH,
|
||||
int Page_Size,
|
||||
int BMM1_DMA_Stage, int BMM2_DMA_Stage,
|
||||
int Page_Idx_Stage, int Num_Page_Idx_Per_Stage,
|
||||
int MaxSplits, int NumReductionCTA>
|
||||
void gqa_paged_host(
|
||||
TypeQKV* device_ptr_KV,
|
||||
TypeQKV* device_ptr_Q,
|
||||
TypeO* device_ptr_O,
|
||||
TypeAcc* device_ptr_sinks,
|
||||
int* seq_lens,
|
||||
int* device_ptr_page_table, // (kvL/Page_Size, BS); underlying allocation must be tail-padded by Num_Page_Idx_Per_Stage ints
|
||||
int kvH, int qHLocal, int qL, int kvL, int dH, int BS,
|
||||
int stride_KV_pages, int stride_KV_KV, int stride_KV_ps, int stride_KV_kvH, int stride_KV_dH,
|
||||
int stride_Q_kvH, int stride_Q_qHLocal, int stride_Q_qL, int stride_Q_dH, int stride_Q_BS,
|
||||
int stride_O_kvH, int stride_O_qHLocal, int stride_O_qL, int stride_O_dH, int stride_O_BS,
|
||||
int stride_PT_p, int stride_PT_BS,
|
||||
float softmax_scale,
|
||||
int sliding_window_size,
|
||||
bool pdl, int pdl_count = -1,
|
||||
cudaStream_t stream = 0
|
||||
) {
|
||||
assert(kvL % Page_Size == 0);
|
||||
int num_pages_total = BS * (kvL / Page_Size);
|
||||
|
||||
// Reconstruct the combined KV gmem tensor exactly as the harness laid it out: shape
|
||||
// (num_pages_total, 2, Page_Size, kvH, dH) with the 5 strides supplied by the caller.
|
||||
auto layout_KV = make_layout(
|
||||
make_shape(num_pages_total, Int<2>{}, Int<Page_Size>{}, kvH, dH),
|
||||
make_stride(stride_KV_pages, stride_KV_KV, stride_KV_ps, stride_KV_kvH, stride_KV_dH));
|
||||
Tensor mKV = make_tensor(make_gmem_ptr(device_ptr_KV), layout_KV); // (num_pages_total, 2, Page_Size, kvH, dH)
|
||||
|
||||
// Slice on the K/V mode (=1). Each slice has shape (num_pages_total, Page_Size, kvH, dH); the kernel expects the
|
||||
// modes in MMA order, so we permute via select<...>: K wants (Page_Size, dH, num_pages_total, kvH) -> indices
|
||||
// (1,3,0,2); V wants (dH, Page_Size, num_pages_total, kvH) -> indices (3,1,0,2).
|
||||
Tensor mKV_K = mKV(_, 0, _, _, _);
|
||||
Tensor mKV_V = mKV(_, 1, _, _, _);
|
||||
Tensor mK = make_tensor(mKV_K.data(), select<1, 3, 0, 2>(mKV_K.layout())); // (Page_Size, dH, num_pages_total, kvH)
|
||||
Tensor mV = make_tensor(mKV_V.data(), select<3, 1, 0, 2>(mKV_V.layout())); // (dH, Page_Size, num_pages_total, kvH)
|
||||
|
||||
Layout layout_Q = gqa::make_layout_Q(kvH, qHLocal, qL, dH, BS, stride_Q_kvH, stride_Q_qHLocal, stride_Q_qL, stride_Q_dH, stride_Q_BS);
|
||||
Layout layout_O = gqa::make_layout_O(kvH, qHLocal, qL, dH, BS, stride_O_kvH, stride_O_qHLocal, stride_O_qL, stride_O_dH, stride_O_BS);
|
||||
Layout layout_sinks = gqa::make_layout_sinks(qHLocal, qL, kvH);
|
||||
|
||||
// Page table tensor as the harness owns it: rank-2 shape (kvL/Page_Size, BS), strides supplied by the caller.
|
||||
// DMA_KV warp on the device side partitions mode-0 via logical_divide(_, Shape<NumPagePerCTATile>) into
|
||||
// ((NumPagePerCTATile, MaxNumKVTiles), BS) -- the (page-within-CTA-tile, CTA-tile-idx, batch) view it actually
|
||||
// indexes. Keeping the gmem-side layout flat per batch lets the harness pass any contiguous or page-strided table.
|
||||
// (DMA_KV uses 4-byte cp.async, so no cp.async.bulk-style 16B alignment requirements on bytes/base/stride.)
|
||||
auto layout_PageTable = make_layout(
|
||||
make_shape(kvL / Page_Size, BS),
|
||||
make_stride(stride_PT_p, stride_PT_BS));
|
||||
Tensor mPageTable = make_tensor(make_gmem_ptr(device_ptr_page_table), layout_PageTable); // (kvL/Page_Size, BS)
|
||||
|
||||
// how we handle oob:
|
||||
// oob for K, Q, V are handled by TMA
|
||||
// oob for O is explicitly handled by predicate in the epilog since it uses simple st.global epilog
|
||||
// we partition kvL with tile size of CTA_kvL, and we evenly distribute the kvL blocks to MaxSplits number of cta in the cluster
|
||||
assert(NumReductionCTA <= MaxSplits);
|
||||
static_assert(((CTA_qHLocal * CTA_qL) % NumReductionCTA) == 0, "each reduction cta must have even number of q tokens");
|
||||
|
||||
// mK and mV are constructed above by slicing+permuting the combined mKV tensor.
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(device_ptr_Q), layout_Q); // ((qHLocal, qL), dH, kvH, BS)
|
||||
Tensor mO = make_tensor(make_gmem_ptr(device_ptr_O), layout_O); // (dH, (qHLocal, qL), kvH, BS)
|
||||
Tensor mSink = make_tensor(make_gmem_ptr(device_ptr_sinks), layout_sinks); // ((qHLocal, qL), kvH)
|
||||
Tensor mSeqLens = make_tensor(make_gmem_ptr(seq_lens), make_layout(make_shape(BS))); // (BS)
|
||||
|
||||
static_assert(CTA_kvL == 128, "BMM1's MMA_M needs to be 128 for tcgen05.ld->softmax");
|
||||
static_assert(((CTA_qHLocal * CTA_qL) % 8) == 0, "BMM1's MMA_N needs to be divisible by 8 for tcgen05.mma");
|
||||
assert(dH == CTA_dH); // bmm1 only has 1 kblock (i.e. 1 Q tile), bmm2 deal with all dH for now, in the foreseable future this is the hardest constraint to lift
|
||||
static_assert((CTA_dH == 128) || (CTA_dH == 64), "BMM2's MMA_M needs to be at either 128 or 64 for tcgen05.ld->correction");
|
||||
// we swap AB so bmm1 is K (CTA_kvL, CTA_dH) x Q (CTA_dH, CTA_qHLocal * CTA_qL)
|
||||
// both Q and K are dH (K in gemm terminology) major
|
||||
// M = CTA_kvL, N = CTA_qHLocal * CTA_qL, K = CTA_dH
|
||||
TiledMMA tiled_bmm1 = cutlass::gemm::collective::detail::sm100_make_1sm_trivial_tiled_mma<
|
||||
TypeQKV, TypeQKV, TypeAcc, // Mma's A, B, and Accumulator types
|
||||
Shape<Int<CTA_kvL>, Int<CTA_qHLocal * CTA_qL>, Int<CTA_dH>>, // TileShape_MNK
|
||||
Shape<_1, _1, _1>, // ClusterShape_MNK
|
||||
cute::UMMA::Major::K, cute::UMMA::Major::K>();
|
||||
|
||||
// we swap AB for bmm2 as well, V (dH, CTA_kvL) x P (CTA_kvL, CTA_qHLocal * CTA_qL)
|
||||
// V is dH (M in gemm terminology) major, P is CTA_kvL (K in gemm terminology) major in smem after each thread writes P from rmem to smem
|
||||
// M = CTA_dH, N = CTA_qHLocal * CTA_qL, K = CTA_kvL
|
||||
TiledMMA tiled_bmm2 = cutlass::gemm::collective::detail::sm100_make_1sm_trivial_tiled_mma<
|
||||
TypeQKV, TypeQKV, TypeAcc, // Mma's A, B, and Accumulator types
|
||||
Shape<Int<CTA_dH>, Int<CTA_qHLocal * CTA_qL>, Int<CTA_kvL>>, // TileShape_MNK
|
||||
Shape<_1, _1, _1>, // ClusterShape_MNK
|
||||
cute::UMMA::Major::MN, cute::UMMA::Major::K>();
|
||||
|
||||
// Pre-partitioned smem Tile Shape to post-partitioned smem tile shape ((Mma_M, Mma_K), NumMma_M, NumMma_K, DMA_Stage)
|
||||
auto shape_K = make_shape(Int<CTA_kvL>{}, Int<CTA_dH>{}, Int<BMM1_DMA_Stage>{});
|
||||
auto shape_Q = make_shape(make_shape(Int<CTA_qHLocal>{}, Int<CTA_qL>{}), Int<CTA_dH>{}, Int<1>{});
|
||||
auto shape_S = make_shape(Int<CTA_kvL>{}, make_shape(Int<CTA_qHLocal>{}, Int<CTA_qL>{}), Int<1>{});
|
||||
auto mma_shape_K = partition_shape_A(tiled_bmm1, shape_K);
|
||||
auto mma_shape_Q = partition_shape_B(tiled_bmm1, shape_Q);
|
||||
|
||||
auto shape_V = make_shape(Int<CTA_dH>{}, Int<CTA_kvL>{}, Int<BMM2_DMA_Stage>{});
|
||||
auto shape_P = select<1, 0, 2>(shape_S); // just a permutation of shape_S
|
||||
auto mma_shape_V = partition_shape_A(tiled_bmm2, shape_V);
|
||||
|
||||
// choose the swizzle atom for K, Q, S, V and P
|
||||
auto SmemLayoutAtomK = cutlass::gemm::collective::detail::sm100_smem_selector<
|
||||
cute::UMMA::Major::K, // gmem layout of K
|
||||
TypeQKV, // data type of K
|
||||
decltype(shape<0>(shape_K)), decltype(shape<1>(shape_K))>(); // tile size of K
|
||||
auto SmemLayoutAtomQ = cutlass::gemm::collective::detail::sm100_smem_selector<
|
||||
cute::UMMA::Major::K, // gmem layout of Q
|
||||
TypeQKV, // data type of Q
|
||||
decltype(shape<0>(shape_Q)), decltype(shape<1>(shape_Q))>(); // tile size of Q
|
||||
// for bmm1 tcgen05.ld, each register is holding a row of S (CTA_kvL is mapped to the thread dimension), if we do
|
||||
// st.shared from rmem to smem, to avoid bank conflict, we need to put T0V0, T1V0, T2V0, ... T31V0 contiguously in smem.
|
||||
// then the smem layout of S is M (CTA_kvL) major, so we choose MN major swizzle atom
|
||||
auto SmemLayoutAtomS = UMMA::Layout_MN_SW128_Atom<TypeQKV>{};
|
||||
|
||||
auto SmemLayoutAtomV = cutlass::gemm::collective::detail::sm100_smem_selector<
|
||||
cute::UMMA::Major::MN, // gmem layout of V
|
||||
TypeQKV, // data type of V
|
||||
decltype(shape<0>(shape_V)), decltype(shape<1>(shape_V))>(); // tile size of V
|
||||
// swizzle atom for P should be the transpose of the swizzle atom for S, because they literally represent the same tensor just with different dimension order (aka a pytorch transpose)
|
||||
auto SmemLayoutAtomP = UMMA::Layout_K_SW128_Atom<TypeQKV>{};
|
||||
|
||||
// finally construct the smem layout for tile K, Q, V, S, and P
|
||||
auto sK_layout = UMMA::tile_to_mma_shape(SmemLayoutAtomK, mma_shape_K); // ((Mma_M, Mma_K), NumMma_M, NumMma_K, BMM1_DMA_Stage)
|
||||
// paged K smem layout: same memory as sK_layout, viewed as (Page_Size, CTA_dH, NumPagePerCTATile, BMM1_DMA_Stage).
|
||||
// tile_to_shape uses Step<_1, _3, _2, _4> to stack SmemLayoutAtomK the same way as sK_layout,
|
||||
// i.e. first along CTA_kvL/(Page_Size, NumPagePerCTATile), then along CTA_dH, finally along BMM1_DMA_Stage
|
||||
static_assert(CTA_kvL % Page_Size == 0, "Page_Size must divide CTA_kvL");
|
||||
constexpr int NumPagePerCTATile = CTA_kvL / Page_Size;
|
||||
auto sK_paged_layout = tile_to_shape(SmemLayoutAtomK,
|
||||
make_shape(Int<Page_Size>{}, Int<CTA_dH>{}, Int<NumPagePerCTATile>{}, Int<BMM1_DMA_Stage>{}),
|
||||
Step<_1, _3, _2, _4>{});
|
||||
|
||||
auto sQ_layout = UMMA::tile_to_mma_shape(SmemLayoutAtomQ, mma_shape_Q); // ((Mma_N, Mma_K), NumMma_N, NumMma_K, 1)
|
||||
auto sV_layout = UMMA::tile_to_mma_shape(SmemLayoutAtomV, mma_shape_V); // ((Mma_M, Mma_K), NumMma_M, NumMma_K, BMM2_DMA_Stage)
|
||||
// paged V smem layout: same memory as sV_layout, viewed as (CTA_dH, Page_Size, NumPagePerCTATile, BMM2_DMA_Stage).
|
||||
// V is MN-major: atom is (atom_M=CTA_dH, atom_K) -> stacking inside the atom covers all of CTA_dH; multiple atoms tile
|
||||
// along K (kvL). Splitting kvL into Page_Size + NumPage just renames the K-iter modes; default LayoutLeft step works
|
||||
// because mode order (M -> K_inner -> K_outer -> Stage) matches the natural sV_layout stride sequence.
|
||||
auto sV_paged_layout = tile_to_shape(SmemLayoutAtomV,
|
||||
make_shape(Int<CTA_dH>{}, Int<Page_Size>{}, Int<NumPagePerCTATile>{}, Int<BMM2_DMA_Stage>{}));
|
||||
|
||||
// The paged and MMA-partitioned views of K/V must alias the same smem buffer -> cosize must match.
|
||||
static_assert(cute::cosize_v<decltype(sK_paged_layout)> == cute::cosize_v<decltype(sK_layout)>,
|
||||
"sK_paged_layout and sK_layout must alias the same smem buffer (cosize must match)");
|
||||
static_assert(cute::cosize_v<decltype(sV_paged_layout)> == cute::cosize_v<decltype(sV_layout)>,
|
||||
"sV_paged_layout and sV_layout must alias the same smem buffer (cosize must match)");
|
||||
// S and P use tile_to_shape as we do the mma partition in the kernel later
|
||||
auto sS_layout = tile_to_shape(SmemLayoutAtomS, shape_S, Step<_1, _2, _3>{}); // (CTA_kvL, (CTA_qHLocal, CTA_qL), 1)
|
||||
auto sP_layout = tile_to_shape(SmemLayoutAtomP, shape_P, Step<_2, _1, _3>{}); // ((CTA_qHLocal, CTA_qL), CTA_kvL, 1)
|
||||
auto sAcc1_layout = make_layout(shape_S); // (CTA_kvL, (CTA_qHLocal, CTA_qL), 1)
|
||||
// for storing fmax and fsum warp reduce partial results
|
||||
int constexpr NumEpilogWarps = 4;
|
||||
// NumEpilogWarps contiguous because we often ld.shared all NumEpilogWarps from 1/32 threads, this has best vectorization
|
||||
auto sWarpReduce_layout = make_layout(make_shape(Int<NumEpilogWarps>{}, make_shape(Int<CTA_qHLocal>{}, Int<CTA_qL>{}))); // (NumEpilogWarps, (CTA_qHLocal, CTA_qL))
|
||||
// MaxSplits contiguous because we often ld.shared all MaxSplits from 1/32 threads, this has best vectorization
|
||||
auto sMSMailbox_Layout = make_layout(make_shape(Int<MaxSplits>{}, Int<CTA_qHLocal * CTA_qL / NumReductionCTA>{})); // (MaxSplits, CTA_qHLocal * CTA_qL / NumReductionCTA)
|
||||
// default layout is CTA_dH contiguous to maximize st.async/ld.shared bw
|
||||
auto sAcc2Mailbox_layout = make_layout(make_shape(Int<CTA_dH>{}, Int<CTA_qHLocal * CTA_qL / NumReductionCTA>{}, Int<MaxSplits>{})); // (CTA_dH, CTA_qHLocal * CTA_qL / NumReductionCTA, MaxSplits)
|
||||
auto sSinks_layout = make_layout(Int<CTA_qHLocal * CTA_qL / NumReductionCTA>{}); // (CTA_qHLocal * CTA_qL / NumReductionCTA)
|
||||
// DMA_KV's page-idx staging buffer. Page_Idx_Stage and Num_Page_Idx_Per_Stage are configured by the host
|
||||
// harness (template params). Single static_assert block lives here so warp functions don't repeat them.
|
||||
static_assert(Num_Page_Idx_Per_Stage % NumPagePerCTATile == 0,
|
||||
"Num_Page_Idx_Per_Stage must be a multiple of CTA_kvL/Page_Size so a DMA stage's pages live in one pi stage");
|
||||
// Page_Idx_Stage must be exactly 2: K leads V by 1 tile, so around a pi-stage boundary K reads slot S%2
|
||||
// (stage S) while V finishes the last tile of slot (S-1)%2 (stage S-1). =1 would alias these slots; >2
|
||||
// wastes smem because at most 2 pi-stage groups are live at once (the one being consumed + the pre-issued).
|
||||
static_assert(Page_Idx_Stage == 2, "Page_Idx_Stage must be 2 for the K-leads-V-by-1 page-idx pipeline");
|
||||
constexpr int Tiles_Per_Pi_Stage_Host = Num_Page_Idx_Per_Stage / NumPagePerCTATile;
|
||||
// DMA_KV's folded page-idx pipeline issues stage S's cp.async at K_t_in_stage==1 of stage S-1, so each
|
||||
// pi-stage must hold at least 2 CTA tiles. Choose Num_Page_Idx_Per_Stage >= 2 * NumPagePerCTATile.
|
||||
static_assert(Tiles_Per_Pi_Stage_Host >= 2,
|
||||
"Tiles_Per_Pi_Stage (= Num_Page_Idx_Per_Stage / NumPagePerCTATile) must be >= 2 for the folded page-idx pipeline");
|
||||
// Hierarchical layout ((NumPagePerCTATile, Tiles_Per_Pi_Stage), Page_Idx_Stage), default LayoutLeft so
|
||||
// mode 0 is fully contiguous (NumPagePerCTATile innermost). Letting cute carry the (p, t) split keeps the
|
||||
// producer's inner write as sPageIdx(make_coord(p, t), stage_idx) without manual i/N + i%N arithmetic, and
|
||||
// gives the consumer a contiguous NumPagePerCTATile-int slice via &sPageIdx(make_coord(0, t), stage_idx).
|
||||
auto sPageIdx_layout = make_layout(make_shape(make_shape(Int<NumPagePerCTATile>{}, Int<Tiles_Per_Pi_Stage_Host>{}), Int<Page_Idx_Stage>{}));
|
||||
|
||||
// Now we can find the SMEM allocation size
|
||||
using SMEMStorage = SharedStorage<TypeQKV, TypeAcc,
|
||||
decltype(sK_layout), decltype(sK_paged_layout),
|
||||
decltype(sQ_layout),
|
||||
decltype(sV_layout), decltype(sV_paged_layout),
|
||||
decltype(sS_layout), decltype(sP_layout),
|
||||
decltype(sWarpReduce_layout), decltype(sMSMailbox_Layout), decltype(sAcc1_layout),
|
||||
decltype(sAcc2Mailbox_layout), decltype(sSinks_layout),
|
||||
decltype(sPageIdx_layout),
|
||||
BMM1_DMA_Stage, BMM2_DMA_Stage, Page_Idx_Stage>;
|
||||
|
||||
static_assert(BMM1_DMA_Stage >= BMM2_DMA_Stage, "otherwise you are wasting BMM2 stage because BMM1 TMA issue will block BMM2 TMA due to insufficient BMM1 stages");
|
||||
|
||||
// create TMA descriptors for K, Q, V matrices
|
||||
// K TMA box is (Page_Size, CTA_dH) -- one page per TMA copy. The per-page SMEM destination layout points into the
|
||||
// same memory as a single page slot inside sK_layout/sK_paged_layout.
|
||||
Copy_Atom tma_atom_K = make_tma_atom(
|
||||
SM90_TMA_LOAD{}, // TMA Load Op, sm100 reuses sm90 tma atom
|
||||
mK, // Source GMEM tensor
|
||||
take<0, 2>(sK_paged_layout), // Destination SMEM layout for 1 page = 1 TMA box, (Page_Size, CTA_dH)
|
||||
make_shape(Int<Page_Size>{}, Int<CTA_dH>{}) // TMA box shape
|
||||
);
|
||||
Tensor mK_tma = tma_atom_K.get_tma_tensor(shape(mK)); // (Page_Size, dH, num_pages_total, kvH)
|
||||
|
||||
Copy_Atom tma_atom_Q = make_tma_atom(
|
||||
SM90_TMA_LOAD{}, // TMA Load Op, sm100 reuses sm90 tma atom
|
||||
mQ, // Source GMEM tensor
|
||||
// sQ_layout(_,_,_,Int<0>{}) doesn't work under some corner cases (composedlayout indexing), so we use
|
||||
// the take method which is also correct.
|
||||
take<0, 3>(sQ_layout), // Destination SMEM layout for 1 DMA_Stage, ((Mma_N, Mma_K), NumMma_N, NumMma_K)
|
||||
make_shape(get<0>(shape_Q), get<1>(shape_Q)) // TMA box shape
|
||||
);
|
||||
Tensor mQ_tma = tma_atom_Q.get_tma_tensor(shape(mQ)); // ((qHLocal, qL), dH, kvH, BS)
|
||||
|
||||
// V TMA box is (CTA_dH, Page_Size) -- one page per TMA copy. Per-page SMEM destination layout points into the
|
||||
// same memory as a single page slot inside sV_layout/sV_paged_layout.
|
||||
Copy_Atom tma_atom_V = make_tma_atom(
|
||||
SM90_TMA_LOAD{}, // TMA Load Op, sm100 reuses sm90 tma atom
|
||||
mV, // Source GMEM tensor
|
||||
take<0, 2>(sV_paged_layout), // Destination SMEM layout for 1 page = 1 TMA box, (CTA_dH, Page_Size)
|
||||
make_shape(Int<CTA_dH>{}, Int<Page_Size>{}) // TMA box shape
|
||||
);
|
||||
Tensor mV_tma = tma_atom_V.get_tma_tensor(shape(mV)); // (dH, Page_Size, num_pages_total, kvH)
|
||||
|
||||
int smemBytes = sizeof(SMEMStorage);
|
||||
|
||||
// invoke the kernel
|
||||
cudaLaunchConfig_t config;
|
||||
cudaLaunchAttribute attrs[2];
|
||||
// bid.x: kvH * BS, bid.y: qHLocal * qL, bid.z: kvL
|
||||
uint32_t Cluster_Size = cute::max(MaxSplits, NumReductionCTA);
|
||||
config.gridDim = dim3{
|
||||
(uint32_t)kvH * BS,
|
||||
(uint32_t)cutlass::ceil_div(qHLocal, CTA_qHLocal) * cutlass::ceil_div(qL, CTA_qL),
|
||||
Cluster_Size};
|
||||
config.blockDim = 256; // 8 warps
|
||||
config.dynamicSmemBytes = smemBytes;
|
||||
config.stream = stream;
|
||||
attrs[0].id = cudaLaunchAttributeClusterDimension;
|
||||
attrs[0].val.clusterDim = {1, 1, Cluster_Size};
|
||||
attrs[1].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[1].val.programmaticStreamSerializationAllowed = 1;
|
||||
config.attrs = attrs;
|
||||
config.numAttrs = pdl ? 2 : 1;
|
||||
|
||||
if (device_ptr_sinks != nullptr) {
|
||||
auto *kernel_instance =
|
||||
&gqa_paged_device<SMEMStorage,
|
||||
decltype(mK_tma), decltype(mQ_tma), decltype(mV_tma), decltype(mO), decltype(mSink),
|
||||
decltype(mSeqLens), decltype(mPageTable),
|
||||
decltype(tma_atom_K), decltype(tma_atom_Q), decltype(tma_atom_V),
|
||||
decltype(tiled_bmm1), decltype(tiled_bmm2),
|
||||
TypeAcc,
|
||||
CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH,
|
||||
Page_Size,
|
||||
BMM1_DMA_Stage, BMM2_DMA_Stage,
|
||||
Page_Idx_Stage, Num_Page_Idx_Per_Stage,
|
||||
MaxSplits, NumReductionCTA,
|
||||
false>;
|
||||
gpuErrChk(cudaFuncSetAttribute(*kernel_instance, cudaFuncAttributeMaxDynamicSharedMemorySize, smemBytes));
|
||||
// portable max cluster size is 8, but sm100a supports 16, need explicit opt in
|
||||
gpuErrChk(cudaFuncSetAttribute(*kernel_instance, cudaFuncAttributeNonPortableClusterSizeAllowed, 1));
|
||||
gpuErrChk(cudaLaunchKernelEx(&config, kernel_instance, mK_tma, mQ_tma, mV_tma, mO, mSink,
|
||||
mSeqLens, mPageTable,
|
||||
tma_atom_K, tma_atom_Q, tma_atom_V,
|
||||
tiled_bmm1, tiled_bmm2,
|
||||
softmax_scale * Log2_E, sliding_window_size, pdl_count));
|
||||
}
|
||||
else {
|
||||
auto *kernel_instance =
|
||||
&gqa_paged_device<SMEMStorage,
|
||||
decltype(mK_tma), decltype(mQ_tma), decltype(mV_tma), decltype(mO), decltype(mSink),
|
||||
decltype(mSeqLens), decltype(mPageTable),
|
||||
decltype(tma_atom_K), decltype(tma_atom_Q), decltype(tma_atom_V),
|
||||
decltype(tiled_bmm1), decltype(tiled_bmm2),
|
||||
TypeAcc,
|
||||
CTA_qHLocal, CTA_qL, CTA_kvL, CTA_dH,
|
||||
Page_Size,
|
||||
BMM1_DMA_Stage, BMM2_DMA_Stage,
|
||||
Page_Idx_Stage, Num_Page_Idx_Per_Stage,
|
||||
MaxSplits, NumReductionCTA,
|
||||
true>;
|
||||
gpuErrChk(cudaFuncSetAttribute(*kernel_instance, cudaFuncAttributeMaxDynamicSharedMemorySize, smemBytes));
|
||||
// portable max cluster size is 8, but sm100a supports 16, need explicit opt in
|
||||
gpuErrChk(cudaFuncSetAttribute(*kernel_instance, cudaFuncAttributeNonPortableClusterSizeAllowed, 1));
|
||||
gpuErrChk(cudaLaunchKernelEx(&config, kernel_instance, mK_tma, mQ_tma, mV_tma, mO, mSink,
|
||||
mSeqLens, mPageTable,
|
||||
tma_atom_K, tma_atom_Q, tma_atom_V,
|
||||
tiled_bmm1, tiled_bmm2,
|
||||
softmax_scale * Log2_E, sliding_window_size, pdl_count));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gqa_paged
|
||||
} // namespace TGV
|
||||
@@ -1207,6 +1207,10 @@ struct SM100_MMA_S8_2x1SM_SS_SPARSE
|
||||
}
|
||||
};
|
||||
|
||||
template <class a_type, class b_type, class c_type, int M, int N,
|
||||
UMMA::Major a_major, UMMA::Major b_major,
|
||||
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One,
|
||||
UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
|
||||
struct SM100_MMA_F8F6F4_SS
|
||||
{
|
||||
using DRegisters = void;
|
||||
@@ -1452,6 +1456,10 @@ struct SM100_MMA_MXF8F6F4_SS_SPARSE
|
||||
}
|
||||
};
|
||||
|
||||
template <class a_type, class b_type, class c_type, int M, int N,
|
||||
UMMA::Major a_major, UMMA::Major b_major,
|
||||
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One,
|
||||
UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
|
||||
struct SM100_MMA_F8F6F4_2x1SM_SS
|
||||
{
|
||||
using DRegisters = void;
|
||||
|
||||
@@ -3327,12 +3327,9 @@ struct MMA_Traits<SM100_MMA_S8_2x1SM_SS_SPARSE<a_type, b_type, c_type,
|
||||
template <class a_type, class b_type, class c_type,
|
||||
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
|
||||
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
|
||||
struct MMA_Traits<SM100_MMA_F8F6F4_SS, a_type, b_type, c_type,
|
||||
cute::C<M>, cute::C<N>,
|
||||
cute::integral_constant<UMMA::Major, a_major>,
|
||||
cute::integral_constant<UMMA::Major, b_major>,
|
||||
cute::integral_constant<UMMA::ScaleIn, a_neg>,
|
||||
cute::integral_constant<UMMA::ScaleIn, b_neg>>
|
||||
struct MMA_Traits<SM100_MMA_F8F6F4_SS<a_type, b_type, c_type,
|
||||
M, N, a_major, b_major,
|
||||
a_neg, b_neg>>
|
||||
{
|
||||
using ValTypeD = c_type;
|
||||
using ValTypeA = a_type;
|
||||
@@ -3390,7 +3387,9 @@ struct MMA_Traits<SM100_MMA_F8F6F4_SS, a_type, b_type, c_type,
|
||||
uint32_t tmem_c = raw_pointer_cast(D.data());
|
||||
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
|
||||
|
||||
SM100_MMA_F8F6F4_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
|
||||
SM100_MMA_F8F6F4_SS<a_type, b_type, c_type,
|
||||
M, N, a_major, b_major,
|
||||
a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -3745,12 +3744,9 @@ struct MMA_Traits<SM100_MMA_F8F6F4_SS_SPARSE<a_type, b_type, c_type,
|
||||
template <class a_type, class b_type, class c_type,
|
||||
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
|
||||
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
|
||||
struct MMA_Traits<SM100_MMA_F8F6F4_2x1SM_SS, a_type, b_type, c_type,
|
||||
cute::C<M>, cute::C<N>,
|
||||
cute::integral_constant<UMMA::Major, a_major>,
|
||||
cute::integral_constant<UMMA::Major, b_major>,
|
||||
cute::integral_constant<UMMA::ScaleIn, a_neg>,
|
||||
cute::integral_constant<UMMA::ScaleIn, b_neg>>
|
||||
struct MMA_Traits<SM100_MMA_F8F6F4_2x1SM_SS<a_type, b_type, c_type,
|
||||
M, N, a_major, b_major,
|
||||
a_neg, b_neg>>
|
||||
{
|
||||
using ValTypeD = c_type;
|
||||
using ValTypeA = a_type;
|
||||
@@ -3808,7 +3804,9 @@ struct MMA_Traits<SM100_MMA_F8F6F4_2x1SM_SS, a_type, b_type, c_type,
|
||||
uint32_t tmem_c = raw_pointer_cast(D.data());
|
||||
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
|
||||
|
||||
SM100_MMA_F8F6F4_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
|
||||
SM100_MMA_F8F6F4_2x1SM_SS<a_type, b_type, c_type,
|
||||
M, N, a_major, b_major,
|
||||
a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -57,6 +57,14 @@
|
||||
# endif // (__CUDA_ARCH__ >= 900)
|
||||
#endif // defined(__CUDA_ARCH__)
|
||||
|
||||
#if defined(__CUDA_ARCH__)
|
||||
# if (__CUDA_ARCH__ >= 1000)
|
||||
# if (__CUDACC_VER_MAJOR__ > 13) || ((__CUDACC_VER_MAJOR__ >= 13) && (__CUDACC_VER_MINOR__ >= 2))
|
||||
# define CUDA_PTX_FP8_BF16_CVT_ENABLED 1
|
||||
# endif // (__CUDACC_VER_MAJOR__ > 13) || ((__CUDACC_VER_MAJOR__ >= 13) && (__CUDACC_VER_MINOR__ >= 2))
|
||||
# endif // (__CUDA_ARCH__ >= 1000)
|
||||
#endif // defined(__CUDA_ARCH__)
|
||||
|
||||
|
||||
#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\
|
||||
defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110A_ENABLED) ||\
|
||||
|
||||
@@ -339,18 +339,16 @@ sm100_make_1sm_trivial_tiled_mma() {
|
||||
) {
|
||||
|
||||
return make_tiled_mma(
|
||||
cute::MMA_Traits<
|
||||
cute::SM100_MMA_F8F6F4_SS,
|
||||
cute::SM100_MMA_F8F6F4_SS<
|
||||
ElementAMma,
|
||||
ElementBMma,
|
||||
ElementAMmaccumulator,
|
||||
cute::C<M>,
|
||||
cute::C<N>,
|
||||
cute::integral_constant<UMMA::Major, UmmaMajorA>,
|
||||
cute::integral_constant<UMMA::Major, UmmaMajorB>,
|
||||
cute::integral_constant<UMMA::ScaleIn, ANeg>,
|
||||
cute::integral_constant<UMMA::ScaleIn, BNeg>
|
||||
>{}
|
||||
M,
|
||||
N,
|
||||
UmmaMajorA,
|
||||
UmmaMajorB,
|
||||
ANeg,
|
||||
BNeg>{}
|
||||
);
|
||||
}
|
||||
else {
|
||||
@@ -407,18 +405,16 @@ sm100_make_2sm_trivial_tiled_mma() {
|
||||
) {
|
||||
|
||||
return make_tiled_mma(
|
||||
cute::MMA_Traits<
|
||||
cute::SM100_MMA_F8F6F4_2x1SM_SS,
|
||||
cute::SM100_MMA_F8F6F4_2x1SM_SS<
|
||||
ElementAMma,
|
||||
ElementBMma,
|
||||
ElementAMmaccumulator,
|
||||
cute::C<M>,
|
||||
cute::C<N>,
|
||||
cute::integral_constant<UMMA::Major, UmmaMajorA>,
|
||||
cute::integral_constant<UMMA::Major, UmmaMajorB>,
|
||||
cute::integral_constant<UMMA::ScaleIn, ANeg>,
|
||||
cute::integral_constant<UMMA::ScaleIn, BNeg>
|
||||
>{}
|
||||
M,
|
||||
N,
|
||||
UmmaMajorA,
|
||||
UmmaMajorB,
|
||||
ANeg,
|
||||
BNeg>{}
|
||||
);
|
||||
|
||||
}
|
||||
@@ -739,17 +735,16 @@ sm100_make_trivial_mixed_input_tiled_mma() {
|
||||
}
|
||||
if constexpr (cute::is_same_v<ElementBMma, cutlass::float_e4m3_t>) {
|
||||
return make_tiled_mma(
|
||||
cute::MMA_Traits<
|
||||
cute::SM100_MMA_F8F6F4_SS,
|
||||
cute::SM100_MMA_F8F6F4_SS<
|
||||
ElementAMma,
|
||||
ElementBMma,
|
||||
ElementAccumulator,
|
||||
cute::C<M>,
|
||||
cute::C<N>,
|
||||
cute::integral_constant<UMMA::Major, UmmaMajorA>,
|
||||
cute::integral_constant<UMMA::Major, UmmaMajorB>,
|
||||
cute::integral_constant<UMMA::ScaleIn, cute::UMMA::ScaleIn::One>,
|
||||
cute::integral_constant<UMMA::ScaleIn, cute::UMMA::ScaleIn::One>>{});
|
||||
M,
|
||||
N,
|
||||
UmmaMajorA,
|
||||
UmmaMajorB,
|
||||
cute::UMMA::ScaleIn::One,
|
||||
cute::UMMA::ScaleIn::One>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1820,14 +1820,21 @@ struct NumericArrayConverter<cutlass::bfloat16_t, cutlass::float_e4m3_t, 2, Roun
|
||||
#if defined(CUDA_PTX_FP8_CVT_ENABLED)
|
||||
uint32_t res_half;
|
||||
uint16_t const& src_packed = reinterpret_cast<uint16_t const&>(source);
|
||||
|
||||
asm volatile( \
|
||||
"{\n" \
|
||||
"cvt.rn.f16x2.e4m3x2 %0, %1;\n" \
|
||||
"}\n" : "=r"(res_half): "h"(src_packed));
|
||||
float2 res_float = __half22float2(reinterpret_cast<__half2 &>(res_half));
|
||||
NumericArrayConverter<cutlass::bfloat16_t, float, 2, Round> converter;
|
||||
return converter(reinterpret_cast<Array<float, 2> const&>(res_float));
|
||||
#if defined(CUDA_PTX_FP8_BF16_CVT_ENABLED)
|
||||
asm volatile( \
|
||||
"{\n" \
|
||||
"cvt.rn.bf16x2.e4m3x2 %0, %1;\n" \
|
||||
"}\n" : "=r"(res_half): "h"(src_packed));
|
||||
return reinterpret_cast<result_type const &>(res_half);
|
||||
#else
|
||||
asm volatile( \
|
||||
"{\n" \
|
||||
"cvt.rn.f16x2.e4m3x2 %0, %1;\n" \
|
||||
"}\n" : "=r"(res_half): "h"(src_packed));
|
||||
float2 res_float = __half22float2(reinterpret_cast<__half2 &>(res_half));
|
||||
NumericArrayConverter<cutlass::bfloat16_t, float, 2, Round> converter;
|
||||
return converter(reinterpret_cast<Array<float, 2> const&>(res_float));
|
||||
#endif
|
||||
#else
|
||||
result_type result;
|
||||
NumericConverter<result_element, source_element, Round> converter;
|
||||
@@ -2961,19 +2968,31 @@ struct NumericArrayConverterPacked4Element<cutlass::bfloat16_t, cutlass::float_e
|
||||
static result_type convert(source_type const & source) {
|
||||
|
||||
#if defined(CUDA_PTX_FP8_CVT_ENABLED)
|
||||
// Convert f8 to float
|
||||
NumericArrayConverterPacked4Element<float, source_element, Round> src2float;
|
||||
Array<float, 4> tmp_floats = src2float(source);
|
||||
|
||||
// Convert float to bf16
|
||||
result_type out;
|
||||
Array<float, 2>* packed_tmp = reinterpret_cast<Array<float, 2>*>(&tmp_floats);
|
||||
Array<result_element, 2>* packed_out = reinterpret_cast<Array<result_element, 2>*>(&out);
|
||||
NumericArrayConverter<result_element, float, 2, Round> float2result;
|
||||
packed_out[0] = float2result(packed_tmp[0]);
|
||||
packed_out[1] = float2result(packed_tmp[1]);
|
||||
#if defined(CUDA_PTX_FP8_BF16_CVT_ENABLED)
|
||||
uint32_t const& src_packed = reinterpret_cast<uint32_t const&>(source);
|
||||
Array<uint32_t, 2>& out_packed = reinterpret_cast<Array<uint32_t, 2>&>(out);
|
||||
asm volatile("{\n"
|
||||
".reg .b16 b0, b1;\n"
|
||||
"mov.b32 {b0, b1}, %2;\n"
|
||||
"cvt.rn.bf16x2.e4m3x2 %0, b0;\n"
|
||||
"cvt.rn.bf16x2.e4m3x2 %1, b1;\n"
|
||||
"}\n"
|
||||
: "=r"(out_packed[0]), "=r"(out_packed[1])
|
||||
: "r"(src_packed));
|
||||
#else
|
||||
// Convert f8 to float
|
||||
NumericArrayConverterPacked4Element<float, source_element, Round> src2float;
|
||||
Array<float, 4> tmp_floats = src2float(source);
|
||||
|
||||
return out;
|
||||
// Convert float to bf16
|
||||
Array<float, 2>* packed_tmp = reinterpret_cast<Array<float, 2>*>(&tmp_floats);
|
||||
Array<result_element, 2>* packed_out = reinterpret_cast<Array<result_element, 2>*>(&out);
|
||||
NumericArrayConverter<result_element, float, 2, Round> float2result;
|
||||
packed_out[0] = float2result(packed_tmp[0]);
|
||||
packed_out[1] = float2result(packed_tmp[1]);
|
||||
#endif
|
||||
return out;
|
||||
#else
|
||||
result_type result;
|
||||
NumericConverter<result_element, source_element, Round> converter;
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
|
||||
#define CUTLASS_MAJOR 4
|
||||
#define CUTLASS_MINOR 5
|
||||
#define CUTLASS_PATCH 0
|
||||
#define CUTLASS_PATCH 1
|
||||
|
||||
#ifdef CUTLASS_VERSIONS_GENERATED
|
||||
#include "cutlass/version_extended.h"
|
||||
|
||||
@@ -23,3 +23,5 @@ CuTe DSL
|
||||
Compile with TVM FFI <cute_dsl_general/compile_with_tvm_ffi.rst>
|
||||
Ahead-of-Time (AOT) Compilation <cute_dsl_general/dsl_ahead_of_time_compilation.rst>
|
||||
Talks and Presentations <cute_dsl_general/resources.rst>
|
||||
Naming Conventions <cute_dsl_general/naming_conventions.rst>
|
||||
MMA Programming Guides <mma_docs/intro.rst>
|
||||
|
||||
@@ -83,6 +83,9 @@ an elementwise lambda function can be passed in as the ``epilogue_op`` argument.
|
||||
|
||||
Refer to the `Blackwell dense GEMM example <https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py>`__ for a complete example.
|
||||
|
||||
.. note::
|
||||
For the per-thread/partition naming convention used above (``tTR_rAcc``, ``tTR_rC``, and related tokens such as ``tAgA``, ``bSG_sC``, ``tQgQ_qdl``, …), see the :ref:`cute_dsl_naming_conventions`.
|
||||
|
||||
Type safety
|
||||
-----------
|
||||
|
||||
|
||||
253
media/docs/pythonDSL/cute_dsl_general/naming_conventions.rst
Normal file
253
media/docs/pythonDSL/cute_dsl_general/naming_conventions.rst
Normal file
@@ -0,0 +1,253 @@
|
||||
.. _cute_dsl_naming_conventions:
|
||||
|
||||
CuTe DSL Naming Conventions
|
||||
===========================
|
||||
|
||||
This page summarizes the Hungarian-style naming conventions used for identifiers across the DSL examples and epilogue helpers: tensor partitions, per-thread copy-partitioners, copy atoms, and the axis-order suffixes that encode tensor layouts. It is meant as a lookup reference while reading example code — not as a style rule enforced on new code.
|
||||
|
||||
Memory/space scopes
|
||||
-------------------
|
||||
|
||||
- ``g``: Global memory view (GMEM), e.g., ``gB_nkl``, ``tTR_gC``
|
||||
- ``s``: Shared memory view (SMEM), e.g., ``sA``, ``tRS_sC``, ``bSG_sC``
|
||||
- ``r``: Register view (RMEM), e.g., ``tTR_rAcc``, ``tRS_rC``
|
||||
- ``t``: Tensor-memory view (TMEM), used for any TMEM-resident fragment or layout regardless of role. The classical case is the accumulator (``tCtAcc``, ``tTR_tAcc``). The same scope letter also appears for non-accumulator TMEM tensors such as ``tCtE``, ``tCtState``, ``tCtQState``, ``tCtShared``. Read the operand suffix to distinguish the role from the memory scope.
|
||||
|
||||
Per-thread/partitioned views and families
|
||||
-----------------------------------------
|
||||
|
||||
- ``tA…`` / ``tB…``: TMA load path for A/B
|
||||
|
||||
- ``tAgA`` / ``tAsA``: per-thread partitioned global/shared A for TMA load
|
||||
- ``tBgB`` / ``tBsB``: per-thread partitioned global/shared B for TMA load
|
||||
- NVFP4/FP8 scale factors mirror this: ``tAgSFA`` / ``tAsSFA``, ``tBgSFB`` / ``tBsSFB``
|
||||
|
||||
- ``tC…``: Compute/epilogue path for C/Acc
|
||||
|
||||
- ``tCgA`` / ``tCgB`` / ``tCgC``: per-thread partitions used by MMA/epilogue (derived from global tensors)
|
||||
- ``tCrA`` / ``tCrB``: per-thread fragments used by MMA (derived from SMEM A/B)
|
||||
- ``tCtAcc``: per-thread accumulator fragment/layout in TMEM
|
||||
- Additional ``tC*`` tensors follow the same schema for kernels that carry more than the classical A/B/C/Acc operands (see Operands and roles below): e.g. ``tCtState`` / ``tCtQState`` / ``tCtShared`` (gated-delta-net recurrent state in TMEM), ``tCrValpha`` / ``tCrVbeta`` / ``tCrVbias`` (EVT/EFC broadcast vectors in registers), ``tCtAccInter`` / ``tCtAccIntra`` (hierarchical accumulators)
|
||||
|
||||
.. {$nv-internal-release begin}
|
||||
|
||||
- Sparse GEMM additionally defines ``tCtE`` for the sparsity metadata tensor in TMEM (sm_140 / Feynman sparse GEMM, not yet released)
|
||||
|
||||
.. {$nv-internal-release end}
|
||||
|
||||
- ``tTM…``: Per-thread TMEM tiled-copy partitions used by FMHA/attention kernels (e.g. ``tTMrO`` as the register-side view of a TMEM load partitioned through ``thr_tmem_load``)
|
||||
|
||||
- Attention/MLA path families (``tQ…``, ``tK…``, ``tV…``, ``tP…``, ``tO…``): same schema as ``tA…`` / ``tB…`` / ``tC…`` but specialised to the Q/K/V/P/O operands of attention kernels, e.g.:
|
||||
|
||||
- ``tQsQ`` / ``tQgQ_qdl``: per-thread SMEM / GMEM partitions of Q for TMA load
|
||||
- ``tKrK`` / ``tVrV``: per-thread register fragments for K / V
|
||||
- ``tOtO`` / ``tOrO``: per-thread TMEM / register views of the attention output accumulator O
|
||||
- ``tPrP``: per-thread register fragment for the softmax probability matrix P
|
||||
|
||||
Data-movement copy paths
|
||||
------------------------
|
||||
|
||||
- ``tTR_*``: TMEM → Register (T2R)
|
||||
|
||||
- ``tTR_tAcc``: TMEM accumulator source for T2R
|
||||
- ``tTR_rAcc``: Register destination for T2R
|
||||
- ``tTR_gC``: When not using TMA store, Register → Global C destination partition
|
||||
|
||||
- ``tRS_*``: Register → Shared (R2S)
|
||||
|
||||
- ``tRS_rC``: Register source (C dtype)
|
||||
- ``tRS_sC``: Shared destination
|
||||
|
||||
- ``bSG_*``: Thread(b)lock partition for Shared → Global via TMA store
|
||||
|
||||
- ``bSG_sC``: Shared source for TMA store
|
||||
- ``bSG_gC``: Global destination for TMA store
|
||||
- Also used for accumulator in some flows: ``bSG_sAcc``, ``bSG_gAcc``
|
||||
- The same schema extends to additional store operands: ``bSG_sD`` / ``bSG_gD``, ``bSG_sP`` / ``bSG_gP``, ``bSG_sY`` / ``bSG_gY``
|
||||
|
||||
- ``bGS_*``: Thread(b)lock partition for Global → Shared via TMA **load** (the load-path mirror of ``bSG_*``)
|
||||
|
||||
- ``bGS_gC`` / ``bGS_sC``: Global source / Shared destination for TMA load of C-like operands (seen in EFC row/column broadcast prologues)
|
||||
|
||||
- ``simt_atom``: SIMT copy path used when TMA store is disabled (Register → Global)
|
||||
- Generic SIMT / tiled copy atoms ``<src>2<dst>_atom[_suffix]`` name the copy direction between two memory scopes:
|
||||
|
||||
- ``s2r_atom_*``: Shared → Register atom used in specialised epilogues and attention loads (e.g. ``s2r_atom_delta``, ``s2r_atom_cumsum``, ``s2r_atom_d`` in Mamba2 SSD)
|
||||
- ``r2s_atom``: Register → Shared atom
|
||||
- ``t2r_atom`` / ``r2t_atom``: Tensor memory ↔ Register atoms (paired with ``thr_tmem_load`` / ``thr_tmem_store``)
|
||||
- ``s2s_atom``: Shared → Shared atom (reshape/remap without register spill)
|
||||
- ``s2t``: Shared → Tensor memory atom
|
||||
|
||||
.. {$nv-internal-release begin}
|
||||
|
||||
- ``sp2t_copy_op_*``: Sparse source → Tensor memory copy op (sm_140 / Feynman sparse GEMM, not yet released: e.g. ``Sp2TAsACopyOp``, ``Sp2TAsECopyOp``)
|
||||
|
||||
.. {$nv-internal-release end}
|
||||
|
||||
- Custom ``autovec_copy`` paths appear where the DSL auto-vectorises a bespoke layout
|
||||
|
||||
Operands and roles
|
||||
------------------
|
||||
|
||||
- ``A``, ``B``, ``C``: GEMM operands
|
||||
- ``Acc``: Accumulator (TMEM/Register paths). Hierarchical MMA kernels split this into ``AccInter`` / ``AccIntra`` for the inter-/intra-CTA accumulator halves
|
||||
- Classical extra outputs / intermediates: ``D`` (additional output), ``Y`` (fused output), ``SFA`` / ``SFB`` (per-operand scale-factor arrays for NVFP4/FP8), ``SF`` (generic scale factor)
|
||||
- Attention / MLA operand letters (Q/K/V/P/O schema):
|
||||
|
||||
- ``Q`` (query), ``K`` (key), ``V`` (value), ``P`` (softmax probability / score matrix), ``O`` (attention output)
|
||||
- Variants: ``Kt`` / ``Vt`` for the transposed view of K/V, ``Qi`` / ``Ki`` / ``Vi`` for per-iteration slices, ``QK`` / ``PV`` / ``QKV`` where a single fragment spans multiple operands of the two back-to-back matmuls
|
||||
- Mamba / recurrent-state letters: ``Delta`` / ``DeltaA`` (time-step and A-decay), ``State`` / ``QState`` / ``Shared`` (gated-delta-net recurrent state tensors), ``Cumsumlog`` / ``Cumprod`` (running reductions), ``Gate``, ``DecayV``
|
||||
|
||||
.. {$nv-internal-release begin}
|
||||
|
||||
- Sparse-GEMM letters (sm_140 / Feynman, not yet released): ``E`` (sparsity metadata tensor in TMEM; paired with ``sp2t_*`` copy ops)
|
||||
|
||||
.. {$nv-internal-release end}
|
||||
|
||||
- EVT / EFC broadcast vectors: ``Valpha`` / ``Vbeta`` (alpha/beta scalars broadcast as vectors), ``Vbias`` (bias vector), ``Ainv`` (inverse of A for fused solvers)
|
||||
|
||||
.. {$nv-internal-release begin}
|
||||
|
||||
- LUT-based block-scaled GEMM letter (Rubin, not yet released): ``LutB`` (look-up-table operand)
|
||||
|
||||
.. {$nv-internal-release end}
|
||||
- Communication operands (multi-CTA / multicast flows): ``CommInMC`` / ``CommOutMC`` (multicast in/out), ``CommOutUC`` (unicast out)
|
||||
- Head-dimension variants: ``Dv`` (value head dimension when distinct from Q/K dim), ``Nv`` (number of value heads)
|
||||
|
||||
Axis-order suffixes
|
||||
-------------------
|
||||
|
||||
- Suffix encodes axis order of the view (lowercase letters each stand for one tensor mode):
|
||||
|
||||
- GEMM layouts use ``m``/``n``/``k``/``l``:
|
||||
|
||||
- ``_mnl``, ``_nkl``, ``_mkl``, … map to (M, N, K, L) ordering
|
||||
- Example: ``gB_nkl`` is B with axes (N, K, L); ``gC_mnl`` is C with (M, N, L)
|
||||
|
||||
- Attention / FMHA layouts use ``q``/``k``/``d``/``l`` (sequence-Q, sequence-K, head-dim, batch):
|
||||
|
||||
- ``mQ_qdl``: Q tensor with axes (SeqQ, HeadDim, Batch)
|
||||
- ``mK_kdl``: K tensor with axes (SeqK, HeadDim, Batch)
|
||||
- ``mV_dkl``: V tensor with axes (HeadDim, SeqK, Batch) — the ``d``-first order reflects the V-transpose that makes the second matmul (P·V) a standard row-major ``MxK·KxN``
|
||||
|
||||
- Lower-rank 2D slices drop the batch letter: ``_mn``, ``_mk``, ``_nk``
|
||||
|
||||
- Internally, CuTe layouts also expose grouped modes like ``MMA_M/N/K``, ``EPI_M/N``, ``RestM/N/K/L``, ``STAGE``, etc. (these are typically implementation details not directly used in example code).
|
||||
|
||||
Reading compound tokens
|
||||
-----------------------
|
||||
|
||||
- From left to right: ``[t|b][A|B|C|Q|K|V|P|O|TR|RS|SG|GS|TM]_[g|s|r|t][Operand/Role][AxisSuffix?]``
|
||||
|
||||
- ``t`` = per-thread/partitioned view; ``b`` = block/threadblock partition context
|
||||
- family/path letters:
|
||||
|
||||
- Operand-based: ``A`` / ``B`` / ``C`` (GEMM), ``Q`` / ``K`` / ``V`` / ``P`` / ``O`` (attention)
|
||||
- Direction-based: ``TR`` (TMEM → Register), ``RS`` (Register → Shared), ``SG`` (Shared → Global, store), ``GS`` (Global → Shared, load), ``TM`` (TMEM tiled-copy partition), ``R2G`` / ``S2R`` / ``T2R`` / ``R2T`` convenience aliases
|
||||
- memory = ``g``/``s``/``r``/``t``
|
||||
- operand/role = ``A``/``B``/``C``/``Acc``/``SFA``/``SFB``/``Q``/``K``/``V``/``P``/``O``/``E``/``State``/…
|
||||
- axis suffix = ``_mnl``, ``_nkl``, ``_qdl``, ``_kdl``, ``_dkl``, ``_mn``, … when applicable
|
||||
|
||||
- Per-thread-partitioner objects follow a parallel ``thr_*`` vocabulary, grouped by role:
|
||||
|
||||
- MMA partitioner: ``thr_mma``
|
||||
- Tiled-copy direction variants ``thr_copy_<src>2<dst>``: ``thr_copy_g2s``, ``thr_copy_s2r``, ``thr_copy_t2r``, ``thr_copy_r2s``, ``thr_copy_r2t``, ``thr_copy_s2t``
|
||||
- Role-qualified copy variants: ``thr_copy_sfa``, ``thr_copy_sfb``, ``thr_copy_load``, ``thr_copy_beta_g2s``
|
||||
- MMA variants for multi-matmul kernels: ``thr_mma_qk``, ``thr_mma_pv``, ``thr_mma_kv``, ``thr_mma_qkv``, ``thr_mma_intra1`` / ``thr_mma_intra2``, ``thr_mma_leader_cta``, ``thr_mma_sfb``
|
||||
- TMEM access partitioners: ``thr_tmem_load``, ``thr_tmem_store`` (with ``_stats`` / ``_vec`` suffix variants)
|
||||
|
||||
The tensor produced by ``thr_foo.partition_S(X)`` or ``.partition_D(X)`` is then named by the ``[t|b]FamilyPrefix_*`` convention above.
|
||||
|
||||
Concrete references
|
||||
-------------------
|
||||
|
||||
Open these files in the repository to see each pattern in context:
|
||||
|
||||
- TMA load partitions for A/B:
|
||||
|
||||
- ``tAgA``, ``tAsA``, ``tBgB``, ``tBsB``
|
||||
- ``CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py`` (around TMA partition of A/B)
|
||||
|
||||
- Accumulator fragment in TMEM:
|
||||
|
||||
- ``tCtAcc``
|
||||
- ``CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py`` (accumulator creation and use)
|
||||
|
||||
- TMEM → Register (T2R):
|
||||
|
||||
- ``tTR_tAcc``, ``tTR_rAcc``, ``tTR_gC``
|
||||
- ``CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py`` (``epilog_tmem_copy_and_partition``)
|
||||
|
||||
- Register → Shared (R2S):
|
||||
|
||||
- ``tRS_rC``, ``tRS_sC``
|
||||
- ``CuTeDSL/cute/blackwell/kernel/mixed_input_gemm/mixed_input_gemm.py`` (``epilog_smem_copy_and_partition``)
|
||||
|
||||
- Shared → Global via TMA store:
|
||||
|
||||
- ``bSG_sC``, ``bSG_gC``
|
||||
- ``CuTeDSL/cute/blackwell/kernel/blockscaled_gemm/dense_blockscaled_gemm_persistent.py`` (``epilog_gmem_copy_and_partition``)
|
||||
|
||||
- NVFP4/FP8 scale factors:
|
||||
|
||||
- ``tAgSFA``/``tAsSFA``, ``tBgSFB``/``tBsSFB``
|
||||
- ``CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/nvfp4_gemm_0.py`` (scale factor partition and usage)
|
||||
|
||||
- Additional examples across ``examples/``:
|
||||
|
||||
- Register → Global helper naming in MLA: ``tR2G_rO_src``, ``tR2G_rO_dst``
|
||||
- ``CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp16.py`` (output store section)
|
||||
|
||||
- Shared → Register SIMT atoms in Mamba2 SSD: ``s2r_atom_delta``, ``s2r_atom_cumsum``, ``s2r_atom_d``
|
||||
- ``CuTeDSL/cute/blackwell/kernel/attention/mamba2_ssd/mamba2_ssd.py`` (SMEM load paths for delta and D)
|
||||
|
||||
- ``thr_*`` slices for partitioning per-thread work: ``thr_mma``, ``thr_copy_t2r``, ``thr_copy_r2s``, etc.
|
||||
- ``CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py`` (``thr_mma``, ``thr_copy_t2r``, ``thr_copy_r2s``)
|
||||
|
||||
- Axis-order suffix examples:
|
||||
|
||||
- ``gB_nkl``, ``gC_mnl``
|
||||
- ``CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py`` (global tensor tiling and partitioning)
|
||||
|
||||
- Global → Shared (TMA load) block partition ``bGS_*``:
|
||||
|
||||
- ``bGS_gC``, ``bGS_sC``
|
||||
- ``CuTeDSL/cute/blackwell/efc/common_efc.py`` (row/column broadcast prologue building the C-like input for EVT)
|
||||
|
||||
- Attention Q/K/V/P/O families and ``_qdl`` / ``_kdl`` / ``_dkl`` axis suffixes:
|
||||
|
||||
- ``tQsQ``, ``tQgQ_qdl``, ``mK_kdl``, ``mV_dkl``
|
||||
- ``CuTeDSL/cute/hopper/kernel/attention/fmha.py`` (Q/K/V TMA partitions)
|
||||
- ``tOtO``, ``tOrO``, ``tPrP``
|
||||
- ``CuTeDSL/cute/blackwell/tutorial/tutorial_fmha/fmha_0.py`` (output and softmax fragments)
|
||||
- ``tKrK``, ``tVrV``
|
||||
- ``CuTeDSL/cute/blackwell/kernel/attention/mixed_input_fmha/mixed_input_fmha_decode.py`` (mixed-input K/V register fragments)
|
||||
|
||||
- TMEM tiled-copy ``tTM*`` family and the generalised ``<src>2<dst>_atom`` naming:
|
||||
|
||||
- ``tTMrO`` driven by ``thr_tmem_load``
|
||||
- ``CuTeDSL/cute/blackwell/tutorial/tutorial_fmha/fmha_0.py``
|
||||
|
||||
- Recurrent-state operands (``State`` / ``QState`` / ``Shared``) in TMEM:
|
||||
|
||||
- ``tCtState``, ``tCtQState``, ``tCtShared``
|
||||
- ``CuTeDSL/cute/blackwell/kernel/attention/gated_delta_net/gated_delta_net_chunked.py``
|
||||
|
||||
.. {$nv-internal-release begin}
|
||||
|
||||
- Sparse-metadata operand ``E`` and ``sp2t_*`` copy ops (sm_140 / Feynman, not yet released):
|
||||
|
||||
- ``tCtE``, ``sp2t_copy_op_A``, ``sp2t_copy_op_E``
|
||||
- ``CuTeDSL/internal/feynman/sm140_sparse_gemm.py`` and ``sm140_sparse_gemm_temporal_split_k.py``
|
||||
|
||||
- LUT-based block-scaled GEMM operand ``LutB`` (Rubin, not yet released):
|
||||
|
||||
- ``CuTeDSL/cute/rubin/kernel/blockscaled_gemm/dense_blockscaled_gemm_lut.py``
|
||||
- ``CuTeDSL/cute_ext/rubin/dense_gemm_lutb.py``
|
||||
|
||||
.. {$nv-internal-release end}
|
||||
|
||||
- Richer ``thr_*`` and ``thr_copy_*`` / ``thr_mma_*`` / ``thr_tmem_*`` partitioner taxonomy:
|
||||
|
||||
- ``thr_copy_g2s``, ``thr_copy_s2r``, ``thr_copy_s2t``, ``thr_copy_r2t``, ``thr_mma_qk``, ``thr_mma_pv``, ``thr_tmem_load``, ``thr_tmem_store``
|
||||
- The attention and Mamba2 examples above are the densest references; any ``fmha_*.py`` or ``mamba2_ssd.py`` file will show the full vocabulary in use
|
||||
11
media/docs/pythonDSL/mma_docs/intro.rst
Normal file
11
media/docs/pythonDSL/mma_docs/intro.rst
Normal file
@@ -0,0 +1,11 @@
|
||||
Architecture-specific MMA Programming Guides
|
||||
=============================================
|
||||
|
||||
This section contains architecture-specific MMA programming guides.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
wmma_programming
|
||||
wgmma_programming
|
||||
tcgen05_programming
|
||||
1528
media/docs/pythonDSL/mma_docs/tcgen05_programming.rst
Normal file
1528
media/docs/pythonDSL/mma_docs/tcgen05_programming.rst
Normal file
File diff suppressed because it is too large
Load Diff
987
media/docs/pythonDSL/mma_docs/wgmma_programming.rst
Normal file
987
media/docs/pythonDSL/mma_docs/wgmma_programming.rst
Normal file
@@ -0,0 +1,987 @@
|
||||
.. _wgmma_programming:
|
||||
|
||||
Warpgroup MMA Programming Guide
|
||||
================================
|
||||
|
||||
Hopper (SM90a) introduced the **warpgroup-level MMA** PTX instruction family
|
||||
``wgmma.mma_async.sync.aligned``. A warpgroup (128 threads / 4 warps)
|
||||
cooperates on one asynchronous ``D = A * B + C`` matrix multiply-accumulate.
|
||||
|
||||
Key architectural characteristics:
|
||||
|
||||
* **Warpgroup scope:** One MMA is issued collectively by a 128-thread
|
||||
warpgroup rather than by a single warp.
|
||||
* **Asynchronous issue model:** WGMMA instructions are ordered with
|
||||
``cute.nvgpu.warpgroup.fence()``, ``commit_group()``, and ``wait_group()``.
|
||||
* **Descriptor-based operand path:** Operand B is sourced from staged shared
|
||||
memory. Operand A can be sourced either from shared memory descriptors or
|
||||
from registers via ``OperandSource``.
|
||||
* **Register accumulator:** The accumulator lives in RMEM and serves as both
|
||||
the input C and output D of ``cute.gemm()``.
|
||||
* **Architecture-specific operand layouts:** F16/BF16 supports K-major and
|
||||
MN-major dense layouts when A comes from SMEM. FP8 and INT8 variants are
|
||||
K-major only.
|
||||
|
||||
The dense DSL op classes currently exposed are ``MmaF16BF16Op`` (F16/BF16),
|
||||
``MmaF8Op`` (FP8 E4M3/E5M2), and ``MmaI8Op`` (INT8/UINT8); see
|
||||
`Setting up the TiledMMA, MMA Ops`_ for their full constructor parameters,
|
||||
instruction K extents, and major-mode constraints.
|
||||
|
||||
This guide outlines the CuTe Python DSL programming model for WGMMA kernels:
|
||||
stage operands in SMEM, build fragment descriptors, launch asynchronous
|
||||
warpgroup MMAs, and stage the RMEM accumulator back to GMEM in the epilogue.
|
||||
|
||||
|
||||
.. contents:: **Contents**
|
||||
:local:
|
||||
:depth: 2
|
||||
|
||||
Global Memory (GMEM) to MMA data flow overview
|
||||
----------------------------------------------
|
||||
|
||||
WGMMA instructions require us to stage B input operands in Shared Memory (SMEM),
|
||||
while A input operands can be sourced from either SMEM or registers (RMEM).
|
||||
SMEM operands are read asynchronously by the hardware via SMEM descriptors.
|
||||
The accumulator is always kept in registers (RMEM) of the warpgroup.
|
||||
|
||||
The diagram below traces the full data flow of a WGMMA GEMM kernel, for the most
|
||||
common case where A and B matrices are stored in GMEM and both are staged through
|
||||
SMEM (``a_src=SMEM``), and the output matrix --accumulated in RMEM-- is written
|
||||
back to GMEM through an SMEM staging buffer.
|
||||
|
||||
There are 3 parallel tracks where each has 2 sub-tracks. Three parallel tracks are
|
||||
for operands A, B, and C/D, respectively. The two sub-tracks are for copying data between different memory
|
||||
spaces and for MMA execution.
|
||||
|
||||
- **Operand A** (and symmetrically **Operand B**):
|
||||
|
||||
- First, we need to create SMEM tensors for A and B matrices: ``sA`` and ``sB``. These
|
||||
tensors are physically allocated tensors that are the destination of TMA copy and
|
||||
the source operands for the WGMMA instructions.
|
||||
- Next the **data copy flow** creates the tensor views for copying data from GMEM to SMEM.
|
||||
It starts with ``mA`` tensor that represents the matrix A in global memory.
|
||||
Then ``mA`` → ``local_tile`` → ``gA`` operation creates the local tile view of A that is the
|
||||
slice of A matrix needed to compute the given CTA's output tile.
|
||||
Then ``tma_partition(tma, sA, gA)`` produces TMA views ``tAsA``, ``tAgA``,
|
||||
and the loop copies tiles from GMEM into SMEM via ``copy(tma, tAgA[k], tAsA[stage])``.
|
||||
- In parallel, the **MMA flow** turns the SMEM tensor into an iterable tensor of SMEM descriptors
|
||||
for the WGMMA instructions. ``sA`` (the same shared-memory allocation written by TMA)
|
||||
→ ``partition_A`` → ``tCsA`` (MMA-partitioned SMEM view)
|
||||
→ ``make_fragment_A`` → ``tCrA`` (SMEM descriptor passed to ``cute.gemm()``).
|
||||
Note that the SMEM descriptor is a view created from the SMEM tensor that is
|
||||
interpretable by the WGMMA instructions.
|
||||
|
||||
- **Accumulator C/D**:
|
||||
|
||||
- **RMEM accumulator flow** (gemm input/output): ``partition_C(gC)`` → ``tCgC``
|
||||
→ ``make_rmem_tensor(tCgC.shape)`` → ``acc``, which serves as both the accumulator
|
||||
input (C) and output (D) of ``cute.gemm()`` (and the WGMMA instruction).
|
||||
- **Output flow** (RMEM → SMEM → GMEM): After the main loop, the accumulator is
|
||||
type-converted and copied from registers to SMEM via ``stmatrix`` (R2S copy),
|
||||
then stored to global memory via TMA store (S2G copy):
|
||||
``mC`` → ``local_tile`` → ``gC`` → ``partition_C`` → ``tCgC`` on the destination side,
|
||||
and ``tRS_rAcc``/``tRS_sD`` / ``bSG_sD``/``bSG_gD`` views drive the two copy stages.
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
Operand A Dataflow Path Operand B Dataflow Path Accumulator C/D Dataflow Path
|
||||
─────────────────────── ─────────────────────── ─────────────────────────────
|
||||
|
||||
mA: (M, K) [GMEM] mB: (N, K) [GMEM] ┌──── RMEM ──────────┐
|
||||
│ │ │ make_rmem_tensor() │
|
||||
│ local_tile(mA, cta_tiler, coord) │ local_tile(mB, cta_tiler, coord) │ acc: accumulator │
|
||||
▼ ▼ └───────┬────────────┘
|
||||
gA: (BM, BK, k) [GMEM] gB: (BN, BK, k) [GMEM] │
|
||||
│ │ acc:(MMA,MMA_M,MMA_N) [RMEM]
|
||||
│ ┌──── SMEM ─────────┐ │ ┌──── SMEM ─────────┐ │
|
||||
│ │ sA = alloc(layout)│ │ │ sB = alloc(layout)│ │ mC: (M, N) [GMEM]
|
||||
│ └──┬────────┬───────┘ │ └──┬────────┬───────┘ │ │
|
||||
│ │ │ │ │ │ │ │ local_tile
|
||||
│ │ thr_mma.partition_A(sA) │ │ thr_mma.partition_B(sB) │ ▼
|
||||
│ │ ▼ │ │ ▼ │ gC: (BM, BN) [GMEM]
|
||||
│ │ tCsA:(MMA,MMA_M, │ │ tCsB:(MMA,MMA_N, │ │ partition_C
|
||||
│ │ MMA_K,PIPE) [SMEM] │ │ MMA_K,PIPE) [SMEM] │ ▼
|
||||
│ │ │ │ │ │ │ tCgC:(MMA,MMA_M,
|
||||
│ │ make_fragment_A(tCsA) │ │ make_fragment_B(tCsB) │ MMA_N)
|
||||
│ │ ▼ │ │ ▼ │ [GMEM] (epi dest)
|
||||
│ │ tCrA:(MMA,MMA_M, │ │ tCrB:(MMA,MMA_N, │ │
|
||||
│ │ MMA_K,PIPE) │ │ MMA_K,PIPE) │ │
|
||||
│ │ [SMEM descriptors] │ │ [SMEM descriptors] │ │
|
||||
│ │ └─────────────┐ │ │ └─────────────┐ │ │
|
||||
╰─────┤ │ ╰─────┤ │ │ │
|
||||
▼ │ ▼ │ │ │
|
||||
tma_partition(tma, │ tma_partition(tma, │ │ │
|
||||
sA, gA) │ sB, gB) │ │ │
|
||||
→ tAsA, tAgA │ → tBsB, tBgB │ │ │
|
||||
▼ │ ▼ │ │ │
|
||||
┌───┴────────────────────┐ │ ┌──────┴─────────────────┐│ │ │
|
||||
│ TMA copy loop (A path):│ │ │ TMA copy loop (B path):││ │ │
|
||||
│ copy(tma, tAgA[k], │ │ │ copy(tma, tBgB[k], ││ │ │
|
||||
│ tAsA[stage]) │ │ │ tBsB[stage]) ││ │ │
|
||||
┌─▶│ (writes into sA; │ │ ┌──▶│ (writes into sB; ││ │ │
|
||||
│ │ tCrA reads same sA) │ │ │ │ tCrB reads same sB) ││ │ │
|
||||
│ │ repeat for next k/stage│ │ │ │ repeat for next k/stage││ │ │
|
||||
│ └────────────────────────┘ │ │ └────────────────────────┘│ │ │
|
||||
│ │ │ │ │ │ │ │
|
||||
└────────┘ ▼ └─────────┘ ▼ ▼ │
|
||||
└───────┬───────────────────────────────┴───────────────────┘ │
|
||||
│ │
|
||||
▼ │
|
||||
┌──────────────────────────────────────────────┐ │
|
||||
│ GEMM Loop: | │
|
||||
│ warpgroup.fence() │ │
|
||||
│ cute.gemm(tiled_mma, │ │
|
||||
│ acc, D (output, RMEM), │ │
|
||||
┌──▶ │ tCrA[stage], A (SMEM desc -> sA), │ │
|
||||
│ │ tCrB[stage], B (SMEM desc -> sB), │ │
|
||||
│ │ acc) C (accumulator, RMEM) │ │
|
||||
│ │ warpgroup.commit_group() │ │
|
||||
│ │ warpgroup.wait_group(n) │ │
|
||||
│ └──────────────────────────────────────────────┘ │
|
||||
│ │ │ │
|
||||
└───────┘ | │
|
||||
▼ │
|
||||
Epilogue: │
|
||||
tRS_rAcc = retile(acc) │
|
||||
tRS_rD = type_convert(tRS_rAcc) │
|
||||
│ │
|
||||
▼ │
|
||||
R2S: copy(tiled_copy_r2s, tRS_rD, tRS_sD) │
|
||||
[RMEM → SMEM via stmatrix] │
|
||||
│ │
|
||||
▼ │
|
||||
sC = alloc(epi_layout) [SMEM] │
|
||||
bSG_sD, bSG_gD = tma_partition(tma_c, sC, gC) ◀───────────────────┘
|
||||
│
|
||||
▼
|
||||
S2G: copy(tma_c, bSG_sD[stage], bSG_gD[coord])
|
||||
[SMEM → GMEM via TMA store]
|
||||
|
||||
**Naming convention:**
|
||||
|
||||
* cta_tiler = (BM, BN, BK) = CTA-wide tiler dimensions
|
||||
* ``mX`` = a global tensor, e.g., (M, K) for A
|
||||
* ``gX`` = CTA-tiled GMEM slice, e.g., (BM, BK, k) for A
|
||||
* ``sX`` = SMEM allocation, e.g., (BM, BK, PIPE) for A
|
||||
* ``tAsA``/``tBsB`` = TMA-partitioned SMEM views
|
||||
* ``tAgA``/``tBgB`` = TMA-partitioned GMEM views
|
||||
* ``tCsX`` = MMA-partitioned SMEM view, e.g., (MMA, MMA_M, MMA_K, PIPE) for A
|
||||
* ``tCrX`` = SMEM descriptor fragment, e.g., (MMA, MMA_M, MMA_K, PIPE) for A
|
||||
* ``acc`` = RMEM accumulator, (MMA, MMA_M, MMA_N)
|
||||
* ``tCgC`` = MMA-partitioned GMEM, (MMA, MMA_M, MMA_N)
|
||||
* ``tRS_rAcc``/``tRS_sD`` = epilogue retile views for R2S (RMEM → SMEM) copy
|
||||
* ``bSG_sD``/``bSG_gD`` = TMA-partitioned SMEM/GMEM views for epilogue store
|
||||
* MMA = warpgroup atom thread-value layout; MMA_M/MMA_N/MMA_K = repeat counts
|
||||
(e.g., BM/inst_M), k = outer K-tiles, PIPE = pipeline stages
|
||||
|
||||
Setting up the TiledMMA, MMA Ops
|
||||
---------------------------------
|
||||
|
||||
As shown in the data flow overview, CuTe DSL provides many utilities to tile/partition
|
||||
the global memory tensors, and create fragment views of SMEM tensors for MMA instructions.
|
||||
|
||||
To utilize these functions, we need to setup the TiledMMA, MMA Ops first.
|
||||
|
||||
Creating a WGMMA Op
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
A WGMMA op describes the hardware instruction to use, it has parameters like
|
||||
data types, instruction shape, operand A source (SMEM or RMEM),
|
||||
and operand major modes.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.nvgpu import OperandMajorMode
|
||||
import cutlass.cute.nvgpu.warpgroup as warpgroup
|
||||
|
||||
op = warpgroup.MmaF16BF16Op(
|
||||
cutlass.Float16, # A/B element type
|
||||
cutlass.Float32, # accumulator type
|
||||
(64, 128, 16), # instruction shape (M, N, K)
|
||||
warpgroup.OperandSource.SMEM, # A operand from shared memory
|
||||
OperandMajorMode.K, # A is K-major
|
||||
OperandMajorMode.K, # B is K-major
|
||||
)
|
||||
|
||||
The key parameters are:
|
||||
|
||||
- **Instruction shape** ``(M, N, K)``: determines the size of one hardware MMA
|
||||
instruction. WGMMA requires ``M = 64`` and ``8 <= N <= 256`` in steps of 8.
|
||||
K is fixed by the op class (16 for F16/BF16, 32 for FP8 and INT8).
|
||||
- **OperandSource**: ``SMEM`` reads A from a shared memory descriptor; ``RMEM``
|
||||
reads A directly from registers.
|
||||
- **OperandMajorMode**: ``K`` for K-major (default), ``MN`` for transposed layout.
|
||||
F16/BF16 supports both K-major and MN-major for A and B when ``a_src=SMEM``;
|
||||
when ``a_src=RMEM``, only B can be transposed. FP8 and INT8 are K-major only.
|
||||
|
||||
|
||||
CuTe DSL provides implementation of the following WGMMA ops:
|
||||
|
||||
.. list-table:: WGMMA ops
|
||||
:header-rows: 1
|
||||
:widths: 30 24 46
|
||||
|
||||
* - PTX name
|
||||
- Python class
|
||||
- Constructor parameters
|
||||
* - ``wgmma.mma_async.m64n{N}k16.{acc}.f16.f16`` / ``.bf16.bf16``
|
||||
- ``warpgroup.MmaF16BF16Op``
|
||||
- ``ab_dtype, acc_dtype, instruction_shape, a_src, a_major_mode, b_major_mode``
|
||||
* - ``wgmma.mma_async.m64n{N}k32.{acc}.{e4m3|e5m2}.{e4m3|e5m2}``
|
||||
- ``warpgroup.MmaF8Op``
|
||||
- ``a_dtype, b_dtype, acc_dtype, instruction_shape, a_src, a_major_mode, b_major_mode``
|
||||
* - ``wgmma.mma_async.m64n{N}k32.s32.{s8|u8}.{s8|u8}``
|
||||
- ``warpgroup.MmaI8Op``
|
||||
- ``a_dtype, b_dtype, acc_dtype, instruction_shape, a_src, a_major_mode, b_major_mode``
|
||||
|
||||
|
||||
Creating a Tiled MMA
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
A ``TiledMma`` tiles the WGMMA atom across the CTA tile. You can pass the op
|
||||
directly or create an explicit atom first.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Option 1: directly from op (common shorthand)
|
||||
tiled_mma = cute.make_tiled_mma(op)
|
||||
|
||||
# Option 2: explicit atom creation
|
||||
atom = cute.make_mma_atom(op)
|
||||
tiled_mma = cute.make_tiled_mma(atom)
|
||||
|
||||
Spatial tiling with a repeat count
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
A repeat tuple ``(M_rep, N_rep, K_rep)`` replicates the WGMMA atom across the
|
||||
M, N, and K dimensions, producing a larger tiled MMA that covers a bigger CTA
|
||||
tile with a single ``cute.gemm`` call. Each entry in the repeat tuple
|
||||
corresponds to one **warpgroup** (128 threads / 4 warps), so ``(2, 1, 1)``
|
||||
uses two warpgroups — the standard configuration for large Hopper tiles:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
atom = cute.make_mma_atom(op) # op shape: (64, 128, 16)
|
||||
tiled_mma = cute.make_tiled_mma(
|
||||
atom,
|
||||
atom_layout_mnk=(2, 1, 1), # 2 warpgroups in M
|
||||
)
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
WGMMA Atom make_tiled_mma(atom, (2, 1, 1))
|
||||
+---------------+ +----------------+
|
||||
| | | | ^
|
||||
| 64 x 128 | | Atom (0,0,0) | |
|
||||
| x 16 | --(2,1,1)--> | 64 x 128 | | 2 x M_atom
|
||||
| | repeat | x 16 | | = 128
|
||||
| | | [Warpgroup 0] | |
|
||||
+---------------+ +----------------+ |
|
||||
| | |
|
||||
| Atom (1,0,0) | |
|
||||
| 64 x 128 | |
|
||||
| x 16 | |
|
||||
| [Warpgroup 1] | v
|
||||
+----------------+
|
||||
<-- N_atom = 128 -->
|
||||
K unchanged = 16
|
||||
|
||||
The Hopper dense GEMM examples
|
||||
(``examples/cute/hopper/kernel/dense_gemm/dense_gemm.py``) use this pattern.
|
||||
The helper ``sm90_utils.make_trivial_tiled_mma(...)`` selects the repeat count
|
||||
automatically:
|
||||
|
||||
- ``atom_layout_mnk = (2, 1, 1)`` when both ``tile_M > 64`` and
|
||||
``tile_N > 128`` (two warpgroups reduce register pressure).
|
||||
- ``atom_layout_mnk = (1, 1, 1)`` otherwise (a single warpgroup suffices).
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import cutlass.utils.hopper_helpers as sm90_utils
|
||||
|
||||
tiled_mma = sm90_utils.make_trivial_tiled_mma(
|
||||
a_dtype,
|
||||
b_dtype,
|
||||
a_major_mode,
|
||||
b_major_mode,
|
||||
acc_dtype,
|
||||
atom_layout_mnk=(2, 1, 1),
|
||||
tiler_mn=(64, 128), # atom instruction shape (M, N)
|
||||
)
|
||||
|
||||
Custom tile permutation with ``permutation_mnk``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
``make_tiled_mma`` also accepts an optional ``permutation_mnk`` argument that
|
||||
controls how the tiled atom footprint is laid out across M, N, and K. At a
|
||||
high level:
|
||||
|
||||
- ``atom_layout_mnk`` tells CuTe how many atoms (warpgroups) to replicate.
|
||||
- ``permutation_mnk`` tells CuTe how the final tiled footprint is ordered.
|
||||
|
||||
``permutation_mnk`` is a tuple of layouts or integers that represent the
|
||||
tile size and ordering of values along each dimension. When a mode's
|
||||
permutation size is larger than the atom layout's natural coverage
|
||||
(``atom_layout x inst_shape``), each warpgroup receives additional values
|
||||
to fill the extended region — the warpgroup count stays the same, but each
|
||||
warpgroup holds more data.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
atom = cute.make_mma_atom(op) # op shape: (64, 128, 16)
|
||||
tiled_mma = cute.make_tiled_mma(
|
||||
atom,
|
||||
atom_layout_mnk=(2, 1, 1),
|
||||
permutation_mnk=(128, 256, 16), # extend N from 128 to 256
|
||||
)
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
Without permutation — natural atom coverage (M = 128, N = 128):
|
||||
|
||||
C tile (M=128, N=128)
|
||||
+----------------+
|
||||
| | ^
|
||||
| [Warpgroup 0] | |
|
||||
| 64 x 128 | | 2 x inst_M
|
||||
| | | = 128
|
||||
+----------------+ |
|
||||
| | |
|
||||
| [Warpgroup 1] | |
|
||||
| 64 x 128 | |
|
||||
| | v
|
||||
+----------------+
|
||||
<--- N = 128 --->
|
||||
(each warpgroup owns one (64, 128) atom)
|
||||
|
||||
With permutation_mnk = (128, 256, 16) — N extended to 256:
|
||||
|
||||
C tile (M=128, N=256)
|
||||
+----------------+----------------+
|
||||
| | | ^ N = 128 → 256:
|
||||
| [Warpgroup 0] | [Warpgroup 0] | | atom pattern repeats
|
||||
| 64 x 128 | 64 x 128 | | along N. Each warpgroup
|
||||
| | | | now holds 2x the values
|
||||
+----------------+----------------+ | along N (same threads,
|
||||
| | | | more data).
|
||||
| [Warpgroup 1] | [Warpgroup 1] | |
|
||||
| 64 x 128 | 64 x 128 | |
|
||||
| | | v
|
||||
+----------------+----------------+
|
||||
<------------ N = 256 ------------>
|
||||
| atom coverage | value repeat |
|
||||
|
||||
**Why WGMMA typically does not need permutation_mnk:** The WGMMA
|
||||
instruction already has a large N dimension (64, 128, or 256), so the natural
|
||||
atom coverage is wide enough that no permutation is needed to align with SMEM
|
||||
swizzle widths. The Hopper
|
||||
dense GEMM examples (``dense_gemm.py``, ``dense_gemm_persistent.py``) use
|
||||
``atom_layout_mnk`` alone without ``permutation_mnk``.
|
||||
|
||||
When ``permutation_mnk`` is not provided (default), the tile ordering is
|
||||
sequential and no permutation is applied.
|
||||
|
||||
|
||||
Partitioning Tensors
|
||||
---------------------
|
||||
|
||||
Before computing, partition the CTA-tiled tensors according to the
|
||||
tiled MMA layout. WGMMA partitioning is **warpgroup-oriented**: each
|
||||
warpgroup (128 threads / 4 warps) receives its own slice of the CTA
|
||||
tile, sized to match the SMEM descriptors and register accumulators
|
||||
that the WGMMA instruction expects.
|
||||
|
||||
**2-warpgroup example**
|
||||
|
||||
``GEMM (M, N, K) = (512, 768, 256)``, ``tile_shape_mnk = (128, 256, 64)``,
|
||||
F16 WGMMA atom = (64, 256, 16), ``atom_layout_mnk = (2, 1, 1)``,
|
||||
``num_stages = 4``, 2 warpgroups = 256 threads.
|
||||
|
||||
Global matrices:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
mA: (M, K) = (512, 256) mB: (N, K) = (768, 256) mC: (M, N) = (512, 768)
|
||||
|
||||
K=256 K=256 N=768
|
||||
|<--------->| |<--------->| |<----------------->|
|
||||
+-----------+ +-----------+ +---+---+---+-------+
|
||||
| | ^ | | ^ | | | | | ^
|
||||
| mA | | M=512 | mB | | N=768 | | | | | | M=512
|
||||
| | v | | v | | | | | v
|
||||
+-----------+ +-----------+ +---+---+---+-------+
|
||||
|
||||
Tiling with ``tile_shape_mnk = (BM, BN, BK) = (128, 256, 64)`` gives
|
||||
M/BM = 4 tiles, N/BN = 3 tiles, K/BK = 4 tiles:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
mA tiled into (M/BM x K/BK) mB tiled into (N/BN x K/BK) mC tiled into (M/BM x N/BN)
|
||||
= (4 x 4) blocks = (3 x 4) blocks = (4 x 3) blocks
|
||||
|
||||
BK=64 x4 BK=64 x4 BN=256 x3
|
||||
|<--->| |<--->| |<------>|
|
||||
+-----+-----+-----+-----+ +-----+-----+-----+-----+ +--------+--------+--------+
|
||||
| | | | | ^ | | | | | ^ | (0,0) | (0,1) | (0,2) | ^
|
||||
| | | | | |128 | | | | | |256 | | | | |128
|
||||
+-----+-----+-----+-----+ v +-----+-----+-----+-----+ v +--------+--------+--------+ v
|
||||
| | | | | ^ | | | | | ^ | (1,0) | (1,1) | (1,2) | ^
|
||||
| | | | | |128 | | | | | |256 | | | | |128
|
||||
+-----+-----+-----+-----+ v +-----+-----+-----+-----+ v +--------+--------+--------+ v
|
||||
| | | | | | | | | | | (2,0) | (2,1) | (2,2) |
|
||||
+-----+-----+-----+-----+ +-----+-----+-----+-----+ +--------+--------+--------+
|
||||
| | | | | | (3,0) | (3,1) | (3,2) |
|
||||
+-----+-----+-----+-----+ +--------+--------+--------+
|
||||
|
||||
Each CTA picks one (M-tile, N-tile) coordinate.
|
||||
For example, CTA at ``tile_coord = (1, 0, :)``.
|
||||
|
||||
After ``local_tile`` — one CTA's tile (``k = K/BK = 256/64 = 4``):
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
gA: (BM, BK, k) = (128, 64, 4) gB: (BN, BK, k) = (256, 64, 4) gC: (BM, BN) = (128, 256)
|
||||
|
||||
BK=64 BK=64 BN=256
|
||||
|<----->| |<----->| |<--------->|
|
||||
+-------+-- +-------+-- +-----------+
|
||||
| |.. | |.. | | ^
|
||||
BM= | gA | k=4 BN= | gB | k=4 BM= | gC | | 128
|
||||
128 | | 256 | | 128 | | v
|
||||
+-------+ +-------+ +-----------+
|
||||
|
||||
SMEM tensors ``sA`` and ``sB`` include a pipeline staging dimension:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
sA: (BM, BK, PIPE) = (128, 64, 4) sB: (BN, BK, PIPE) = (256, 64, 4)
|
||||
|
||||
``get_slice(warp_group_thread_layout(warp_group_idx))`` — each
|
||||
warpgroup receives its slice of the tiled MMA footprint.
|
||||
With ``atom_layout_mnk = (2, 1, 1)`` and inst shape ``(64, 256, 16)``,
|
||||
the tiled MMA covers ``(2x64, 1x256, 16) = (128, 256, 16)`` which
|
||||
exactly matches the CTA tile in M and N. Each warpgroup owns one
|
||||
64-row slice of M:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
sA (one pipeline stage, BM=128, BK=64):
|
||||
|
||||
Warpgroup 0's slice Warpgroup 1's slice
|
||||
inst_K inst_K inst_K inst_K
|
||||
=16 =16 =16 =16
|
||||
|<--->|<--->|<--->|<--->| |<--->|<--->|<--->|<--->|
|
||||
+-----+-----+-----+-----+ ^ +-----+-----+-----+-----+ ^
|
||||
| 0 | 1 | 2 | 3 | |64 | 0 | 1 | 2 | 3 | |64
|
||||
+-----+-----+-----+-----+ v +-----+-----+-----+-----+ v
|
||||
|<-- MMA_K = BK/inst_K = 4 -->| |<-- MMA_K = 4 ---------->|
|
||||
MMA_M = 64/64 = 1 MMA_M = 64/64 = 1
|
||||
|
||||
gC (BM=128, BN=256):
|
||||
|
||||
+---------------------------+ ^
|
||||
| Warpgroup 0: 64 x 256 | | 64
|
||||
| | |
|
||||
+---------------------------+ v
|
||||
| Warpgroup 1: 64 x 256 | ^
|
||||
| | | 64
|
||||
+---------------------------+ v
|
||||
<--------- N = 256 -------->
|
||||
MMA_M = 64/64 = 1, MMA_N = 256/256 = 1
|
||||
|
||||
After partition (per warpgroup):
|
||||
|
||||
- ``tCsA: (MMA, MMA_M, MMA_K, PIPE) = (MMA, 1, 4, 4)`` — MMA_M = BM / (atom_M x inst_M) = 128 / (2x64) = 1, MMA_K = BK / inst_K = 64 / 16 = 4
|
||||
- ``tCsB: (MMA, MMA_N, MMA_K, PIPE) = (MMA, 1, 4, 4)`` — MMA_N = BN / (atom_N x inst_N) = 256 / (1x256) = 1, MMA_K = 4
|
||||
- ``tCgC: (MMA, MMA_M, MMA_N) = (MMA, 1, 1)`` — MMA_M = 1, MMA_N = 1
|
||||
|
||||
The first mode ``MMA`` contains the atom's **thread x value** layout — it
|
||||
encodes which registers within a warpgroup hold which matrix elements.
|
||||
The remaining modes are repeat counts that tile the atom across the
|
||||
full CTA tile.
|
||||
|
||||
.. note:: Because the WGMMA instruction shape is large (64 x {64..256}),
|
||||
the tiled MMA footprint typically covers the entire CTA tile in M and N
|
||||
with just one or two warpgroups. This means MMA_M and MMA_N are often 1.
|
||||
The MMA_K dimension is where the repeat count is non-trivial (BK / inst_K
|
||||
iterations per pipeline stage).
|
||||
|
||||
**1-warpgroup example (contrast)**
|
||||
|
||||
For a smaller tile ``(128, 128, 64)`` with ``atom_layout_mnk = (1, 1, 1)``,
|
||||
inst shape ``(64, 128, 16)``, and ``num_stages = 4``,
|
||||
the tiled MMA covers only ``(64, 128, 16)``.
|
||||
Now a single warpgroup must iterate over two atom-blocks along M:
|
||||
|
||||
- ``tCsA: (MMA, MMA_M, MMA_K, PIPE) = (MMA, 2, 4, 4)`` — MMA_M = 128 / (1x64) = 2
|
||||
- ``tCsB: (MMA, MMA_N, MMA_K, PIPE) = (MMA, 1, 4, 4)`` — MMA_N = 128 / (1x128) = 1
|
||||
- ``tCgC: (MMA, MMA_M, MMA_N) = (MMA, 2, 1)``
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Based on examples/cute/hopper/kernel/dense_gemm/dense_gemm.py
|
||||
@cute.kernel
|
||||
def kernel(tiled_mma: cute.TiledMma, ...):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
|
||||
# CTA-tiled global tensors
|
||||
gA_mkl = cute.local_tile(
|
||||
mA_mkl, tile_shape_mnk, tile_coord_mnkl, proj=(1, None, 1)
|
||||
)
|
||||
gB_nkl = cute.local_tile(
|
||||
mB_nkl, tile_shape_mnk, tile_coord_mnkl, proj=(None, 1, 1)
|
||||
)
|
||||
gC_mnl = cute.local_tile(
|
||||
mC_mnl, tile_shape_mnk, tile_coord_mnkl, proj=(1, 1, None)
|
||||
)
|
||||
|
||||
# Warpgroup-oriented slicing (128 threads per warpgroup)
|
||||
warp_group_idx = cute.arch.make_warp_uniform(
|
||||
tidx // num_threads_per_warp_group # 128
|
||||
)
|
||||
warp_group_thread_layout = cute.make_layout(
|
||||
mma_warp_groups, # e.g. 2
|
||||
stride=num_threads_per_warp_group, # 128
|
||||
)
|
||||
thr_mma = tiled_mma.get_slice(
|
||||
warp_group_thread_layout(warp_group_idx)
|
||||
)
|
||||
|
||||
# Partition C from global
|
||||
tCgC = thr_mma.partition_C(gC_mnl) # (MMA, MMA_M, MMA_N)
|
||||
|
||||
# Partition A/B from staged SMEM
|
||||
tCsA = thr_mma.partition_A(sA) # (MMA, MMA_M, MMA_K, PIPE)
|
||||
tCsB = thr_mma.partition_B(sB) # (MMA, MMA_N, MMA_K, PIPE)
|
||||
|
||||
|
||||
Pre and Post-Conditions for Partitioning
|
||||
-----------------------------------------
|
||||
|
||||
* The inputs of ``partition_A``, ``partition_B``, and ``partition_C`` should be
|
||||
at least rank-2 tensors.
|
||||
* The output layout is constrained by the selected MMA atom:
|
||||
|
||||
- For A, the output has layout ``(MMA, MMA_M, MMA_K, ...)``.
|
||||
- For B, the output has layout ``(MMA, MMA_N, MMA_K, ...)``.
|
||||
- For C, the output has layout ``(MMA, MMA_M, MMA_N, ...)``.
|
||||
|
||||
* Partitioning reasons about layout, not memory space or element type.
|
||||
When ``a_src=OperandSource.RMEM``, the same tiled MMA shape still
|
||||
determines the logical A footprint, but A is materialized as a register
|
||||
fragment rather than a shared-memory descriptor.
|
||||
|
||||
|
||||
Making Fragments
|
||||
-----------------
|
||||
|
||||
Fragments are the tensors that the WGMMA instruction operates on. For dense
|
||||
WGMMA:
|
||||
|
||||
- **Fragment A**: an SMEM descriptor when ``a_src=OperandSource.SMEM``, or an
|
||||
RMEM register fragment when ``a_src=OperandSource.RMEM``.
|
||||
- **Fragment B**: an SMEM descriptor pointing into staged shared memory buffers.
|
||||
- **Fragment C (accumulator)**: an RMEM tensor that serves as both the input C
|
||||
and output D of ``cute.gemm()``.
|
||||
|
||||
WGMMA fragments for A and B are **SMEM descriptors** — the hardware reads
|
||||
directly from shared memory. There is no explicit SMEM → RMEM copy step for
|
||||
operands A and B. The accumulator, however, still lives in per-thread
|
||||
registers (RMEM).
|
||||
|
||||
Creating fragment descriptors and accumulator fragments
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Fragment creation has two parts:
|
||||
|
||||
**1. A and B fragment descriptors**
|
||||
|
||||
``make_fragment_A`` and ``make_fragment_B`` take the MMA-partitioned SMEM
|
||||
views (``tCsA`` / ``tCsB``) and produce descriptor tensors that the WGMMA
|
||||
instruction consumes. Each descriptor points to one tile within a pipeline
|
||||
stage in shared memory.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# MMA-partitioned SMEM views (see "Partitioning Tensors")
|
||||
tCsA = thr_mma.partition_A(sA) # (MMA, MMA_M, MMA_K, PIPE)
|
||||
tCsB = thr_mma.partition_B(sB) # (MMA, MMA_N, MMA_K, PIPE)
|
||||
|
||||
# SMEM descriptor fragments consumed by cute.gemm()
|
||||
tCrA = tiled_mma.make_fragment_A(tCsA) # (MMA, MMA_M, MMA_K, PIPE)
|
||||
tCrB = tiled_mma.make_fragment_B(tCsB) # (MMA, MMA_N, MMA_K, PIPE)
|
||||
|
||||
Continuing the 2-warpgroup example from `Partitioning Tensors`_
|
||||
(F16 atom = (64, 256, 16), ``tile_shape_mnk = (128, 256, 64)``,
|
||||
``atom_layout_mnk = (2, 1, 1)``, ``num_stages = 4``):
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
tCsA: (MMA, MMA_M=1, MMA_K=4, PIPE=4)
|
||||
tCsB: (MMA, MMA_N=1, MMA_K=4, PIPE=4)
|
||||
|
||||
make_fragment_A(tCsA) -> tCrA: (MMA, 1, 4, 4)
|
||||
make_fragment_B(tCsB) -> tCrB: (MMA, 1, 4, 4)
|
||||
|
||||
Each element of tCrA/tCrB is an SMEM descriptor — one per
|
||||
(MMA_K, PIPE) pair. The hardware reads SMEM directly via the
|
||||
descriptor; no explicit SMEM -> RMEM load is needed.
|
||||
|
||||
tCrA per warpgroup (4 pipeline stages, 4 K-blocks each):
|
||||
|
||||
|<-- MMA_K = BK/inst_K = 4 -->|
|
||||
stage 0: +------+------+------+------+
|
||||
| k=0 | k=1 | k=2 | k=3 | inst_M=64 (MMA_M=1)
|
||||
+------+------+------+------+
|
||||
stage 1: +------+------+------+------+
|
||||
| k=0 | k=1 | k=2 | k=3 | inst_M=64
|
||||
+------+------+------+------+
|
||||
stage 2: +------+------+------+------+
|
||||
| k=0 | k=1 | k=2 | k=3 | inst_M=64
|
||||
+------+------+------+------+
|
||||
stage 3: +------+------+------+------+
|
||||
| k=0 | k=1 | k=2 | k=3 | inst_M=64
|
||||
+------+------+------+------+
|
||||
|
||||
Similarly for tCrB with shape (MMA, MMA_N=1, MMA_K=4, PIPE=4).
|
||||
|
||||
.. note:: WGMMA fragments for A and B are SMEM descriptors — the hardware
|
||||
reads SMEM directly, so there is no ``ldmatrix`` retiling step required
|
||||
before ``cute.gemm()``.
|
||||
|
||||
**When A comes from registers (``OperandSource.RMEM``)**
|
||||
|
||||
In fused kernels, the output of one MMA can become the A operand of the
|
||||
next. The second ``TiledMma`` is created with
|
||||
``a_src=OperandSource.RMEM``, and ``make_fragment_A`` is **not** used.
|
||||
Instead:
|
||||
|
||||
1. The accumulator's C layout ``(MMA, MMA_M, MMA_N)`` is converted to the
|
||||
A layout ``(MMA, MMA_M, MMA_K)`` expected by the second ``TiledMma``.
|
||||
2. The accumulator values are type-converted and stored into an RMEM tensor
|
||||
with the A layout.
|
||||
3. The resulting RMEM tensor is passed directly to ``cute.gemm()`` as the A
|
||||
operand — no SMEM descriptor is involved.
|
||||
|
||||
See the Hopper FMHA example (``examples/cute/hopper/kernel/attention/fmha.py``) for the complete pattern.
|
||||
|
||||
**2. C fragment (accumulator)**
|
||||
|
||||
The accumulator lives in per-thread registers (RMEM). Its shape is derived
|
||||
from the partitioned C layout. The accumulator starts at zero before the K
|
||||
loop and is updated in-place by each ``cute.gemm()`` call.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Partition C from global (see "Partitioning Tensors")
|
||||
tCgC = thr_mma.partition_C(gC_mnl) # (MMA, MMA_M, MMA_N)
|
||||
|
||||
# Allocate RMEM accumulator with the same shape
|
||||
acc_shape = tCgC.shape
|
||||
acc = cute.make_rmem_tensor(acc_shape, cutlass.Float32)
|
||||
acc.fill(0.0)
|
||||
|
||||
For the same running example:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
tCgC: (MMA, MMA_M=1, MMA_N=1)
|
||||
|
||||
make_rmem_tensor(tCgC.shape, Float32) -> acc: (MMA, 1, 1)
|
||||
|
||||
The accumulator stays in RMEM for the entire main loop.
|
||||
cute.gemm() reads A/B from SMEM descriptors and accumulates into acc.
|
||||
|
||||
+-----------------------------------+
|
||||
| acc: (MMA, 1, 1) in RMEM |
|
||||
| 64 x 256 elements per warpgroup |
|
||||
| Float32 |
|
||||
+-----------------------------------+
|
||||
|
||||
|
||||
Creating SMEM layouts for A and B
|
||||
----------------------------------
|
||||
|
||||
The SMEM layouts define how A and B tiles are staged in shared memory,
|
||||
including swizzling for bank-conflict-free descriptor access. The helper
|
||||
functions in ``cutlass.utils.hopper_helpers`` handle the details.
|
||||
|
||||
**Host side** (``@cute.jit``):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import cutlass.utils.hopper_helpers as sm90_utils
|
||||
|
||||
# Create SMEM layouts (includes swizzle + staging)
|
||||
a_smem_layout = sm90_utils.make_smem_layout_a(
|
||||
a_layout, # LayoutEnum — row-major or col-major
|
||||
tile_shape_mnk, # CTA tile (M, N, K)
|
||||
a_dtype, # element type (e.g. Float16)
|
||||
num_stages, # pipeline depth
|
||||
)
|
||||
b_smem_layout = sm90_utils.make_smem_layout_b(
|
||||
b_layout,
|
||||
tile_shape_mnk,
|
||||
b_dtype,
|
||||
num_stages,
|
||||
)
|
||||
epi_smem_layout = sm90_utils.make_smem_layout_epi(
|
||||
c_dtype,
|
||||
c_layout,
|
||||
epi_tile,
|
||||
epi_stage,
|
||||
)
|
||||
|
||||
``make_smem_layout_a`` and ``make_smem_layout_b`` are convenience helpers that
|
||||
build a complete, staged SMEM layout in four steps:
|
||||
|
||||
1. **Extract the operand tile shape.** For A the ``(M, K)`` portion of
|
||||
``tile_shape_mnk`` is kept via ``cute.slice_``; for B the ``(N, K)``
|
||||
portion.
|
||||
|
||||
2. **Determine the major mode.** The major mode (K-major or MN-major) is read
|
||||
from the layout enum (``a_layout.is_k_major_a()``). The major-mode
|
||||
dimension size is used for swizzle selection.
|
||||
|
||||
3. **Select and materialise the swizzle atom.** A heuristic
|
||||
(``get_smem_layout_atom``) picks the widest swizzle whose contiguous
|
||||
size (in bits) evenly divides the major-mode dimension:
|
||||
|
||||
+------------+-----------------+
|
||||
| Swizzle | Contiguous bits |
|
||||
+============+=================+
|
||||
| SW128 | 1024 (128 B) |
|
||||
+------------+-----------------+
|
||||
| SW64 | 512 (64 B) |
|
||||
+------------+-----------------+
|
||||
| SW32 | 256 (32 B) |
|
||||
+------------+-----------------+
|
||||
| Interleave | 128 (16 B) |
|
||||
+------------+-----------------+
|
||||
|
||||
``make_smem_layout_atom`` then combines the chosen swizzle with a compact
|
||||
outer layout into a ``ComposedLayout(swizzle, outer)``.
|
||||
|
||||
4. **Tile to the operand shape and append the staging dimension.**
|
||||
``cute.tile_to_shape`` broadcasts the atom to the full ``(M_or_N, K)``
|
||||
shape with ``num_stages`` appended. The ``order`` argument controls which
|
||||
dimension is contiguous: ``(0, 1, 2)`` for K-major (K innermost),
|
||||
``(1, 0, 2)`` for MN-major (MN innermost).
|
||||
|
||||
For the running F16 example (``tile_shape_mnk = (128, 256, 64)``,
|
||||
``num_stages = 4``, K-major A, K-major B):
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
A operand (K-major, tile = (M=128, K=64)):
|
||||
major_mode_size = 64
|
||||
64 * 16 bits = 1024 bits → SW128
|
||||
atom = make_smem_layout_atom(K_SW128, Float16)
|
||||
tile_to_shape(atom, (128, 64, 4), order=(0,1,2))
|
||||
-> a_smem_layout: ComposedLayout with shape (128, 64, 4)
|
||||
|
||||
B operand (K-major, tile = (N=256, K=64)):
|
||||
major_mode_size = 64
|
||||
64 * 16 bits = 1024 bits → SW128
|
||||
atom = make_smem_layout_atom(K_SW128, Float16)
|
||||
tile_to_shape(atom, (256, 64, 4), order=(0,1,2))
|
||||
-> b_smem_layout: ComposedLayout with shape (256, 64, 4)
|
||||
|
||||
**Kernel side** (``@cute.kernel``):
|
||||
|
||||
The layout and swizzle are passed to shared-memory allocation. The result
|
||||
is a ``ComposedLayout`` whose ``.outer`` is the logical layout and ``.inner``
|
||||
is the swizzle:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Based on examples/cute/hopper/kernel/dense_gemm/dense_gemm.py
|
||||
sA = storage.sA.get_tensor(
|
||||
a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner
|
||||
)
|
||||
sB = storage.sB.get_tensor(
|
||||
b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner
|
||||
)
|
||||
|
||||
After allocation:
|
||||
|
||||
- ``sA`` has shape ``(BM, BK, PIPE) = (128, 64, 4)``.
|
||||
- ``sB`` has shape ``(BN, BK, PIPE) = (256, 64, 4)``.
|
||||
|
||||
These are the staged SMEM tensors consumed by ``partition_A`` /
|
||||
``partition_B`` and ``make_fragment_A`` / ``make_fragment_B``
|
||||
(see `Making Fragments`_).
|
||||
|
||||
.. note:: If you need finer control, you can build layout atoms directly with
|
||||
``cute.nvgpu.warpgroup.make_smem_layout_atom(...)`` and compose the final
|
||||
SMEM layout manually via ``cute.tile_to_shape``.
|
||||
|
||||
|
||||
Executing the GEMM (Main Loop)
|
||||
-------------------------------
|
||||
|
||||
The main loop iterates over K-tiles. The WGMMA-specific part of each
|
||||
iteration is the **fence / gemm / commit / wait** sequence:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
acc.fill(0.0)
|
||||
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
|
||||
|
||||
for k_tile in cutlass.range(k_pipe_mmas, k_tile_cnt, 1, unroll=1):
|
||||
# ... wait for TMA load (pipeline details in dense_gemm.py) ...
|
||||
|
||||
cute.nvgpu.warpgroup.fence()
|
||||
tile_crd = (None, None, None, consumer_read.index)
|
||||
cute.gemm(tiled_mma, acc, tCrA[tile_crd], tCrB[tile_crd], acc)
|
||||
cute.nvgpu.warpgroup.commit_group()
|
||||
cute.nvgpu.warpgroup.wait_group(k_pipe_mmas)
|
||||
|
||||
# ... release buffer & advance pipeline (see dense_gemm.py) ...
|
||||
|
||||
cute.nvgpu.warpgroup.wait_group(0)
|
||||
|
||||
Key points:
|
||||
|
||||
- ``fence()`` orders prior SMEM writes before WGMMA issue.
|
||||
- ``commit_group()`` publishes queued WGMMA instructions as a group.
|
||||
- ``wait_group(n)`` waits until at most ``n`` groups remain in flight.
|
||||
``wait_group(0)`` after the loop drains all work before the epilogue.
|
||||
- ``Field.ACCUMULATE`` — ``True`` accumulates (``D += A*B``),
|
||||
``False`` overwrites (``D = A*B``). The dense GEMM sets ``True`` and
|
||||
zero-fills ``acc`` so the first iteration computes ``0 + A*B``.
|
||||
|
||||
|
||||
Complete Workflow
|
||||
------------------
|
||||
|
||||
Putting it all together, a typical Hopper WGMMA GEMM has this structure.
|
||||
The MMA-relevant steps are highlighted; see ``dense_gemm.py`` for the full
|
||||
kernel including TMA, pipeline, and epilogue details.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.nvgpu import OperandMajorMode
|
||||
import cutlass.cute.nvgpu.warpgroup as warpgroup
|
||||
import cutlass.utils.hopper_helpers as sm90_utils
|
||||
|
||||
# --- Host side (@cute.jit) ---
|
||||
|
||||
# 1. MMA op + tiled MMA
|
||||
op = warpgroup.MmaF16BF16Op(
|
||||
cutlass.Float16, cutlass.Float32, (64, 128, 16),
|
||||
warpgroup.OperandSource.SMEM, OperandMajorMode.K, OperandMajorMode.K,
|
||||
)
|
||||
tiled_mma = cute.make_tiled_mma(op)
|
||||
|
||||
# 2. SMEM layouts
|
||||
a_smem_layout = sm90_utils.make_smem_layout_a(a_layout, tile_shape_mnk, a_dtype, num_stages)
|
||||
b_smem_layout = sm90_utils.make_smem_layout_b(b_layout, tile_shape_mnk, b_dtype, num_stages)
|
||||
|
||||
# 3. TMA copy atoms + kernel launch (see dense_gemm.py)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# --- Kernel side (@cute.kernel) ---
|
||||
|
||||
# 4. Allocate SMEM
|
||||
smem = cutlass.utils.SmemAllocator()
|
||||
storage = smem.allocate(SharedStorage)
|
||||
sA = storage.sA.get_tensor(
|
||||
a_smem_layout.outer, swizzle=a_smem_layout.inner) # (BM, BK, PIPE)
|
||||
sB = storage.sB.get_tensor(
|
||||
b_smem_layout.outer, swizzle=b_smem_layout.inner) # (BN, BK, PIPE)
|
||||
|
||||
# 5. CTA-tiled global tensors
|
||||
gA_mkl = cute.local_tile(mA_mkl, tile_shape_mnk, tile_coord, proj=(1, None, 1))
|
||||
gB_nkl = cute.local_tile(mB_nkl, tile_shape_mnk, tile_coord, proj=(None, 1, 1))
|
||||
gC_mnl = cute.local_tile(mC_mnl, tile_shape_mnk, tile_coord, proj=(1, 1, None))
|
||||
|
||||
# 6. Warpgroup slice, partition & make fragments
|
||||
warp_group_idx = cute.arch.make_warp_uniform(tidx // num_threads_per_warp_group)
|
||||
warp_group_thread_layout = cute.make_layout(mma_warp_groups, stride=num_threads_per_warp_group)
|
||||
thr_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx))
|
||||
|
||||
tCsA = thr_mma.partition_A(sA) # (MMA, MMA_M, MMA_K, PIPE)
|
||||
tCsB = thr_mma.partition_B(sB) # (MMA, MMA_N, MMA_K, PIPE)
|
||||
tCrA = tiled_mma.make_fragment_A(tCsA) # SMEM descriptor
|
||||
tCrB = tiled_mma.make_fragment_B(tCsB) # SMEM descriptor
|
||||
tCgC = thr_mma.partition_C(gC_mnl) # (MMA, MMA_M, MMA_N)
|
||||
acc = cute.make_rmem_tensor(tCgC.shape, acc_dtype)
|
||||
|
||||
# 7. TMA pipeline setup + prefetch (see dense_gemm.py)
|
||||
|
||||
# 8. Main loop — fence / gemm / commit / wait
|
||||
acc.fill(0.0)
|
||||
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
|
||||
|
||||
for k_tile in cutlass.range(k_pipe_mmas, k_tile_cnt, 1, unroll=1):
|
||||
# ... wait for TMA load ...
|
||||
cute.nvgpu.warpgroup.fence()
|
||||
tile_crd = (None, None, None, consumer_read.index)
|
||||
cute.gemm(tiled_mma, acc, tCrA[tile_crd], tCrB[tile_crd], acc)
|
||||
cute.nvgpu.warpgroup.commit_group()
|
||||
cute.nvgpu.warpgroup.wait_group(k_pipe_mmas)
|
||||
# ... release buffer, advance pipeline ...
|
||||
|
||||
cute.nvgpu.warpgroup.wait_group(0)
|
||||
|
||||
# 9. Epilogue: RMEM → SMEM (stmatrix) → GMEM (TMA store)
|
||||
# ... (see dense_gemm.py)
|
||||
|
||||
|
||||
.. Beyond Simple Dense MMAs
|
||||
.. ------------------------
|
||||
|
||||
.. The current Python DSL coverage for warpgroup MMA is centered on the three
|
||||
.. dense ops above. PTX also defines additional WGMMA instruction families that
|
||||
.. do **not** yet have DSL op classes. These are tracked in the source at
|
||||
.. ``cutlass/cute/nvgpu/warpgroup/mma.py`` (marked ``✗`` in the instruction
|
||||
.. table).
|
||||
|
||||
.. **Structured-sparse WGMMA** (``wgmma.mma_async.sp``)
|
||||
|
||||
.. 2:4 structured sparsity in operand A: out of every 4 consecutive K-elements,
|
||||
.. exactly 2 are non-zero. The instruction K is **doubled** relative to the
|
||||
.. dense counterpart (e.g. ``m64nNk32`` for F16/BF16 vs ``m64nNk16`` dense)
|
||||
.. because A is stored in compressed form. Supported data types include
|
||||
.. F16/BF16, TF32, FP8, and INT8.
|
||||
|
||||
.. Compared to the dense workflow, a sparse kernel would add:
|
||||
|
||||
.. - A **compressed A tensor** storing only the non-zero values (half the
|
||||
.. K-elements), and a **metadata tensor E** encoding which 2 of 4 positions
|
||||
.. are non-zero.
|
||||
.. - Extra SMEM layouts, TMA loads, and allocations for both the compressed A
|
||||
.. and the metadata E.
|
||||
.. - A metadata staging step each K-tile (SMEM to the MMA instruction).
|
||||
|
||||
.. Once DSL support is added, the same fence/commit/wait workflow described in
|
||||
.. this guide applies, with the additional metadata operand.
|
||||
|
||||
.. **Dense TF32 WGMMA** (``m64nNk8``)
|
||||
|
||||
.. TF32 (19-bit truncated FP32) inputs with FP32 accumulator. The instruction
|
||||
.. K = 8 is smaller than F16's K = 16, so MMA_K repeat counts are larger for
|
||||
.. the same BK tile size. Otherwise the workflow is identical to the dense
|
||||
.. F16/BF16 path — the same SMEM layout, descriptor, and fence/commit/wait
|
||||
.. pattern applies.
|
||||
|
||||
.. **Dense B1 WGMMA** (``m64nNk256``)
|
||||
|
||||
.. 1-bit (binary) inputs with INT32 accumulator. The very large instruction
|
||||
.. K = 256 means each atom consumes 256 bits along K per operand, resulting in
|
||||
.. small MMA_K repeat counts. This is a niche instruction for binary neural
|
||||
.. networks.
|
||||
|
||||
|
||||
See also:
|
||||
|
||||
- Dense GEMM example: ``examples/cute/hopper/kernel/dense_gemm/dense_gemm.py``
|
||||
- Persistent GEMM example: ``examples/cute/hopper/kernel/dense_gemm/dense_gemm_persistent.py``
|
||||
- FMHA example (RMEM A path): ``examples/cute/hopper/kernel/attention/fmha.py``
|
||||
- Helper utilities: ``cutlass.utils.hopper_helpers``
|
||||
1358
media/docs/pythonDSL/mma_docs/wmma_programming.rst
Normal file
1358
media/docs/pythonDSL/mma_docs/wmma_programming.rst
Normal file
File diff suppressed because it is too large
Load Diff
@@ -720,7 +720,9 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
offset = len(all_args) - len(func_ast.args.defaults)
|
||||
for i, default_node in enumerate(func_ast.args.defaults):
|
||||
ast_defaults[all_args[offset + i].arg] = default_node
|
||||
for kwarg, kw_default in zip(func_ast.args.kwonlyargs, func_ast.args.kw_defaults):
|
||||
for kwarg, kw_default in zip(
|
||||
func_ast.args.kwonlyargs, func_ast.args.kw_defaults
|
||||
):
|
||||
if kw_default is not None:
|
||||
ast_defaults[kwarg.arg] = kw_default
|
||||
for param_name, default_val in params_with_defaults.items():
|
||||
|
||||
@@ -1865,7 +1865,7 @@ class BaseDSL(metaclass=DSLSingletonMeta):
|
||||
sources = set(x.value for x in link_libraries_attributes)
|
||||
link_libraries = (
|
||||
link_libraries
|
||||
+ ("," if len(link_libraries) > 0 else "")
|
||||
+ ("," if link_libraries and len(sources) > 0 else "")
|
||||
+ ",".join(sources)
|
||||
)
|
||||
self.compile_options.options[LinkLibraries] = LinkLibraries(
|
||||
|
||||
@@ -88,6 +88,11 @@ def _get_gpu_arch_info(major: int, minor: int) -> tuple[str, str, list[str]]:
|
||||
"sm_120a",
|
||||
["sm_120a"],
|
||||
), # RTX PRO 6000 / RTX 50 Series
|
||||
(12, 1): (
|
||||
"Blackwell",
|
||||
"sm_121a",
|
||||
["sm_121a"],
|
||||
), # DGX Spark
|
||||
}
|
||||
return gpu_arch_map.get(
|
||||
(major, minor), ("Unknown", f"sm_{major}{minor}", [f"sm_{major}{minor}"])
|
||||
|
||||
@@ -3330,7 +3330,12 @@ def filter_zeros(
|
||||
if not isinstance(input, (Layout, Tensor)):
|
||||
raise TypeError(f"Expected layout or tensor as input, but got {type(input)=}")
|
||||
if isinstance(input, Tensor):
|
||||
input = input.value
|
||||
return _op_wrapper(
|
||||
partial(_cute_ir.filter_zeros, target_profile=target_profile),
|
||||
input,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return _cute_ir.filter_zeros(input, target_profile=target_profile, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@@ -3388,7 +3393,7 @@ def filter(
|
||||
input.inner, input.offset, filter(input.outer, loc=loc, ip=ip)
|
||||
)
|
||||
elif isinstance(input, _Tensor):
|
||||
return _cute_ir.filter(input.value, loc=loc, ip=ip)
|
||||
return _op_wrapper(_cute_ir.filter, input, loc=loc, ip=ip)
|
||||
else:
|
||||
return _cute_ir.filter(input, loc=loc, ip=ip)
|
||||
|
||||
@@ -5020,10 +5025,9 @@ def local_partition(
|
||||
raise NotImplementedError(
|
||||
f"Index value should be 32-bit or smaller integer type, but got {index_val.type}"
|
||||
)
|
||||
return _cute_ir.local_partition(
|
||||
input=target.value,
|
||||
tiler=dice(tiler, proj),
|
||||
index=index_val,
|
||||
return _op_wrapper(
|
||||
partial(_cute_ir.local_partition, tiler=dice(tiler, proj), index=index_val),
|
||||
target,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
@@ -5114,11 +5118,9 @@ def local_tile(
|
||||
proj_val = _pack_coord(proj, loc=loc, ip=ip)
|
||||
proj = proj_val.type.attribute
|
||||
|
||||
return _cute_ir.local_tile(
|
||||
input=input.value,
|
||||
tile=tiler_val,
|
||||
coord=coord_val,
|
||||
proj=proj,
|
||||
return _op_wrapper(
|
||||
partial(_cute_ir.local_tile, tile=tiler_val, coord=coord_val, proj=proj),
|
||||
input,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@@ -21,6 +21,9 @@ __all__ = [
|
||||
"MmaFP8Op",
|
||||
"MmaMXF4Op",
|
||||
"MmaMXF4NVF4Op",
|
||||
"MmaMXF8Op",
|
||||
"MmaMXF8F6F4Op",
|
||||
"MXF8F6F4_SUPPORTED_PAIRS",
|
||||
# copy.py
|
||||
"LdMatrix8x8x16bOp",
|
||||
"LdMatrix16x8x8bOp",
|
||||
|
||||
@@ -224,7 +224,9 @@ class MmaSM120BlockScaledOp(MmaOp):
|
||||
|
||||
admissible_archs = [
|
||||
Arch.sm_120a,
|
||||
Arch.sm_120f,
|
||||
Arch.sm_121a,
|
||||
Arch.sm_121f,
|
||||
]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@@ -239,29 +241,44 @@ class MmaSM120BlockScaledOp(MmaOp):
|
||||
"CUTE_DSL_ARCH set to sm_120a or sm_121a",
|
||||
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
||||
)
|
||||
if self.ab_dtype != Float4E2M1FN:
|
||||
# (ab_dtype, shape_mnk) consistency: FP4 uses (16,8,64); FP8 uses (16,8,32).
|
||||
if self.ab_dtype == Float4E2M1FN:
|
||||
if self.shape_mnk != (16, 8, 64):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'shape_mnk' Op parameter to be (16,8,64) for Float4E2M1FN",
|
||||
)
|
||||
elif self.ab_dtype in (Float8E4M3FN, Float8E5M2):
|
||||
if self.shape_mnk != (16, 8, 32):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'shape_mnk' Op parameter to be (16,8,32) for Float8E4M3FN/Float8E5M2",
|
||||
)
|
||||
else:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'ab_dtype' Op parameter to be Float4E2M1FN",
|
||||
"expects the 'ab_dtype' Op parameter to be Float4E2M1FN, Float8E4M3FN, or Float8E5M2",
|
||||
)
|
||||
if self.acc_dtype != Float32:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'acc_dtype' Op parameter to be Float32",
|
||||
)
|
||||
if self.shape_mnk != (16, 8, 64):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'shape_mnk' Op parameter to be (16,8,64)",
|
||||
)
|
||||
|
||||
if self.sf_vec_size == 16:
|
||||
# vec_size=16 is only valid for FP4 (NVFP4) with E4M3 scale.
|
||||
if self.ab_dtype != Float4E2M1FN:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'sf_vec_size' Op parameter to be 32 for Float8E4M3FN/Float8E5M2",
|
||||
)
|
||||
if self.sf_type != Float8E4M3FN:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'sf_type' Op parameter to be Float8E4M3FN",
|
||||
)
|
||||
elif self.sf_vec_size == 32:
|
||||
# vec_size=32 path uses UE8M0 scale for both FP4 (MXF4) and FP8 (MXF8).
|
||||
if self.sf_type != Float8E8M0FNU:
|
||||
raise OpError(
|
||||
self,
|
||||
@@ -275,7 +292,7 @@ class MmaSM120BlockScaledOp(MmaOp):
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
"warp-level MXF4/MXF4NVF4 MMA Operation"
|
||||
"warp-level MXF4/MXF4NVF4/MXF8 MMA Operation"
|
||||
+ f"\n A/B data type = {self.ab_dtype}"
|
||||
+ f"\n Accumulator data type = {self.acc_dtype}"
|
||||
+ f"\n Instruction shape MNK = {self.shape_mnk}"
|
||||
@@ -474,3 +491,214 @@ class MmaMXF4NVF4Op(MmaSM120BlockScaledOp):
|
||||
|
||||
class MmaMXF4NVF4Trait(MmaBlockScaledTrait):
|
||||
pass
|
||||
|
||||
|
||||
#
|
||||
# MXF8 MMA
|
||||
#
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MmaMXF8Op(MmaSM120BlockScaledOp):
|
||||
"""
|
||||
MXF8 warp-level MMA Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma>`__.
|
||||
This Operation covers the instructions using the ``.e4m3`` / ``.e5m2`` qualifiers for the input operands.
|
||||
.kind = {.kind::mxf8};
|
||||
.scale_vec_size = {.scale_vec::1X};
|
||||
.stype = {.ue8m0};
|
||||
"""
|
||||
|
||||
descriptive_name = "warp-level MXF8 MMA Operation"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ab_dtype: Type[Numeric],
|
||||
acc_dtype: Type[Numeric],
|
||||
sf_type: Type[Numeric],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
ab_dtype,
|
||||
acc_dtype,
|
||||
(16, 8, 32),
|
||||
sf_type,
|
||||
32,
|
||||
)
|
||||
|
||||
def _make_trait(
|
||||
self,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
**kwargs: Any,
|
||||
) -> "MmaMXF8Trait":
|
||||
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
||||
ty = _cute_nvgpu_ir.MmaAtomSM120BlockScaledType.get(
|
||||
shape_mnk.type.attribute,
|
||||
32,
|
||||
False,
|
||||
self.ab_dtype.mlir_type,
|
||||
self.ab_dtype.mlir_type,
|
||||
self.acc_dtype.mlir_type,
|
||||
self.sf_type.mlir_type,
|
||||
)
|
||||
return MmaMXF8Trait(make_atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class MmaMXF8Trait(MmaBlockScaledTrait):
|
||||
pass
|
||||
|
||||
|
||||
#
|
||||
# MXF8F6F4 mixed-precision MMA (independent A/B dtypes)
|
||||
#
|
||||
|
||||
|
||||
MXF8F6F4_SUPPORTED_PAIRS = frozenset(
|
||||
{
|
||||
(Float4E2M1FN, Float8E4M3FN),
|
||||
(Float4E2M1FN, Float8E5M2),
|
||||
(Float8E4M3FN, Float4E2M1FN),
|
||||
(Float8E5M2, Float4E2M1FN),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MmaMXF8F6F4Op(MmaOp):
|
||||
"""
|
||||
SM120 MXF8F6F4 mixed-precision warp-level block-scaled MMA Operation.
|
||||
|
||||
Covers the PTX instructions using independent ``.<a_type>.<b_type>``
|
||||
qualifiers (one of e2m1.e4m3, e2m1.e5m2, e4m3.e2m1, e5m2.e2m1):
|
||||
|
||||
.kind = {.kind::mxf8f6f4};
|
||||
.scale_vec_size = {.scale_vec::1X};
|
||||
.stype = {.ue8m0};
|
||||
|
||||
A and B operand dtypes are independent. Same-dtype FP4/FP4 and FP8/FP8
|
||||
paths remain on ``MmaMXF4Op`` / ``MmaMXF4NVF4Op`` / ``MmaMXF8Op``
|
||||
respectively. Same-width mixed-FP8 (E4M3 + E5M2) and FP6 mixed pairs
|
||||
are not supported.
|
||||
"""
|
||||
|
||||
a_dtype: Type[Numeric]
|
||||
b_dtype: Type[Numeric]
|
||||
acc_dtype: Type[Numeric]
|
||||
sf_type: Type[Numeric]
|
||||
|
||||
descriptive_name = "warp-level MXF8F6F4 mixed-precision MMA Operation"
|
||||
|
||||
shape_mnk = (16, 8, 32)
|
||||
sf_vec_size = 32
|
||||
use_sf_layout_TV = False
|
||||
|
||||
admissible_archs = [
|
||||
Arch.sm_120a,
|
||||
Arch.sm_121a,
|
||||
]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Verify arch
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
||||
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
||||
)
|
||||
if self.acc_dtype != Float32:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'acc_dtype' Op parameter to be Float32",
|
||||
)
|
||||
if self.sf_type != Float8E8M0FNU:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'sf_type' Op parameter to be Float8E8M0FNU",
|
||||
)
|
||||
# Reject same-dtype pairs explicitly (route to dedicated ops).
|
||||
if self.a_dtype == self.b_dtype:
|
||||
if self.a_dtype == Float4E2M1FN:
|
||||
raise OpError(
|
||||
self,
|
||||
"same-dtype Float4E2M1FN/Float4E2M1FN is not supported by MmaMXF8F6F4Op; "
|
||||
"use MmaMXF4Op (sf_vec_size=32) or MmaMXF4NVF4Op (sf_vec_size=16) instead",
|
||||
)
|
||||
if self.a_dtype in (Float8E4M3FN, Float8E5M2):
|
||||
raise OpError(
|
||||
self,
|
||||
"same-dtype FP8/FP8 is not supported by MmaMXF8F6F4Op; "
|
||||
"use MmaMXF8Op instead",
|
||||
)
|
||||
# Reject same-width mixed-FP8 (E4M3 + E5M2) explicitly.
|
||||
fp8_dtypes = (Float8E4M3FN, Float8E5M2)
|
||||
if self.a_dtype in fp8_dtypes and self.b_dtype in fp8_dtypes:
|
||||
raise OpError(
|
||||
self,
|
||||
"same-width mixed-FP8 (Float8E4M3FN + Float8E5M2) is not supported; "
|
||||
"supported MXF8F6F4 pairs are (Float4E2M1FN x Float8E4M3FN/Float8E5M2) "
|
||||
"and the reverse",
|
||||
)
|
||||
# Final allow-list check (catches FP6 and any other unsupported dtype).
|
||||
if (self.a_dtype, self.b_dtype) not in MXF8F6F4_SUPPORTED_PAIRS:
|
||||
raise OpError(
|
||||
self,
|
||||
f"unsupported (a_dtype, b_dtype) = ({self.a_dtype}, {self.b_dtype}) "
|
||||
f"for MmaMXF8F6F4Op; supported pairs are "
|
||||
f"{sorted(repr(p) for p in MXF8F6F4_SUPPORTED_PAIRS)}. "
|
||||
f"FP6 mixed pairs are not supported.",
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
"warp-level MXF8F6F4 mixed-precision MMA Operation"
|
||||
+ f"\n A data type = {self.a_dtype}"
|
||||
+ f"\n B data type = {self.b_dtype}"
|
||||
+ f"\n Accumulator data type = {self.acc_dtype}"
|
||||
+ f"\n Instruction shape MNK = {self.shape_mnk}"
|
||||
+ f"\n Vector size = {self.sf_vec_size}"
|
||||
+ f"\n SF data type = {self.sf_type}"
|
||||
)
|
||||
|
||||
def _verify_fragment_A(
|
||||
self,
|
||||
input: Tensor,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def _verify_fragment_B(
|
||||
self,
|
||||
input: Tensor,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def _make_trait(
|
||||
self,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
**kwargs: Any,
|
||||
) -> "MmaMXF8F6F4Trait":
|
||||
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
||||
ty = _cute_nvgpu_ir.MmaAtomSM120BlockScaledType.get(
|
||||
shape_mnk.type.attribute,
|
||||
self.sf_vec_size,
|
||||
self.use_sf_layout_TV,
|
||||
self.a_dtype.mlir_type,
|
||||
self.b_dtype.mlir_type,
|
||||
self.acc_dtype.mlir_type,
|
||||
self.sf_type.mlir_type,
|
||||
)
|
||||
return MmaMXF8F6F4Trait(make_atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class MmaMXF8F6F4Trait(MmaBlockScaledTrait):
|
||||
pass
|
||||
|
||||
@@ -21,7 +21,11 @@ from jax._src.interpreters import batching
|
||||
|
||||
|
||||
from .compile import get_or_compile_kernel, build_function_spec
|
||||
from .types import cutlass_to_jax_layout_order, default_tensor_spec, TensorSpec
|
||||
from .types import (
|
||||
cutlass_to_jax_layout_order,
|
||||
default_tensor_spec,
|
||||
TensorSpec,
|
||||
)
|
||||
from .ffi import get_cutlass_call_ffi_name, is_ffi_registered, register_ffi
|
||||
|
||||
|
||||
@@ -77,8 +81,10 @@ def cutlass_call(
|
||||
objects with ``.shape`` and ``.dtype`` attributes) describing each
|
||||
output buffer.
|
||||
input_spec: A :class:`TensorSpec` or list thereof providing
|
||||
layout/mode/divisibility hints for input tensors. ``None`` infers
|
||||
defaults from each array.
|
||||
layout/mode/divisibility hints for input tensors. ``None`` infers
|
||||
defaults from each array. A ``TensorSpec`` with ``layout=None`` uses
|
||||
and constrains row-major physical layout; use ``mode`` to remap
|
||||
physical dimensions to the kernel's logical modes.
|
||||
output_spec: Same as *input_spec* but applied to output tensors.
|
||||
input_output_aliases: ``{input_index: output_index}`` mapping that
|
||||
allows an input buffer to alias an output, avoiding an extra copy.
|
||||
@@ -308,7 +314,9 @@ def cutlass_call_inner_p_impl(
|
||||
|
||||
call_name = get_cutlass_call_ffi_name(allow_cuda_graph)
|
||||
|
||||
# Convert layout from CuTeDSL to JAX order as ffi_call expects this.
|
||||
# Convert explicit layout constraints from CuTeDSL to JAX order. ``None`` is
|
||||
# passed through intentionally: jax.ffi.ffi_call treats it as default
|
||||
# row-major layout.
|
||||
input_layouts = [cutlass_to_jax_layout_order(s.layout) for s in input_spec_flat]
|
||||
output_layouts = [cutlass_to_jax_layout_order(s.layout) for s in output_spec_flat]
|
||||
|
||||
|
||||
@@ -15,12 +15,18 @@ import jax.numpy as jnp
|
||||
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cutlass_dsl import dsl_user_op
|
||||
from typing import Optional
|
||||
from typing import Optional, Sequence
|
||||
from cutlass._mlir import ir
|
||||
|
||||
|
||||
def reorder_modes(src: str, target: str) -> tuple[int, ...]:
|
||||
"""Computes the mode given a source and target order."""
|
||||
def reorder_modes(src: Sequence[str], target: Sequence[str]) -> tuple[int, ...]:
|
||||
"""Compute a ``TensorSpec.mode`` from physical input order to kernel order.
|
||||
|
||||
``src`` names the JAX array's physical dimension order. ``target`` names the
|
||||
logical mode order that the CuTe kernel expects. The returned tuple can be
|
||||
passed as ``TensorSpec(mode=...)`` while leaving ``layout`` at its default
|
||||
row-major value when the JAX buffer is physically row-major.
|
||||
"""
|
||||
src = tuple(src)
|
||||
target = tuple(target)
|
||||
src_map = {}
|
||||
@@ -29,52 +35,64 @@ def reorder_modes(src: str, target: str) -> tuple[int, ...]:
|
||||
return tuple([src_map[d] for d in target])
|
||||
|
||||
|
||||
def gemm_a_major(d: str):
|
||||
"""Returns order for A tensor major mode."""
|
||||
def gemm_a_major(d: str) -> str:
|
||||
"""Return the physical JAX dimension order for an A tensor major mode.
|
||||
|
||||
The returned string is not the kernel's canonical logical order. Use
|
||||
:func:`gemm_a_mode` to map this physical order to kernel logical ``mkl``.
|
||||
"""
|
||||
return {"k": "lmk", "m": "lkm"}[d]
|
||||
|
||||
|
||||
def gemm_a_mode(d: str) -> tuple[int, ...]:
|
||||
"""Returns mode for A tensor major mode."""
|
||||
"""Return ``TensorSpec.mode`` for A, mapping physical order to logical ``mkl``."""
|
||||
return reorder_modes(gemm_a_major(d), "mkl")
|
||||
|
||||
|
||||
def gemm_b_major(d: str):
|
||||
"""Returns order for B tensor major mode."""
|
||||
def gemm_b_major(d: str) -> str:
|
||||
"""Return the physical JAX dimension order for a B tensor major mode.
|
||||
|
||||
The returned string is not the kernel's canonical logical order. Use
|
||||
:func:`gemm_b_mode` to map this physical order to kernel logical ``nkl``.
|
||||
"""
|
||||
return {"k": "lnk", "n": "lkn"}[d]
|
||||
|
||||
|
||||
def gemm_b_mode(d: str) -> tuple[int, ...]:
|
||||
"""Returns mode for B tensor major mode."""
|
||||
"""Return ``TensorSpec.mode`` for B, mapping physical order to logical ``nkl``."""
|
||||
return reorder_modes(gemm_b_major(d), "nkl")
|
||||
|
||||
|
||||
def gemm_c_major(d: str):
|
||||
"""Returns order for C tensor major mode."""
|
||||
def gemm_c_major(d: str) -> str:
|
||||
"""Return the physical JAX dimension order for a C/D tensor major mode.
|
||||
|
||||
The returned string is not the kernel's canonical logical order. Use
|
||||
:func:`gemm_c_mode` to map this physical order to kernel logical ``mnl``.
|
||||
"""
|
||||
return {"n": "lmn", "m": "lnm"}[d]
|
||||
|
||||
|
||||
def gemm_c_mode(d: str) -> tuple[int, ...]:
|
||||
"""Returns mode for C tensor major mode."""
|
||||
"""Return ``TensorSpec.mode`` for C/D, mapping physical order to logical ``mnl``."""
|
||||
return reorder_modes(gemm_c_major(d), "mnl")
|
||||
|
||||
|
||||
def gemm_a_shape(l, m, k, major) -> tuple[int, ...]:
|
||||
"""Returns shape for A tensor given major mode."""
|
||||
def gemm_a_shape(l: int, m: int, k: int, major: str) -> tuple[int, ...]:
|
||||
"""Return the physical row-major JAX shape for A with the requested major mode."""
|
||||
assert major in ("k", "m")
|
||||
shape = (l, m, k) if major == "k" else (l, k, m)
|
||||
return shape
|
||||
|
||||
|
||||
def gemm_b_shape(l, n, k, major) -> tuple[int, ...]:
|
||||
"""Returns shape for B tensor given major mode."""
|
||||
def gemm_b_shape(l: int, n: int, k: int, major: str) -> tuple[int, ...]:
|
||||
"""Return the physical row-major JAX shape for B with the requested major mode."""
|
||||
assert major in ("k", "n")
|
||||
shape = (l, n, k) if major == "k" else (l, k, n)
|
||||
return shape
|
||||
|
||||
|
||||
def gemm_c_shape(l, m, n, major) -> tuple[int, ...]:
|
||||
"""Returns shape for C tensor given major mode."""
|
||||
def gemm_c_shape(l: int, m: int, n: int, major: str) -> tuple[int, ...]:
|
||||
"""Return the physical row-major JAX shape for C/D with the requested major mode."""
|
||||
assert major in ("m", "n")
|
||||
shape = (l, m, n) if major == "n" else (l, n, m)
|
||||
return shape
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from typing import Optional, Sequence
|
||||
from typing import Any, Optional, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import jax.numpy as jnp
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.core import IntValue
|
||||
from cutlass.cute.runtime import from_dlpack as _from_dlpack
|
||||
from cutlass.cute import AddressSpace
|
||||
from cutlass._mlir import ir
|
||||
@@ -58,35 +59,69 @@ DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNMENT = 256
|
||||
class TensorSpec:
|
||||
"""Specifies the layout and metadata for a JAX array passed to a CuTe kernel.
|
||||
|
||||
TensorSpec controls how a JAX array's dimensions are mapped to a cute.Tensor
|
||||
during jit lowering, including stride ordering, mode permutation, and whether
|
||||
shapes/strides are compiled as static constants.
|
||||
TensorSpec controls how a JAX array's input dimensions are mapped to a
|
||||
``cute.Tensor`` during jit lowering, including compact stride ordering,
|
||||
mode permutation, and whether shapes/strides are compiled as static
|
||||
constants. The JAX bridge models tensors as compact layouts: runtime
|
||||
strides are derived from runtime shapes using ``layout`` order rather than
|
||||
loaded from a strided view descriptor.
|
||||
|
||||
A useful way to choose a spec is to separate physical storage from logical
|
||||
kernel modes:
|
||||
|
||||
1. First choose the public JAX array shape and its compact physical memory
|
||||
order. If the buffer is a standard row-major JAX array, leave
|
||||
``layout=None``. ``cutlass_call`` will constrain the FFI operand/result
|
||||
to row-major physical layout, matching the CuTe tensor strides that are
|
||||
built from the default.
|
||||
2. Then use ``mode`` only when the kernel should see those input dimensions
|
||||
in a different logical order. ``mode`` is applied after the compact
|
||||
layout is built; it is not a request for JAX/XLA to transpose data.
|
||||
|
||||
For example, a row-major JAX buffer shaped ``(expert_count, N, K)`` can be
|
||||
presented to a kernel expecting logical ``(N, K, expert_count)`` with
|
||||
``TensorSpec(mode=(1, 2, 0))``. No explicit ``layout`` is needed because the
|
||||
physical buffer is still ordinary row-major, and the FFI call will be
|
||||
constrained accordingly. Use ``layout`` only when the compact physical
|
||||
stride order itself differs from the default row-major order, such as a
|
||||
column-major compact buffer.
|
||||
|
||||
Attributes:
|
||||
layout: A minor-to-major stride ordering in CuTeDSL convention. ``layout[i]``
|
||||
gives the stride rank of dimension ``i``, where rank 0 means the smallest
|
||||
(innermost) stride. For example, row-major order for a 3-D tensor is
|
||||
``(2, 1, 0)``. If ``None``, row-major is assumed. Use
|
||||
:func:`jax_to_cutlass_layout_order` to convert from JAX's major-to-minor
|
||||
convention.
|
||||
mode: A permutation that maps the stride-ordered dimensions to the mode
|
||||
positions of the resulting ``cute.Layout``. For example, ``mode=(2, 0, 1)``
|
||||
reorders an ``(M, K, L)`` layout into ``(K, L, M)`` mode order inside the
|
||||
kernel. If ``None``, modes match the natural dimension order ``(0, 1, ..., N-1)``.
|
||||
gives the compact physical stride rank of input dimension ``i``,
|
||||
where rank 0 means the smallest (innermost) stride. For example,
|
||||
row-major order for a 3-D tensor is ``(2, 1, 0)``. If ``None``,
|
||||
row-major is assumed. Use :func:`jax_to_cutlass_layout_order` to
|
||||
convert from JAX's major-to-minor convention. ``layout`` does not
|
||||
change which logical mode a dimension represents; combine it with
|
||||
``mode`` when physical order and kernel-logical order differ.
|
||||
mode: A permutation applied after the compact layout is constructed. It
|
||||
selects input dimensions into the mode positions seen by the kernel.
|
||||
For example, ``mode=(2, 0, 1)`` presents an input shaped
|
||||
``(M, K, L)`` to the kernel as logical ``(L, M, K)``. If ``None``,
|
||||
modes match the natural input-dimension order ``(0, 1, ..., N-1)``.
|
||||
``mode`` changes the tensor layout object seen by CuTe code but
|
||||
does not materialize a transpose or change the underlying buffer.
|
||||
static: If ``True``, shapes and strides are compiled as static ``constexpr``
|
||||
values, which may enable additional compiler optimisations. Kernels that
|
||||
do not support static shapes will raise a compile error. Must be ``False``
|
||||
when any dimension is symbolic (e.g. under ``jax.export``).
|
||||
ptr_assumed_align: Assumed byte alignment of the tensor's data pointer.
|
||||
Overrides the default of 256 bytes. Rarely needs to change.
|
||||
divisibility: Optional per-mode divisibility hints. If a single int is passed
|
||||
divisibility will be applied to the leading (stride=1) dimension only.
|
||||
divisibility: Optional divisibility hints for input dimensions, in the
|
||||
same order as the JAX array shape and before any ``mode`` reordering.
|
||||
Positive hints constrain dynamic shape values and are propagated
|
||||
through compact stride construction: a stride inherits the product
|
||||
of the divisibilities for dimensions with lower stride rank.
|
||||
Positive explicit hints take precedence over inferred concrete
|
||||
extents. If a single int is passed, it is applied to the leading
|
||||
compact dimension only, where ``layout[i] == 0``.
|
||||
"""
|
||||
|
||||
# Minor-to-major stride ordering in CuTeDSL convention (layout[i] = stride rank
|
||||
# of dimension i, 0 = innermost). Defaults to row-major if None.
|
||||
layout: tuple[int, ...] | None = field(metadata=dict(static=True), default=None)
|
||||
# Permutation from stride-ordered dimensions to cute.Layout mode positions.
|
||||
# Permutation from input dimensions to cute.Layout mode positions.
|
||||
# Defaults to identity (0, 1, ..., N-1) if None.
|
||||
mode: tuple[int, ...] | None = field(metadata=dict(static=True), default=None)
|
||||
# If True, shapes and strides are embedded as compile-time constants.
|
||||
@@ -96,7 +131,7 @@ class TensorSpec:
|
||||
ptr_assumed_align: int = field(
|
||||
metadata=dict(static=True), default=DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNMENT
|
||||
)
|
||||
# Per-mode divisibility hints.
|
||||
# Per-input-dimension divisibility hints, before mode reordering.
|
||||
divisibility: tuple[int | None, ...] | int | None = field(
|
||||
metadata=dict(static=True), default=None
|
||||
)
|
||||
@@ -128,9 +163,10 @@ def row_major_layout(shaped):
|
||||
def default_tensor_mode(shaped):
|
||||
"""Returns the identity mode permutation for an N-dimensional tensor.
|
||||
|
||||
The mode permutation maps stride-ordered dimensions to ``cute.Layout`` mode
|
||||
positions. The default identity ``(0, 1, ..., N-1)`` leaves the mode order
|
||||
unchanged relative to the dimension order.
|
||||
The mode permutation maps JAX input dimensions to ``cute.Layout`` mode
|
||||
positions after the compact layout has been constructed. The default
|
||||
identity ``(0, 1, ..., N-1)`` leaves the mode order unchanged relative to
|
||||
the JAX shape order.
|
||||
|
||||
Args:
|
||||
shaped: An object with a ``.shape`` attribute, or a shape tuple/sequence.
|
||||
@@ -151,12 +187,22 @@ def default_tensor_spec(shaped) -> TensorSpec:
|
||||
TensorSpec(layout=(N-1, ..., 1, 0), mode=(0, 1, ..., N-1), divisibility=(D0, D1, ... DN-1))
|
||||
|
||||
This is appropriate for standard row-major (C-contiguous) JAX arrays that
|
||||
do not require dimension reordering inside the kernel.
|
||||
do not require dimension reordering inside the kernel. The resulting JAX
|
||||
CuTe tensor is treated as compact: strides are derived from shapes using the
|
||||
row-major layout order.
|
||||
|
||||
Divisibility hints are inferred only for concrete integer dimensions.
|
||||
Symbolic dimensions always produce ``None`` for their slot; pass an
|
||||
explicit ``TensorSpec`` with ``divisibility`` set if you need alignment
|
||||
hints for symbolic shapes.
|
||||
If the JAX buffer is row-major but the kernel expects a different logical
|
||||
mode order, use an explicit :class:`TensorSpec` with ``mode`` set and leave
|
||||
``layout`` unset. ``cutlass_call`` still constrains the FFI buffer to
|
||||
row-major layout in this case. For example, ``TensorSpec(mode=(1, 2, 0))``
|
||||
maps a physical ``(L, M, K)`` row-major input to a logical ``(M, K, L)``
|
||||
tensor.
|
||||
|
||||
Divisibility hints are inferred only for concrete integer input dimensions.
|
||||
Symbolic dimensions always produce ``None`` for their slot; pass an explicit
|
||||
``TensorSpec`` with ``divisibility`` set if you need alignment hints for
|
||||
symbolic shapes or want a weaker explicit constraint than the concrete
|
||||
extent.
|
||||
|
||||
Args:
|
||||
shaped: An object with a ``.shape`` attribute, or a shape tuple/sequence.
|
||||
@@ -179,11 +225,12 @@ def default_tensor_spec(shaped) -> TensorSpec:
|
||||
def _expand_divisibility(
|
||||
divisibility, order: tuple[int, ...], ndim: int
|
||||
) -> tuple[int | None, ...] | None:
|
||||
"""Expand a divisibility spec to a full per-dimension tuple.
|
||||
"""Expand a divisibility spec to a full per-input-dimension tuple.
|
||||
|
||||
A bare ``int`` is placed at the leading-dimension slot (where
|
||||
``order[i] == 0``, i.e. stride == 1) and ``None`` everywhere else.
|
||||
A tuple is returned unchanged. ``None`` returns ``None``.
|
||||
A tuple is already in JAX input-dimension order and is returned unchanged.
|
||||
``None`` returns ``None``.
|
||||
"""
|
||||
if divisibility is None or isinstance(divisibility, tuple):
|
||||
return divisibility
|
||||
@@ -268,7 +315,20 @@ def from_dlpack(array, assumed_align: int = DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNM
|
||||
return _from_dlpack(array, assumed_align=assumed_align)
|
||||
|
||||
|
||||
def _validate_permutation(name: str, perm, shape):
|
||||
def _assume_divisible_int(
|
||||
value: Any,
|
||||
divby: int,
|
||||
*,
|
||||
loc: ir.Location | None = None,
|
||||
ip: ir.InsertionPoint | None = None,
|
||||
) -> Any:
|
||||
"""Attach a divisibility assumption to an integer value without narrowing it."""
|
||||
if divby <= 1:
|
||||
return value
|
||||
return cute.assume(IntValue(value, loc=loc, ip=ip), divby=divby, loc=loc, ip=ip)
|
||||
|
||||
|
||||
def _validate_permutation(name: str, perm: Sequence[int], shape: Sequence[Any]) -> None:
|
||||
if len(perm) != len(shape):
|
||||
raise ValueError(f"{name} must be same length as shape", perm, shape)
|
||||
for s in perm:
|
||||
@@ -292,9 +352,11 @@ class JaxArray:
|
||||
can be concrete or symbolic in the case of jax.export.
|
||||
3. mem_space: The memory space of the tensor. Defaults to gmem.
|
||||
4. assumed_align: The alignment of the tensor. Defaults to XLA alignment.
|
||||
5. order: Specifies the order of the shape to determine strides.
|
||||
6. mode: Specifies how to map ordered elements to the modes od a cute.Layout.
|
||||
5. order: Specifies the compact physical stride order of the shape.
|
||||
6. mode: Specifies how to map input dimensions to the logical modes seen by
|
||||
the kernel after the compact layout is constructed.
|
||||
7. static: If True, tensor shapes and strides are compiled statically.
|
||||
8. divisibility: Optional divisibility hints in input-dimension order.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -381,6 +443,21 @@ class JaxArrayValue(JaxArray):
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
|
||||
# Track the divisibility available for each input dimension. Explicit
|
||||
# hints win; otherwise concrete dimensions contribute their known extent.
|
||||
dim_divisibility = None
|
||||
if self.divisibility is not None:
|
||||
dim_divisibility = []
|
||||
for div_spec, static_s in zip(self.divisibility, self.shape):
|
||||
if div_spec is not None and div_spec > 0:
|
||||
dim_divisibility.append(div_spec)
|
||||
elif isinstance(static_s, int):
|
||||
dim_divisibility.append(static_s)
|
||||
else:
|
||||
dim_divisibility.append(1)
|
||||
dim_divisibility = tuple(dim_divisibility)
|
||||
|
||||
pairs = sorted(zip(shape, order), key=lambda x: x[1])
|
||||
|
||||
# Compute strides for each element in order.
|
||||
@@ -395,28 +472,29 @@ class JaxArrayValue(JaxArray):
|
||||
for i in range(len(shape)):
|
||||
strides_ordered.append(strides[order[i]])
|
||||
|
||||
if dim_divisibility is not None:
|
||||
# A compact stride is the product of all dimensions with a lower
|
||||
# stride order, so it inherits the product of their divisibility.
|
||||
stride_divisibility = []
|
||||
for dim_order in order:
|
||||
divby = 1
|
||||
for other_dim, other_order in enumerate(order):
|
||||
if other_order < dim_order:
|
||||
divby *= dim_divisibility[other_dim]
|
||||
stride_divisibility.append(divby)
|
||||
|
||||
strides_ordered = [
|
||||
_assume_divisible_int(s, divby, loc=loc, ip=ip)
|
||||
for s, divby in zip(strides_ordered, stride_divisibility)
|
||||
]
|
||||
|
||||
# Shapes are expected to be int32 so truncate to that before creating layout
|
||||
shape_i32 = tuple(arith.trunci(i32, s) for s in shape)
|
||||
|
||||
# Apply per-mode divisibility assumptions so the compiler can exploit alignment.
|
||||
if self.divisibility is not None:
|
||||
assumed = []
|
||||
for s32, div_spec, static_s in zip(
|
||||
shape_i32, self.divisibility, self.shape
|
||||
):
|
||||
if isinstance(static_s, int):
|
||||
# Pure static shape is known even though a dynamic shape is
|
||||
# used. We can assume the exact shape here. We keep the shape
|
||||
# as a dynamic value to avoid breaking code that may expect
|
||||
# a dynamic value.
|
||||
assumed.append(cute.assume(s32, divby=static_s))
|
||||
elif div_spec is not None:
|
||||
# Using a dynamic value so apply the div_spec if its provided.
|
||||
assumed.append(cute.assume(s32, divby=div_spec))
|
||||
else:
|
||||
# No divisibility specification for this shape
|
||||
assumed.append(s32)
|
||||
shape_i32 = tuple(assumed)
|
||||
if dim_divisibility is not None:
|
||||
shape_i32 = tuple(
|
||||
_assume_divisible_int(s, divby, loc=loc, ip=ip)
|
||||
for s, divby in zip(shape_i32, dim_divisibility)
|
||||
)
|
||||
|
||||
return cute.make_layout(shape_i32, stride=tuple(strides_ordered))
|
||||
|
||||
|
||||
@@ -84,6 +84,8 @@ from .tmem_allocator import (
|
||||
|
||||
from .layout import LayoutEnum
|
||||
|
||||
from .block import block_copy
|
||||
|
||||
from .mixed_input_helpers import (
|
||||
TransformMode,
|
||||
scale_tma_partition,
|
||||
@@ -176,6 +178,7 @@ __all__ = [
|
||||
"sm90",
|
||||
"sm100",
|
||||
"gemm",
|
||||
"block_copy",
|
||||
"ClcDynamicPersistentTileSchedulerParams",
|
||||
"ClcDynamicPersistentTileScheduler",
|
||||
"print_latex",
|
||||
|
||||
@@ -612,7 +612,7 @@ def get_tmem_load_op(
|
||||
def get_smem_layout_atom_ab(
|
||||
major_mode: OperandMajorMode,
|
||||
element_type: Type[Numeric],
|
||||
smem_shape_mn_k: Tuple[int, int],
|
||||
smem_shape_mn_k: cute.Tile,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
@@ -625,13 +625,16 @@ def get_smem_layout_atom_ab(
|
||||
:param element_type: The element type for the SMEM tensor.
|
||||
:type element_type: Type[Numeric]
|
||||
:param smem_shape_mn_k: The shape of the SMEM tensor.
|
||||
:type smem_shape_mn_k: Tuple[int, int]
|
||||
:type smem_shape_mn_k: cute.Tile
|
||||
:return: The SMEM layout atom kind
|
||||
:rtype: cutlass.cute.nvgpu.tcgen05.SmemLayoutAtomKind
|
||||
"""
|
||||
is_k_major = major_mode == OperandMajorMode.K
|
||||
major_mode_size = smem_shape_mn_k[1] if is_k_major else smem_shape_mn_k[0]
|
||||
|
||||
major_mode_size = (
|
||||
cute.size(smem_shape_mn_k, mode=[1])
|
||||
if is_k_major
|
||||
else cute.size(smem_shape_mn_k, mode=[0])
|
||||
)
|
||||
assert major_mode_size % 8 == 0
|
||||
sw128_num_contiguous_bits = 1024
|
||||
sw64_num_contiguous_bits = 512
|
||||
@@ -711,6 +714,7 @@ def make_smem_layout(
|
||||
cute.append(smem_tile_shape, num_stages),
|
||||
order=(0, 1, 2) if is_k_major else (1, 0, 2),
|
||||
)
|
||||
|
||||
return cute.coalesce(smem_layout, target_profile=(1, 1, 1), loc=loc, ip=ip)
|
||||
|
||||
|
||||
@@ -1956,12 +1960,35 @@ def thrfrg_SFA(
|
||||
"""Thread-fragment scale factor A tensor for SM120 block-scaled MMA.
|
||||
|
||||
Implements the ThrFrg partitioning for scale factor A according to the
|
||||
corresponding C++ code.
|
||||
corresponding C++ code in cutlass/include/cute/atom/mma_traits_sm120.hpp:
|
||||
SFALayout for SM120 MXF4 16x8x64 uses K=64, SM120 MXF8F6F4 16x8x32 uses
|
||||
K=32; the stride pattern ``((_8,_0,_1), _16)`` is shared.
|
||||
"""
|
||||
assert cute.rank(sfa_tensor) >= 2
|
||||
|
||||
atom_shape_mnk = tiled_mma.shape_mnk
|
||||
atom_sfa_layout = cute.make_layout(shape=((2, 2, 8), 64), stride=((8, 0, 1), 16))
|
||||
# K-dim of the warp-MMA atom: FP4 -> 64, FP8 -> 32 (per mma_traits_sm120.hpp).
|
||||
# For FP8 (atom_K=32) where mma_nsf=1, wrap K in a 2-tuple ``(atom_K, 1)``
|
||||
# so the layout's K mode keeps its 2D structure and the resulting fragment
|
||||
# has the same rank as the FP4 path. For FP4 (atom_K=64) the original 1D
|
||||
# layout already produces a 2D K decomposition through SMEM-layout
|
||||
# composition, so we keep the original shape.
|
||||
atom_K = atom_shape_mnk[2]
|
||||
if atom_K == 32:
|
||||
atom_sfa_layout = cute.make_layout(
|
||||
shape=((2, 2, 8), (atom_K, 1)),
|
||||
stride=((8, 0, 1), (16, 0)),
|
||||
)
|
||||
elif atom_K == 64:
|
||||
atom_sfa_layout = cute.make_layout(
|
||||
shape=((2, 2, 8), atom_K),
|
||||
stride=((8, 0, 1), 16),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"thrfrg_SFA: unsupported atom_K={atom_K}; SM120 block-scaled atoms "
|
||||
f"use atom_K=32 (mxf8/mxf8f6f4) or atom_K=64 (mxf4/mxf4nvf4)"
|
||||
)
|
||||
permutation_mnk = tiled_mma.permutation_mnk
|
||||
thr_layout_vmnk = tiled_mma.thr_layout_vmnk
|
||||
|
||||
@@ -2000,12 +2027,32 @@ def thrfrg_SFB(
|
||||
"""Thread-fragment scale factor B tensor for SM120 block-scaled MMA.
|
||||
|
||||
Implements the ThrFrg partitioning for scale factor B according to the
|
||||
corresponding C++ code.
|
||||
corresponding C++ code in cutlass/include/cute/atom/mma_traits_sm120.hpp:
|
||||
SFBLayout for SM120 MXF4 16x8x64 uses K=64, SM120 MXF8F6F4 16x8x32 uses
|
||||
K=32; the stride pattern ``((_0,_1), _8)`` is shared.
|
||||
"""
|
||||
assert cute.rank(sfb_tensor) >= 2
|
||||
|
||||
atom_shape_mnk = tiled_mma.shape_mnk
|
||||
atom_sfb_layout = cute.make_layout(shape=((4, 8), 64), stride=((0, 1), 8))
|
||||
# K-dim of the warp-MMA atom: FP4 -> 64, FP8 -> 32 (per mma_traits_sm120.hpp).
|
||||
# See :func:`thrfrg_SFA` for the rationale behind the FP8-only
|
||||
# ``(atom_K, 1)`` wrapping.
|
||||
atom_K = atom_shape_mnk[2]
|
||||
if atom_K == 32:
|
||||
atom_sfb_layout = cute.make_layout(
|
||||
shape=((4, 8), (atom_K, 1)),
|
||||
stride=((0, 1), (8, 0)),
|
||||
)
|
||||
elif atom_K == 64:
|
||||
atom_sfb_layout = cute.make_layout(
|
||||
shape=((4, 8), atom_K),
|
||||
stride=((0, 1), 8),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"thrfrg_SFB: unsupported atom_K={atom_K}; SM120 block-scaled atoms "
|
||||
f"use atom_K=32 (mxf8/mxf8f6f4) or atom_K=64 (mxf4/mxf4nvf4)"
|
||||
)
|
||||
permutation_mnk = tiled_mma.permutation_mnk
|
||||
thr_layout_vmnk = tiled_mma.thr_layout_vmnk
|
||||
|
||||
|
||||
248
python/CuTeDSL/cutlass/utils/block.py
Normal file
248
python/CuTeDSL/cutlass/utils/block.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
|
||||
from cutlass.cutlass_dsl import dsl_user_op, CuTeDSL
|
||||
|
||||
from cutlass.cute.typing import Tensor
|
||||
from cutlass.cute.core import make_layout, filter_zeros
|
||||
from cutlass.cute.atom import TiledCopy
|
||||
from cutlass.cute.algorithm import copy
|
||||
from cutlass.cute.nvgpu import tcgen05
|
||||
from cutlass.cute.nvgpu.cpasync.copy import (
|
||||
TmaCopyOp,
|
||||
CopyBulkTensorTileG2SOp,
|
||||
CopyBulkTensorTileG2SMulticastOp,
|
||||
)
|
||||
from cutlass.cute.nvgpu.cpasync.helpers import tma_partition
|
||||
from cutlass.cute.nvgpu.tcgen05.copy import _S2TCopyBase
|
||||
from typing import Any, Optional
|
||||
from cutlass._mlir import ir
|
||||
|
||||
|
||||
def _check_required_args(
|
||||
required_args: list[str], kwargs: dict, condition: bool = True
|
||||
) -> None:
|
||||
if not condition:
|
||||
return
|
||||
for arg in required_args:
|
||||
if arg not in kwargs:
|
||||
raise ValueError(f"Argument {arg} is required.")
|
||||
|
||||
|
||||
def _tma_copy_impl(
|
||||
tiled_copy: TiledCopy,
|
||||
src: Tensor,
|
||||
dst: Tensor,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Internal implementation for TMA-based block-level copy."""
|
||||
#
|
||||
# Handle tma_multicast argument
|
||||
#
|
||||
if "tma_multicast" in kwargs:
|
||||
if not isinstance(
|
||||
tiled_copy.op,
|
||||
(
|
||||
CopyBulkTensorTileG2SOp,
|
||||
),
|
||||
):
|
||||
raise ValueError(
|
||||
"block_copy with tma_multicast expects a non-multicast G2S TMA copy atom "
|
||||
"(CopyBulkTensorTileG2SOp) for compiler-driven multicast"
|
||||
)
|
||||
# Mark as coming from block API
|
||||
kwargs["tma_multicast"]["from_block_api"] = True
|
||||
|
||||
#
|
||||
# Check if required arguments are provided
|
||||
#
|
||||
is_bar_ptr_required = isinstance(
|
||||
tiled_copy.op,
|
||||
(
|
||||
CopyBulkTensorTileG2SOp,
|
||||
CopyBulkTensorTileG2SMulticastOp,
|
||||
),
|
||||
)
|
||||
_check_required_args(["tma_bar_ptr"], kwargs, is_bar_ptr_required)
|
||||
|
||||
#
|
||||
# TMA bulk tensor copies: partition via tma_partition
|
||||
#
|
||||
is_g2s = isinstance(
|
||||
tiled_copy.op,
|
||||
(
|
||||
CopyBulkTensorTileG2SOp,
|
||||
),
|
||||
)
|
||||
stensor = dst if is_g2s else src
|
||||
gtensor = src if is_g2s else dst
|
||||
cta_coord = 0
|
||||
cta_layout = make_layout(1, loc=loc, ip=ip)
|
||||
s_ptn, g_ptn = tma_partition(
|
||||
tiled_copy, cta_coord, cta_layout, stensor, gtensor, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
s_ptn = filter_zeros(s_ptn)
|
||||
g_ptn = filter_zeros(g_ptn)
|
||||
|
||||
src_arg = g_ptn if is_g2s else s_ptn
|
||||
dst_arg = s_ptn if is_g2s else g_ptn
|
||||
return copy(tiled_copy, src_arg, dst_arg, loc=loc, ip=ip, **kwargs)
|
||||
|
||||
|
||||
def _utccp_copy_impl(
|
||||
tiled_copy: TiledCopy,
|
||||
src: Tensor,
|
||||
dst: Tensor,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Internal implementation for S2T (SMEM to TMEM) copy operations.
|
||||
|
||||
This function abstracts the S2T copy pattern which involves:
|
||||
1. Filtering zeros from src (smem) and dst (tmem) tensors
|
||||
2. Creating a tiled copy using make_s2t_copy
|
||||
3. Partitioning source and destination
|
||||
4. Getting the SMEM descriptor tensor
|
||||
5. Executing the copy
|
||||
|
||||
:param tiled_copy: The tiled copy for S2T operations.
|
||||
:type tiled_copy: TiledCopy
|
||||
:param src: The source tensor in shared memory.
|
||||
:type src: Tensor
|
||||
:param dst: The destination tensor in TMEM.
|
||||
:type dst: Tensor
|
||||
"""
|
||||
# Filter zeros from src (smem) and dst (tmem) tensors
|
||||
src_compact = filter_zeros(src)
|
||||
dst_compact = filter_zeros(dst)
|
||||
|
||||
# S2T has a single thread slice; election handled automatically in lowering
|
||||
thr_copy = tiled_copy.get_slice(0)
|
||||
|
||||
# Partition source and destination
|
||||
src_partitioned = thr_copy.partition_S(src_compact, loc=loc, ip=ip)
|
||||
dst_partitioned = thr_copy.partition_D(dst_compact, loc=loc, ip=ip)
|
||||
|
||||
# Get SMEM descriptor tensor for the source
|
||||
smem_desc_tensor = tcgen05.get_s2t_smem_desc_tensor(
|
||||
tiled_copy, src_partitioned, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
# Execute the copy
|
||||
return copy(tiled_copy, smem_desc_tensor, dst_partitioned, loc=loc, ip=ip, **kwargs)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
@CuTeDSL.jit
|
||||
def block_copy(
|
||||
tiled_copy: TiledCopy,
|
||||
src: Tensor,
|
||||
dst: Tensor,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Performs a block-level copy operation.
|
||||
|
||||
This function adds an abstraction layer over the `cute.copy` usage model by
|
||||
allowing operands with layouts shaped like tiles to be passed directly. This
|
||||
removes the need to manually partition. The API is designed to support multiple
|
||||
copy kinds; currently TMA-based copies and S2T (SMEM to TMEM) copies are supported.
|
||||
|
||||
**TMA copy requirements**:
|
||||
|
||||
When using TMA-based tiled copies, the ``src`` and ``dst`` tensors must have
|
||||
their first mode representing the TMATile, i.e. tensors shaped as ``(TMATile, Rest...)``.
|
||||
For a rank-2 tensor with logical layout (e.g., ``(TILE_M, TILE_N)``), call
|
||||
``group_modes(tensor, 0, 2)`` before passing it to this function.
|
||||
|
||||
**TMA multicast support**:
|
||||
|
||||
For TMA-based copies that enable compiler-driven multicast in a 2D cluster, pass the
|
||||
``tma_multicast`` argument as a dict with the following keys:
|
||||
|
||||
- ``cluster_shape``: a tuple of 2 integers ``(cluster_m, cluster_n)``
|
||||
representing the **2D cluster shape**.
|
||||
- ``multicast_dim``: either ``"M"`` or ``"N"`` indicating which
|
||||
cluster dimension the multicast happens along.
|
||||
- ``use_2cta_mma_inst`` (optional): a ``bool`` indicating whether to
|
||||
use 2CTA MMA instructions when the loaded data is consumed by MMA.
|
||||
Defaults to ``False`` when omitted.
|
||||
|
||||
**S2T (SMEM to TMEM) copy**:
|
||||
|
||||
When using S2T copy operations (e.g., ``tcgen05.Cp4x32x128bOp``), the function
|
||||
automatically handles the filtering, partitioning, and SMEM descriptor creation.
|
||||
Pass a copy atom created with ``cute.make_copy_atom(tcgen05.Cp*Op(...), dtype)``
|
||||
along with source (SMEM) and destination (TMEM) tensors.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# 1) TMA load without compiler-driven multicast
|
||||
# Note: group_modes is called to make the first mode TMATile
|
||||
block_copy(tma_atom_a, group_modes(tCgA_, 0, 2), group_modes(tCsA_, 0, 2),
|
||||
tma_bar_ptr=tma_bar_ptr)
|
||||
|
||||
# 2) TMA load with compiler-driven multicast along M in a (4,2) cluster
|
||||
block_copy(
|
||||
tma_atom_a,
|
||||
group_modes(tCgA_, 0, 2),
|
||||
group_modes(tCsA_, 0, 2),
|
||||
tma_multicast={
|
||||
"cluster_shape": (4, 2),
|
||||
"multicast_dim": "M",
|
||||
"use_2cta_mma_inst": True,
|
||||
},
|
||||
tma_bar_ptr=tma_bar_ptr,
|
||||
)
|
||||
|
||||
# 3) TMA store
|
||||
# Note that `tma_bar_ptr` and CTA params (`cta_coord` and `cta_layout`)
|
||||
# are not needed for TMA store
|
||||
block_copy(tma_atom_c, group_modes(tCsC_, 0, 2), group_modes(tCgC_, 0, 2))
|
||||
|
||||
# 4) S2T copy (SMEM to TMEM)
|
||||
copy_atom_s2t = cute.make_copy_atom(
|
||||
tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), sf_dtype
|
||||
)
|
||||
block_copy(copy_atom_s2t, tCsSF, tCtSF)
|
||||
|
||||
:param tiled_copy: The tiled_copy or copy_atom of the current copy operation.
|
||||
:type tiled_copy: TiledCopy
|
||||
:param src: The source tensor.
|
||||
:type src: Tensor
|
||||
:param dst: The destination tensor.
|
||||
:type dst: Tensor
|
||||
:param tma_multicast: Optional dict for TMA multicast configuration with keys
|
||||
``cluster_shape``, ``multicast_dim``, and optionally
|
||||
``use_2cta_mma_inst``.
|
||||
:type tma_multicast: dict, optional
|
||||
"""
|
||||
import cutlass # local import to avoid circular import at module load time
|
||||
|
||||
if cutlass.const_expr(isinstance(tiled_copy.op, TmaCopyOp)):
|
||||
return _tma_copy_impl(tiled_copy, src, dst, loc=loc, ip=ip, **kwargs)
|
||||
elif cutlass.const_expr(isinstance(tiled_copy.op, _S2TCopyBase)):
|
||||
return _utccp_copy_impl(tiled_copy, src, dst, loc=loc, ip=ip, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Copy op {type(tiled_copy.op).__name__} is not supported yet."
|
||||
)
|
||||
@@ -1,3 +1,3 @@
|
||||
# Use `pip install -r requirements-cu13.txt` with the present file to install a
|
||||
# wheel consistent with the present state of the github repository
|
||||
nvidia-cutlass-dsl[cu13]==4.4.2
|
||||
nvidia-cutlass-dsl[cu13]==4.5.1
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# Use `pip install -r requirements.txt` with the present file to install a
|
||||
# wheel consistent with the present state of the github repository
|
||||
nvidia-cutlass-dsl==4.4.2
|
||||
nvidia-cutlass-dsl==4.5.1
|
||||
|
||||
@@ -133,7 +133,7 @@ def get_option_registry():
|
||||
this._option_registry = OptionRegistry(device_cc())
|
||||
return this._option_registry
|
||||
|
||||
this.__version__ = '4.5.0'
|
||||
this.__version__ = '4.5.1'
|
||||
|
||||
from cutlass_cppgen.backend import create_memory_pool
|
||||
from cutlass_cppgen.emit.pytorch import pytorch
|
||||
|
||||
@@ -51,7 +51,7 @@ setup_pycute.perform_setup()
|
||||
|
||||
setup(
|
||||
name='cutlass_cppgen',
|
||||
version='4.5.0',
|
||||
version='4.5.1',
|
||||
description='CUTLASS Pythonic Interface',
|
||||
package_dir={'': '.'},
|
||||
packages=[
|
||||
|
||||
@@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='cutlass_library',
|
||||
version='4.5.0',
|
||||
version='4.5.1',
|
||||
description='CUTLASS library generation scripts',
|
||||
packages=['cutlass_library']
|
||||
)
|
||||
|
||||
@@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='pycute',
|
||||
version='4.5.0',
|
||||
version='4.5.1',
|
||||
description='Python implementation of CuTe',
|
||||
packages=['pycute'],
|
||||
)
|
||||
|
||||
@@ -554,6 +554,112 @@ struct RandomUniformFunc {
|
||||
}
|
||||
};
|
||||
|
||||
/// Computes an exponent-uniform random distribution for UE8M0 scale factors.
|
||||
template <>
|
||||
struct RandomUniformFunc<float_ue8m0_t> {
|
||||
|
||||
using Element = float_ue8m0_t;
|
||||
using FloatType = float;
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
uint64_t seed;
|
||||
int exp_min;
|
||||
int exp_range;
|
||||
int int_scale; ///< Retained for Params compatibility; exponent is integral.
|
||||
double pnan;
|
||||
int exclude_zero; ///< Retained for Params compatibility; unused for UE8M0.
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static int closest_log2_exp(FloatType value) {
|
||||
using CUTLASS_CMATH_NAMESPACE :: log2;
|
||||
using CUTLASS_CMATH_NAMESPACE :: nearbyint;
|
||||
|
||||
// UE8M0 scale factors are strictly positive. Keep invalid lower bounds
|
||||
// finite so callers using the generic [0, max] default do not produce NaN.
|
||||
FloatType min_scale = FloatType(Element::bitcast(0x01));
|
||||
FloatType positive_value = value > FloatType(0) ? value : min_scale;
|
||||
return int(nearbyint(log2(positive_value)));
|
||||
}
|
||||
|
||||
/// Construction of uniform RNG functor.
|
||||
Params(
|
||||
uint64_t seed_ = 0,
|
||||
FloatType max_ = FloatType(1),
|
||||
FloatType min_ = FloatType(0),
|
||||
int int_scale_ = -1,
|
||||
double pnan_ = 0,
|
||||
int exclude_zero_ = -1
|
||||
):
|
||||
seed(seed_),
|
||||
exp_min(closest_log2_exp(min_)),
|
||||
exp_range(closest_log2_exp(max_) - closest_log2_exp(min_)),
|
||||
int_scale(int_scale_),
|
||||
pnan(pnan_),
|
||||
exclude_zero(exclude_zero_) {
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters object
|
||||
Params params;
|
||||
|
||||
/// RNG state object
|
||||
curandState_t rng_state;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Device-side initialization of RNG
|
||||
CUTLASS_DEVICE
|
||||
RandomUniformFunc(Params const ¶ms): params(params) {
|
||||
|
||||
uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
|
||||
curand_init(params.seed, gtid, 0, &rng_state);
|
||||
}
|
||||
|
||||
/// Compute random value and update RNG state
|
||||
CUTLASS_DEVICE
|
||||
Element operator()() {
|
||||
|
||||
// Draw random float in [0.0, 1.0] to determine if element should be NaN.
|
||||
if constexpr (std::numeric_limits<Element>::has_quiet_NaN) {
|
||||
if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) {
|
||||
return Element(NAN);
|
||||
}
|
||||
}
|
||||
|
||||
using CUTLASS_CMATH_NAMESPACE :: pow;
|
||||
|
||||
FloatType rnd = random_uniform_float<FloatType>(&rng_state);
|
||||
int exponent_count = params.exp_range + 1;
|
||||
int exponent_offset = int(rnd * FloatType(exponent_count));
|
||||
exponent_offset = exponent_offset < exponent_count ? exponent_offset : exponent_count - 1;
|
||||
FloatType exp = FloatType(params.exp_min + exponent_offset);
|
||||
FloatType sf = FloatType(pow(FloatType(2), exp));
|
||||
|
||||
return Element(sf);
|
||||
}
|
||||
};
|
||||
|
||||
/// Computes a random Gaussian distribution
|
||||
template <typename Real>
|
||||
struct RandomUniformFunc<complex<Real>> {
|
||||
@@ -763,6 +869,16 @@ struct TensorFillRandomUniformFunc {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Element>
|
||||
struct UniformDistributionValueType {
|
||||
using Type = typename RealType<Element>::Type;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UniformDistributionValueType<float_ue8m0_t> {
|
||||
using Type = float;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -774,8 +890,10 @@ template <
|
||||
void TensorFillRandomUniform(
|
||||
TensorView<Element, Layout> view, ///< destination tensor
|
||||
uint64_t seed, ///< seed for RNG
|
||||
typename RealType<Element>::Type max = Element(1), ///< upper bound of distribution
|
||||
typename RealType<Element>::Type min = Element(0), ///< lower bound for distribution
|
||||
typename detail::UniformDistributionValueType<Element>::Type max =
|
||||
typename detail::UniformDistributionValueType<Element>::Type(1), ///< upper bound of distribution
|
||||
typename detail::UniformDistributionValueType<Element>::Type min =
|
||||
typename detail::UniformDistributionValueType<Element>::Type(0), ///< lower bound for distribution
|
||||
int bits = -1, ///< If non-negative, specifies number of fractional bits that
|
||||
/// are not truncated to zero. Permits reducing precision of
|
||||
/// data.
|
||||
@@ -805,8 +923,8 @@ void BlockFillRandomUniform(
|
||||
Element *ptr,
|
||||
size_t capacity,
|
||||
uint64_t seed, ///< seed for RNG
|
||||
typename RealType<Element>::Type max, ///< upper bound of distribution
|
||||
typename RealType<Element>::Type min, ///< lower bound for distribution
|
||||
typename detail::UniformDistributionValueType<Element>::Type max, ///< upper bound of distribution
|
||||
typename detail::UniformDistributionValueType<Element>::Type min, ///< lower bound for distribution
|
||||
int bits = -1, ///< If non-negative, specifies number of fractional bits that
|
||||
/// are not truncated to zero. Permits reducing precision of
|
||||
/// data.
|
||||
@@ -1768,6 +1886,7 @@ void TensorFillRandom(
|
||||
) {
|
||||
|
||||
using Real = typename RealType<Element>::Type;
|
||||
using UniformReal = typename detail::UniformDistributionValueType<Element>::Type;
|
||||
|
||||
if (dist.kind == Distribution::Gaussian) {
|
||||
TensorFillRandomGaussian<Element, Layout>(
|
||||
@@ -1782,8 +1901,8 @@ void TensorFillRandom(
|
||||
TensorFillRandomUniform<Element, Layout>(
|
||||
view,
|
||||
seed,
|
||||
static_cast<Real>(dist.uniform.max),
|
||||
static_cast<Real>(dist.uniform.min),
|
||||
static_cast<UniformReal>(dist.uniform.max),
|
||||
static_cast<UniformReal>(dist.uniform.min),
|
||||
dist.int_scale,
|
||||
dist.uniform.pnan,
|
||||
exclude_zero,
|
||||
@@ -1830,6 +1949,7 @@ void BlockFillRandom(
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
using Real = typename RealType<Element>::Type;
|
||||
using UniformReal = typename detail::UniformDistributionValueType<Element>::Type;
|
||||
|
||||
if (dist.kind == Distribution::Gaussian) {
|
||||
BlockFillRandomGaussian<Element>(
|
||||
@@ -1846,8 +1966,8 @@ void BlockFillRandom(
|
||||
ptr,
|
||||
capacity,
|
||||
seed,
|
||||
static_cast<Real>(dist.uniform.max),
|
||||
static_cast<Real>(dist.uniform.min),
|
||||
static_cast<UniformReal>(dist.uniform.max),
|
||||
static_cast<UniformReal>(dist.uniform.min),
|
||||
dist.int_scale,
|
||||
dist.uniform.pnan,
|
||||
stream);
|
||||
|
||||
@@ -658,6 +658,74 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Computes an exponent-uniform random distribution for UE8M0 scale factors.
|
||||
template <>
|
||||
struct RandomUniformFunc<float_ue8m0_t> {
|
||||
|
||||
using Element = float_ue8m0_t;
|
||||
|
||||
uint64_t seed;
|
||||
int exp_min;
|
||||
int exp_range;
|
||||
int int_scale; ///< Retained for Params compatibility; exponent is integral.
|
||||
|
||||
double pnan;
|
||||
private:
|
||||
using engine_type = std::mt19937;
|
||||
public:
|
||||
engine_type bernoulli_rnd;
|
||||
std::bernoulli_distribution bernoulli_dist;
|
||||
|
||||
bool exclude_zero; ///< Retained for Params compatibility; unused for UE8M0.
|
||||
|
||||
static int closest_log2_exp(double value) {
|
||||
// UE8M0 scale factors are strictly positive. Keep invalid lower bounds
|
||||
// finite so callers using the generic [0, max] default do not produce NaN.
|
||||
double min_scale = double(Element::bitcast(0x01));
|
||||
double positive_value = value > 0.0 ? value : min_scale;
|
||||
return int(std::nearbyint(std::log2(positive_value)));
|
||||
}
|
||||
|
||||
RandomUniformFunc(
|
||||
uint64_t seed_ = 0,
|
||||
double max = 1,
|
||||
double min_ = 0,
|
||||
int int_scale_ = -1,
|
||||
double pnan_ = 0,
|
||||
bool exclude_zero_ = false
|
||||
):
|
||||
seed(seed_),
|
||||
exp_min(closest_log2_exp(min_)),
|
||||
exp_range(closest_log2_exp(max) - closest_log2_exp(min_)),
|
||||
int_scale(int_scale_),
|
||||
pnan(pnan_),
|
||||
bernoulli_rnd{static_cast<engine_type::result_type>(seed_)},
|
||||
bernoulli_dist(pnan_),
|
||||
exclude_zero(exclude_zero_)
|
||||
{
|
||||
std::srand((unsigned)seed);
|
||||
}
|
||||
|
||||
/// Compute random value and update RNG state
|
||||
Element operator()() {
|
||||
|
||||
// Sample from NaN distribution.
|
||||
if constexpr (std::numeric_limits<Element>::has_quiet_NaN) {
|
||||
if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) {
|
||||
return Element(NAN);
|
||||
}
|
||||
}
|
||||
|
||||
double rnd = double(std::rand()) / double(RAND_MAX);
|
||||
int exponent_count = exp_range + 1;
|
||||
int exponent_offset = int(rnd * double(exponent_count));
|
||||
exponent_offset = exponent_offset < exponent_count ? exponent_offset : exponent_count - 1;
|
||||
double sf = std::pow(2.0, double(exp_min + exponent_offset));
|
||||
|
||||
return Element(sf);
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization for initializing a complex value.
|
||||
template <typename Element>
|
||||
struct RandomUniformFunc<complex<Element> > {
|
||||
|
||||
Reference in New Issue
Block a user