CK-UA: halve kBlockN for bf16/fp16 m16 decode + generalise PVAttrNumAccess

The decode_d128_m16 tier was VGPR-saturated and LDS-bound on bf16/fp16
(probe_decode_d128 showed VGPR=256 + AGPR overflow, ~2x fp8's LDS at
the same kBlockN), capping it at 1 CTA/CU. Halving kBlockN for the
non-fp8 path on the m16 tier sheds enough LDS and VGPR pressure to
fit 3-4 CTAs/CU (LDS-bound). The halved kBlockN forces a smaller-K
MFMA shape on the m16 PV gemm (16x16x32 -> 16x16x16); we also auto-
adjust WarpGemm::K so PVAttrNumAccess picks Single vs Double access
correctly. The PVAttrNumAccess derivation is now generic — driven by
(kABKPerLane, SubMinDim) rather than just (dtype) — so the new
shape compiles without per-variant special-casing.

Variants only affected where cfg::BlockSize/2 >= WarpGemm::N (i.e.
decode_d128_m16); m32/m128/prefill keep their un-halved tiles since
they use 32x32 N-warps.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
juuso-oskari
2026-05-26 08:20:55 +00:00
parent 89b54563b6
commit 310efc556f
2 changed files with 56 additions and 15 deletions

View File

@@ -3,6 +3,7 @@
#pragma once
#include <type_traits>
#include <utility>
#include "ck_tile/core/numeric/bfloat16.hpp"
@@ -259,9 +260,42 @@ struct unified_attention_kernel_traits
static constexpr index_t HEAD_SIZE = cfg::HeadSize;
static constexpr index_t kBlockM = cfg::BlockM;
static constexpr index_t BLOCK_SIZE = cfg::BlockSize;
static constexpr bool kUseDecodeGrid = cfg::kUseDecodeGrid;
// bf16/fp16 carry a 2-byte element vs fp8's 1 byte, so at the same kBlockN
// they double both LDS usage and per-tile VGPR pressure. The decode probe
// (ua-test-scripts/probe_decode_d128.sh) showed bf16 saturating VGPR=256
// with AGPR overflow (44-106 AGPRs) and ~2x the LDS of fp8 on the m16
// tier — the LDS pressure alone caps decode_d128_m16 at 1 CTA/CU. We
// halve kBlockN for bf16/fp16 to shed LDS and VGPR pressure, trading a
// small per-iter overhead for a big occupancy boost.
//
// The halved kBlockN must satisfy both gemm constraints:
// QK gemm: kBlockN is the N axis -> kBlockN >= WarpGemm::N
// PV gemm: kBlockN is the K axis -> kBlockN >= WarpGemm::K
// When `cfg::BlockSize/2 < WarpGemm::K` we additionally swap WarpGemm::K
// to the halved kBlockN so the smaller-K MFMA is used. For our variants
// this only hits decode_d128_m16 (WG=<16,16,32> -> <16,16,16>); the d=64
// tiers already fit the smaller tile under the same 16x16x32 / 32x32x16
// MFMA. The 32x32 N-warps (m32/m128/prefill) cannot drop their kBlockN
// below 32 (WG::N=32) and stay at the un-halved size.
static constexpr index_t WGM_ = cfg::WarpGemmShape::at(number<0>{});
static constexpr index_t WGN_ = cfg::WarpGemmShape::at(number<1>{});
static constexpr index_t WGK_ = cfg::WarpGemmShape::at(number<2>{});
static constexpr bool kBf16HalveBlockN =
(DataType != unified_attention_args::data_type_enum::fp8) &&
(cfg::BlockSize / 2 >= WGN_);
static constexpr index_t BLOCK_SIZE =
kBf16HalveBlockN ? cfg::BlockSize / 2 : cfg::BlockSize;
// Swap WarpGemm::K down to BLOCK_SIZE when the halved kBlockN dropped
// below the original WarpGemm::K. PVAttrNumAccess in GetPVBlockGemm
// recomputes from the new WarpGemm shape (lanes_in_K * SubMinDim rule)
// so the smaller-K MFMA tiles cleanly.
using unified_attention_warp_gemm_shape = std::conditional_t<
(kBf16HalveBlockN && BLOCK_SIZE < WGK_),
sequence<WGM_, WGN_, BLOCK_SIZE>,
typename cfg::WarpGemmShape>;
// The 2nd entry of the BlockTile is the static `kBlockQ` exposed via
// `UnifiedAttentionShape::kBlockQ`. Now that the kernel always reads
// kBlockQ from `args.num_queries_per_kv` at runtime, this static value
@@ -269,9 +303,8 @@ struct unified_attention_kernel_traits
// happens in practice). Anchor it at kBlockM so the static "looks like
// num_qpkv == 1" and any (kBlockM, num_qpkv) such that kBlockM % num_qpkv
// == 0 works without touching this trait.
using unified_attention_block_tile = sequence<kBlockM, kBlockM, BLOCK_SIZE, HEAD_SIZE>;
using unified_attention_warp_gemm_shape = typename cfg::WarpGemmShape;
using unified_attention_block_warps = typename cfg::BlockWarps;
using unified_attention_block_tile = sequence<kBlockM, kBlockM, BLOCK_SIZE, HEAD_SIZE>;
using unified_attention_block_warps = typename cfg::BlockWarps;
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
unified_attention_block_warps,

