CK-UA: VGPR-pressure toggles for kv128 probing (all default OFF)

Adds compile-time levers, all guarded and bit-identical to production when
unset, used to characterise why prefill_d128 fp8 fits KV tile 64 but not 128
under the 256-VGPR/wave ceiling (see ua-test-scripts/kv128_vgpr_findings.md):

- UA_PREFILL_D128_BLOCKSIZE (default 64): KV-tile override for probing kv128.
- UA_FA4_INPLACE_DELTA (default 0): drop sp_delta, scale-shift/exp2 in place on
  sp_compute (fmha_alu_D_upd reads only m/l/o_acc/rowsum_p, never raw scores, so
  bit-identical). VGPR-neutral on its own (compiler already reclaims sp_delta).
- UA_FA4_SHARED_SPCOMPUTE (default 0): keep ONE shared fp32 sp_compute + a
  2-slot fp8 P ping-pong instead of a 2-slot union{sp_compute,p}. The deferred
  PV only needs one live fp32 score; this cuts kv128 spills 173 -> 126. (Forces
  in-place delta; slightly regresses kv64 so it is a kv128-only lever.)
- UA_FA4_UNION_KV (default 0): union k_tile/v_tile (ASM-style). VGPR-neutral;
  kept as a documented dead end (compiler already overlaps their live ranges).

P thread-buffer size exposed as a type-derived constexpr (kPThreadBufSize) so
the static_assert/static_for sites work when sp(idx) is the runtime proxy.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
juuso-oskari
2026-06-11 15:29:04 +00:00
parent 9aa380e6c2
commit 8abbd21a01
2 changed files with 102 additions and 8 deletions

View File

@@ -149,7 +149,14 @@ struct variant_config<KernelVariant::prefill_d128>
// spills at kv64. Unlocking kv128 needs a VGPR-pressure cut (smaller kBlockM
// for this tile, or sub-tiling kBlockN so the live score tile stays 32x64),
// not an LDS change. Stay at 64 until that lands.
static constexpr index_t BlockSize = 64;
//
// UA_PREFILL_D128_BLOCKSIZE: compile-time override of the KV tile so the
// VGPR-pressure experiments (kv128 + sub-tiling) can be probed without
// editing this line each build. Defaults to the production 64.
#ifndef UA_PREFILL_D128_BLOCKSIZE
#define UA_PREFILL_D128_BLOCKSIZE 64
#endif
static constexpr index_t BlockSize = UA_PREFILL_D128_BLOCKSIZE;
using BlockWarps = sequence<8, 1, 1>;
using WarpGemmShape = sequence<32, 32, 16>;
template <typename Problem, index_t PageSize = 0, bool IsPaged = true>

View File

