mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
CK-UA: VGPR-pressure toggles for kv128 probing (all default OFF)
Adds compile-time levers, all guarded and bit-identical to production when
unset, used to characterise why prefill_d128 fp8 fits KV tile 64 but not 128
under the 256-VGPR/wave ceiling (see ua-test-scripts/kv128_vgpr_findings.md):
- UA_PREFILL_D128_BLOCKSIZE (default 64): KV-tile override for probing kv128.
- UA_FA4_INPLACE_DELTA (default 0): drop sp_delta, scale-shift/exp2 in place on
sp_compute (fmha_alu_D_upd reads only m/l/o_acc/rowsum_p, never raw scores, so
bit-identical). VGPR-neutral on its own (compiler already reclaims sp_delta).
- UA_FA4_SHARED_SPCOMPUTE (default 0): keep ONE shared fp32 sp_compute + a
2-slot fp8 P ping-pong instead of a 2-slot union{sp_compute,p}. The deferred
PV only needs one live fp32 score; this cuts kv128 spills 173 -> 126. (Forces
in-place delta; slightly regresses kv64 so it is a kv128-only lever.)
- UA_FA4_UNION_KV (default 0): union k_tile/v_tile (ASM-style). VGPR-neutral;
kept as a documented dead end (compiler already overlaps their live ranges).
P thread-buffer size exposed as a type-derived constexpr (kPThreadBufSize) so
the static_assert/static_for sites work when sp(idx) is the runtime proxy.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -149,7 +149,14 @@ struct variant_config<KernelVariant::prefill_d128>
|
||||
// spills at kv64. Unlocking kv128 needs a VGPR-pressure cut (smaller kBlockM
|
||||
// for this tile, or sub-tiling kBlockN so the live score tile stays 32x64),
|
||||
// not an LDS change. Stay at 64 until that lands.
|
||||
static constexpr index_t BlockSize = 64;
|
||||
//
|
||||
// UA_PREFILL_D128_BLOCKSIZE: compile-time override of the KV tile so the
|
||||
// VGPR-pressure experiments (kv128 + sub-tiling) can be probed without
|
||||
// editing this line each build. Defaults to the production 64.
|
||||
#ifndef UA_PREFILL_D128_BLOCKSIZE
|
||||
#define UA_PREFILL_D128_BLOCKSIZE 64
|
||||
#endif
|
||||
static constexpr index_t BlockSize = UA_PREFILL_D128_BLOCKSIZE;
|
||||
using BlockWarps = sequence<8, 1, 1>;
|
||||
using WarpGemmShape = sequence<32, 32, 16>;
|
||||
template <typename Problem, index_t PageSize = 0, bool IsPaged = true>
|
||||
|
||||
@@ -23,6 +23,47 @@
|
||||
#define UA_FA4_VLOAD_IN_MATRIX 1
|
||||
#endif
|
||||
|
||||
// UA_FA4_INPLACE_DELTA (VGPR-pressure cut, blocks kv128):
|
||||
// 0 (default): fmha_alu0 writes the scaled-shifted score into a dedicated
|
||||
// `sp_delta` scratch (statically_indexed_array<sp_compute, 2>), and
|
||||
// fmha_alu1 reads it back to compute exp2. That scratch is sized like the
|
||||
// full score tile and declared per-slot (2x), so at kv128 it is up to
|
||||
// ~64 VGPR of live state that exists only to bridge alu0 -> D_upd -> alu1.
|
||||
// 1: do the scale-shift in place on sp_compute (alu0) and the exp2 in place
|
||||
// (alu1). fmha_alu_D_upd between them reads only m/l/o_acc/rowsum_p, never
|
||||
// the raw scores, so this is mathematically and bit identical. Removes the
|
||||
// sp_delta array entirely. (TODO at the sp_delta decl asked for exactly
|
||||
// this.)
|
||||
#ifndef UA_FA4_INPLACE_DELTA
|
||||
#define UA_FA4_INPLACE_DELTA 0
|
||||
#endif
|
||||
|
||||
// UA_FA4_SHARED_SPCOMPUTE (VGPR-pressure cut, blocks kv128):
|
||||
// 0 (default): `sp` is a 2-slot array of a union{sp_compute(fp32), p(fp8)}.
|
||||
// Each slot is sized at the fp32 score tile, so two slots reserve ~2x the
|
||||
// score tile (~128 VGPR at kv128) even though the deferred-PV pipeline
|
||||
// only keeps ONE fp32 score live at a time.
|
||||
// 1: keep only ONE shared fp32 sp_compute, plus a 2-slot fp8 `p` ping-pong
|
||||
// (the deferred PV reads p[pi] from a prior softmax while QK overwrites the
|
||||
// single sp_compute; the softmax that drains sp_compute always runs before
|
||||
// the next QK within a warp group, so one fp32 buffer is sufficient).
|
||||
// Saves one full fp32 score tile. Requires in-place delta (the sp_delta
|
||||
// decltype would otherwise see a reference member), so it force-enables it.
|
||||
#ifndef UA_FA4_SHARED_SPCOMPUTE
|
||||
#define UA_FA4_SHARED_SPCOMPUTE 0
|
||||
#endif
|
||||
|
||||
// UA_FA4_UNION_KV (VGPR-pressure cut, blocks kv128): union the register-
|
||||
// resident k_tile and v_tile (ASM unions K/V). Default 0 (separate, overlaps
|
||||
// the K ds_read with the PV MFMA). See the kv_tile_type definition.
|
||||
#ifndef UA_FA4_UNION_KV
|
||||
#define UA_FA4_UNION_KV 0
|
||||
#endif
|
||||
#if UA_FA4_SHARED_SPCOMPUTE
|
||||
#undef UA_FA4_INPLACE_DELTA
|
||||
#define UA_FA4_INPLACE_DELTA 1
|
||||
#endif
|
||||
|
||||
// FMHA_MASK PLACEMENT: pick exactly one of:
|
||||
// - both 0 → baseline (mask in K-side memory phase, W0-3 phase 1
|
||||
// / W4-7 phase 2, right after `cl_load(memK)`).
|
||||
@@ -621,6 +662,18 @@ struct UnifiedAttentionPipeline
|
||||
// Separate tiles let the K ds_read execute on the LSU *concurrently*
|
||||
// with the PV MFMA (it stays at the same program point, so the
|
||||
// cooperative-load residency slack is preserved -- see fa4_matrix).
|
||||
// UA_FA4_UNION_KV: union k_tile/v_tile (ASM-style) to halve their VGPR
|
||||
// footprint at the cost of serialising the K ds_read behind the PV MFMA
|
||||
// (which reads v_tile). Default off (separate tiles overlap the K read
|
||||
// with PV — measured ~3-4% faster); a VGPR-pressure lever for kv128.
|
||||
#if UA_FA4_UNION_KV
|
||||
union kv_tile_type
|
||||
{
|
||||
CK_TILE_DEVICE kv_tile_type() {}
|
||||
decltype(load_tile(k_lds_window_load(number<0>{}))) k_tile;
|
||||
decltype(load_tile_transpose(v_lds_window_load(number<0>{}))) v_tile;
|
||||
} kv_tile;
|
||||
#else
|
||||
struct kv_tile_type
|
||||
{
|
||||
CK_TILE_DEVICE kv_tile_type() {}
|
||||
@@ -629,7 +682,22 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
decltype(load_tile_transpose(v_lds_window_load(number<0>{}))) v_tile;
|
||||
} kv_tile;
|
||||
#endif
|
||||
|
||||
#if UA_FA4_SHARED_SPCOMPUTE
|
||||
// ONE fp32 score tile shared by both slots + a 2-slot fp8 P ping-pong.
|
||||
using SpComputeT = decltype(gemm_0.MakeCBlockTile());
|
||||
using PTileT = decltype(make_static_distributed_tensor<PDataType>(
|
||||
Policy::template MakePRegTileDistribution<Problem>()));
|
||||
SpComputeT sp_compute_shared;
|
||||
statically_indexed_array<PTileT, 2> p_slots;
|
||||
struct SpRef
|
||||
{
|
||||
SpComputeT& sp_compute;
|
||||
PTileT& p;
|
||||
};
|
||||
auto sp = [&](auto idx) -> SpRef { return SpRef{sp_compute_shared, p_slots(idx)}; };
|
||||
#else
|
||||
union sp_compute_type
|
||||
{
|
||||
CK_TILE_DEVICE sp_compute_type() {}
|
||||
@@ -639,6 +707,13 @@ struct UnifiedAttentionPipeline
|
||||
Policy::template MakePRegTileDistribution<Problem>())) p;
|
||||
};
|
||||
statically_indexed_array<sp_compute_type, 2> sp;
|
||||
#endif
|
||||
// P thread-buffer length as a type-derived constexpr; usable in the
|
||||
// static_assert / static_for sites even when `sp(idx)` is a runtime
|
||||
// proxy (SHARED_SPCOMPUTE) rather than an array element.
|
||||
static constexpr index_t kPThreadBufSize =
|
||||
decltype(make_static_distributed_tensor<PDataType>(
|
||||
Policy::template MakePRegTileDistribution<Problem>()))::get_thread_buffer_size();
|
||||
|
||||
decltype(gemm_1.MakeCBlockTile()) o_acc;
|
||||
constexpr index_t fmha_alu_D_reg_cnt = 6; // threshold to decide how many fmha_alu_D_upd()
|
||||
@@ -1588,7 +1663,9 @@ struct UnifiedAttentionPipeline
|
||||
decltype(m) m_old;
|
||||
SMPLComputeDataType o_acc_scale; // rescale o_acc in fmha_alu1() & fmha_alu_D_upd()
|
||||
/// TODO: remove the sp_delta and use sp_compute directly
|
||||
#if !UA_FA4_INPLACE_DELTA
|
||||
statically_indexed_array<decltype(sp(number<0>{}).sp_compute), 2> sp_delta;
|
||||
#endif
|
||||
|
||||
auto fmha_alu0 = [&](auto sp_reg_idx) {
|
||||
m_old = m; // m{j-1}
|
||||
@@ -1651,8 +1728,13 @@ struct UnifiedAttentionPipeline
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if UA_FA4_INPLACE_DELTA
|
||||
sp(sp_reg_idx).sp_compute(i_j_idx) = detail::fma_impl_vsv(
|
||||
sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m_shift(i_j_idx));
|
||||
#else
|
||||
sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv(
|
||||
sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m_shift(i_j_idx));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
/// TODO: move some fmha_alu1() code here if necessary
|
||||
@@ -1664,8 +1746,13 @@ struct UnifiedAttentionPipeline
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if UA_FA4_INPLACE_DELTA
|
||||
sp(sp_reg_idx).sp_compute(i_j_idx) =
|
||||
ck_tile::exp2(sp(sp_reg_idx).sp_compute(i_j_idx));
|
||||
#else
|
||||
sp(sp_reg_idx).sp_compute(i_j_idx) =
|
||||
ck_tile::exp2(sp_delta(sp_reg_idx)(i_j_idx));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1735,7 +1822,7 @@ struct UnifiedAttentionPipeline
|
||||
/// result 'p' is only consumed later. To anchor them here, we rewrite
|
||||
/// the cast_tile() call as inline assembly, forcing the conversions to be
|
||||
/// emitted at this point.
|
||||
static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0);
|
||||
static_assert(kPThreadBufSize % 2 == 0);
|
||||
if constexpr(std::is_same_v<PDataType, fp8_t>)
|
||||
{
|
||||
// FP8 P packing for the PV gemm.
|
||||
@@ -1758,7 +1845,7 @@ struct UnifiedAttentionPipeline
|
||||
// bytes back into `p.thread_buf_`. We still anchor the work
|
||||
// inline (no `cast_tile(...)` indirection) so the conversions
|
||||
// stay at the end of `fmha_alu1` like the FP16/BF16 paths.
|
||||
static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 4 == 0,
|
||||
static_assert(kPThreadBufSize % 4 == 0,
|
||||
"fp8 P conversion expects packs of 4 fp32 lanes per "
|
||||
"thread; widen the warp gemm M distribution if this "
|
||||
"trips.");
|
||||
@@ -1833,7 +1920,7 @@ struct UnifiedAttentionPipeline
|
||||
// sub=0 | slot[4..7] | N=8..11 | K=4..7 BAD
|
||||
// sub=1 | slot[0..3] | N=4..7 | K=8..11 BAD
|
||||
// sub=1 | slot[4..7] | N=12..15 | K=12..15 OK
|
||||
static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 8 == 0,
|
||||
static_assert(kPThreadBufSize % 8 == 0,
|
||||
"FP8 32x32x16 + Single cross-lane permute "
|
||||
"expects PV per-thread buffer in chunks of 8 "
|
||||
"fp8 (one warp-gemm K iteration).");
|
||||
@@ -1851,7 +1938,7 @@ struct UnifiedAttentionPipeline
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wuninitialized"
|
||||
int dummy_old;
|
||||
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 8>{}([&](auto k_base) {
|
||||
static_for<0, kPThreadBufSize, 8>{}([&](auto k_base) {
|
||||
auto& p = sp(sp_reg_idx).p;
|
||||
auto& sc = sp(sp_reg_idx).sp_compute;
|
||||
|
||||
@@ -1936,7 +2023,7 @@ struct UnifiedAttentionPipeline
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wuninitialized"
|
||||
int dummy_old;
|
||||
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 4>{}([&](auto idx) {
|
||||
static_for<0, kPThreadBufSize, 4>{}([&](auto idx) {
|
||||
const float a =
|
||||
p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 0]);
|
||||
const float b =
|
||||
@@ -1992,7 +2079,7 @@ struct UnifiedAttentionPipeline
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wuninitialized"
|
||||
int dummy_old;
|
||||
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 4>{}([&](auto idx) {
|
||||
static_for<0, kPThreadBufSize, 4>{}([&](auto idx) {
|
||||
const float a = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 0]);
|
||||
const float b = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]);
|
||||
const float c = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 2]);
|
||||
@@ -2039,7 +2126,7 @@ struct UnifiedAttentionPipeline
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](auto idx) {
|
||||
static_for<0, kPThreadBufSize, 2>{}([&](auto idx) {
|
||||
float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]);
|
||||
float y = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]);
|
||||
if constexpr(std::is_same_v<PDataType, fp16_t>)
|
||||
|
||||
Reference in New Issue
Block a user