View File

@@ -333,18 +333,26 @@ struct UnifiedAttentionPipelineDefaultPolicy
typename Problem::UnifiedAttentionShape::Gemm1WarpTile>>;
// `load_tile_transpose` is only valid when the tile distribution's inner
// packing matches the transpose engine's SubtileMinorDimension =
// 64 bits / sizeof(VDataType_in_bits). For BF16 / FP16 SubMinDim=4 and the
// PV warp gemm produces kABKPerLane / AttrNumAccess = 8 / 2 = 4 elements
// per lane on the K direction — Double access is needed there. For FP8
// SubMinDim=8 and kABKPerLane=8, so we must pass `Single` (otherwise
// 8/2=4 mismatches the FP8 SubMinDim=8 and the load_tile_transpose
// validation static_asserts fire — see DefaultTranspose::Quad in
// load_tile_transpose.hpp). The select is purely a compile-time alias.
// 64 bits / sizeof(VDataType_in_bits). The PV warp gemm produces
// kABKPerLane = WG::K / lanes_in_K elements per lane on the K direction
// (lanes_in_K = 4 for 16x16x* MFMA, 2 for 32x32x*), so we must pick
// AttrNumAccess such that kABKPerLane / AttrNumAccess == SubMinDim:
// bf16/fp16 16x16x32 -> kABKPerLane=8, SubMinDim=4 -> Double.
// bf16/fp16 16x16x16 -> kABKPerLane=4, SubMinDim=4 -> Single.
// bf16/fp16 32x32x16 -> kABKPerLane=8, SubMinDim=4 -> Double.
// fp8/bf8 16x16x32 -> kABKPerLane=8, SubMinDim=8 -> Single.
// fp8/bf8 32x32x16 -> kABKPerLane=8, SubMinDim=8 -> Single.
// The select is a compile-time alias.
static constexpr index_t kPVWarpGemmM =
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<0>{});
static constexpr index_t kPVWarpGemmK =
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<2>{});
static constexpr index_t kPVLanesInK = (kPVWarpGemmM == 16) ? 4 : 2;
static constexpr index_t kPVABKPerLane = kPVWarpGemmK / kPVLanesInK;
static constexpr index_t kPVSubMinDim = 8 / sizeof(typename Problem::VDataType);
static constexpr WGAttrNumAccessEnum PVAttrNumAccess =
std::is_same_v<remove_cvref_t<typename Problem::VDataType>, ck_tile::fp8_t> ||
std::is_same_v<remove_cvref_t<typename Problem::VDataType>, ck_tile::bf8_t>
? WGAttrNumAccessEnum::Single
: WGAttrNumAccessEnum::Double;
(kPVABKPerLane == kPVSubMinDim) ? WGAttrNumAccessEnum::Single
: WGAttrNumAccessEnum::Double;
using WarpGemm =
WarpGemmDispatcher<typename Problem::PDataType,
typename Problem::VDataType,