@@ -23,6 +23,47 @@
#define UA_FA4_VLOAD_IN_MATRIX 1
#endif
// UA_FA4_INPLACE_DELTA (VGPR-pressure cut, blocks kv128):
// 0 (default): fmha_alu0 writes the scaled-shifted score into a dedicated
// `sp_delta` scratch (statically_indexed_array<sp_compute, 2>), and
// fmha_alu1 reads it back to compute exp2. That scratch is sized like the
// full score tile and declared per-slot (2x), so at kv128 it is up to
// ~64 VGPR of live state that exists only to bridge alu0 -> D_upd -> alu1.
// 1: do the scale-shift in place on sp_compute (alu0) and the exp2 in place
// (alu1). fmha_alu_D_upd between them reads only m/l/o_acc/rowsum_p, never
// the raw scores, so this is mathematically and bit identical. Removes the
// sp_delta array entirely. (TODO at the sp_delta decl asked for exactly
// this.)
#ifndef UA_FA4_INPLACE_DELTA
#define UA_FA4_INPLACE_DELTA 0
#endif
// UA_FA4_SHARED_SPCOMPUTE (VGPR-pressure cut, blocks kv128):
// 0 (default): `sp` is a 2-slot array of a union{sp_compute(fp32), p(fp8)}.
// Each slot is sized at the fp32 score tile, so two slots reserve ~2x the
// score tile (~128 VGPR at kv128) even though the deferred-PV pipeline
// only keeps ONE fp32 score live at a time.
// 1: keep only ONE shared fp32 sp_compute, plus a 2-slot fp8 `p` ping-pong
// (the deferred PV reads p[pi] from a prior softmax while QK overwrites the
// single sp_compute; the softmax that drains sp_compute always runs before
// the next QK within a warp group, so one fp32 buffer is sufficient).
// Saves one full fp32 score tile. Requires in-place delta (the sp_delta
// decltype would otherwise see a reference member), so it force-enables it.
#ifndef UA_FA4_SHARED_SPCOMPUTE
#define UA_FA4_SHARED_SPCOMPUTE 0
#endif
// UA_FA4_UNION_KV (VGPR-pressure cut, blocks kv128): union the register-
// resident k_tile and v_tile (ASM unions K/V). Default 0 (separate, overlaps
// the K ds_read with the PV MFMA). See the kv_tile_type definition.
#ifndef UA_FA4_UNION_KV
#define UA_FA4_UNION_KV 0
#endif
#if UA_FA4_SHARED_SPCOMPUTE
#undef UA_FA4_INPLACE_DELTA
#define UA_FA4_INPLACE_DELTA 1
#endif
// FMHA_MASK PLACEMENT: pick exactly one of:
// - both 0 → baseline (mask in K-side memory phase, W0-3 phase 1
// / W4-7 phase 2, right after `cl_load(memK)`).
@@ -621,6 +662,18 @@ struct UnifiedAttentionPipeline
// Separate tiles let the K ds_read execute on the LSU *concurrently*
// with the PV MFMA (it stays at the same program point, so the
// cooperative-load residency slack is preserved -- see fa4_matrix).
// UA_FA4_UNION_KV: union k_tile/v_tile (ASM-style) to halve their VGPR
// footprint at the cost of serialising the K ds_read behind the PV MFMA
// (which reads v_tile). Default off (separate tiles overlap the K read
// with PV — measured ~3-4% faster); a VGPR-pressure lever for kv128.
#if UA_FA4_UNION_KV
union kv_tile_type
{
CK_TILE_DEVICE kv_tile_type() {}
decltype(load_tile(k_lds_window_load(number<0>{}))) k_tile;
decltype(load_tile_transpose(v_lds_window_load(number<0>{}))) v_tile;
} kv_tile;
#else
struct kv_tile_type
{
CK_TILE_DEVICE kv_tile_type() {}
@@ -629,7 +682,22 @@ struct UnifiedAttentionPipeline
decltype(load_tile_transpose(v_lds_window_load(number<0>{}))) v_tile;
} kv_tile;
#endif
#if UA_FA4_SHARED_SPCOMPUTE
// ONE fp32 score tile shared by both slots + a 2-slot fp8 P ping-pong.
using SpComputeT = decltype(gemm_0.MakeCBlockTile());
using PTileT = decltype(make_static_distributed_tensor<PDataType>(
Policy::template MakePRegTileDistribution<Problem>()));
SpComputeT sp_compute_shared;
statically_indexed_array<PTileT, 2> p_slots;
struct SpRef
{
SpComputeT& sp_compute;
PTileT& p;
};
auto sp = [&](auto idx) -> SpRef { return SpRef{sp_compute_shared, p_slots(idx)}; };
#else
union sp_compute_type
{
CK_TILE_DEVICE sp_compute_type() {}
@@ -639,6 +707,13 @@ struct UnifiedAttentionPipeline
Policy::template MakePRegTileDistribution<Problem>())) p;
};
statically_indexed_array<sp_compute_type, 2> sp;
#endif
// P thread-buffer length as a type-derived constexpr; usable in the
// static_assert / static_for sites even when `sp(idx)` is a runtime
// proxy (SHARED_SPCOMPUTE) rather than an array element.
static constexpr index_t kPThreadBufSize =
decltype(make_static_distributed_tensor<PDataType>(
Policy::template MakePRegTileDistribution<Problem>()))::get_thread_buffer_size();
decltype(gemm_1.MakeCBlockTile()) o_acc;
constexpr index_t fmha_alu_D_reg_cnt = 6; // threshold to decide how many fmha_alu_D_upd()
@@ -1588,7 +1663,9 @@ struct UnifiedAttentionPipeline
decltype(m) m_old;
SMPLComputeDataType o_acc_scale; // rescale o_acc in fmha_alu1() & fmha_alu_D_upd()
/// TODO: remove the sp_delta and use sp_compute directly
#if !UA_FA4_INPLACE_DELTA
statically_indexed_array<decltype(sp(number<0>{}).sp_compute), 2> sp_delta;
#endif
auto fmha_alu0 = [&](auto sp_reg_idx) {
m_old = m; // m{j-1}
@@ -1651,8 +1728,13 @@ struct UnifiedAttentionPipeline
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if UA_FA4_INPLACE_DELTA
sp(sp_reg_idx).sp_compute(i_j_idx) = detail::fma_impl_vsv(
sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m_shift(i_j_idx));
#else
sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv(
sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m_shift(i_j_idx));
#endif
});
});
/// TODO: move some fmha_alu1() code here if necessary
@@ -1664,8 +1746,13 @@ struct UnifiedAttentionPipeline
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if UA_FA4_INPLACE_DELTA
sp(sp_reg_idx).sp_compute(i_j_idx) =
ck_tile::exp2(sp(sp_reg_idx).sp_compute(i_j_idx));
#else
sp(sp_reg_idx).sp_compute(i_j_idx) =
ck_tile::exp2(sp_delta(sp_reg_idx)(i_j_idx));
#endif
});
});
@@ -1735,7 +1822,7 @@ struct UnifiedAttentionPipeline
/// result 'p' is only consumed later. To anchor them here, we rewrite
/// the cast_tile() call as inline assembly, forcing the conversions to be
/// emitted at this point.
static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0);
static_assert(kPThreadBufSize % 2 == 0);
if constexpr(std::is_same_v<PDataType, fp8_t>)
{
// FP8 P packing for the PV gemm.
@@ -1758,7 +1845,7 @@ struct UnifiedAttentionPipeline
// bytes back into `p.thread_buf_`. We still anchor the work
// inline (no `cast_tile(...)` indirection) so the conversions
// stay at the end of `fmha_alu1` like the FP16/BF16 paths.
static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 4 == 0,
static_assert(kPThreadBufSize % 4 == 0,
"fp8 P conversion expects packs of 4 fp32 lanes per "
"thread; widen the warp gemm M distribution if this "
"trips.");
@@ -1833,7 +1920,7 @@ struct UnifiedAttentionPipeline
// sub=0 | slot[4..7] | N=8..11 | K=4..7 BAD
// sub=1 | slot[0..3] | N=4..7 | K=8..11 BAD
// sub=1 | slot[4..7] | N=12..15 | K=12..15 OK
static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 8 == 0,
static_assert(kPThreadBufSize % 8 == 0,
"FP8 32x32x16 + Single cross-lane permute "
"expects PV per-thread buffer in chunks of 8 "
"fp8 (one warp-gemm K iteration).");
@@ -1851,7 +1938,7 @@ struct UnifiedAttentionPipeline
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wuninitialized"
int dummy_old;
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 8>{}([&](auto k_base) {
static_for<0, kPThreadBufSize, 8>{}([&](auto k_base) {
auto& p = sp(sp_reg_idx).p;
auto& sc = sp(sp_reg_idx).sp_compute;
@@ -1936,7 +2023,7 @@ struct UnifiedAttentionPipeline
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wuninitialized"
int dummy_old;
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 4>{}([&](auto idx) {
static_for<0, kPThreadBufSize, 4>{}([&](auto idx) {
const float a =
p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 0]);
const float b =
@@ -1992,7 +2079,7 @@ struct UnifiedAttentionPipeline
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wuninitialized"
int dummy_old;
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 4>{}([&](auto idx) {
static_for<0, kPThreadBufSize, 4>{}([&](auto idx) {
const float a = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 0]);
const float b = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]);
const float c = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 2]);
@@ -2039,7 +2126,7 @@ struct UnifiedAttentionPipeline
}
else
{
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](auto idx) {
static_for<0, kPThreadBufSize, 2>{}([&](auto idx) {
float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]);
float y = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]);
if constexpr(std::is_same_v<PDataType, fp16_t>)