diff --git a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp index 97642d8a5f..6272cf7ccc 100644 --- a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp @@ -149,7 +149,14 @@ struct variant_config // spills at kv64. Unlocking kv128 needs a VGPR-pressure cut (smaller kBlockM // for this tile, or sub-tiling kBlockN so the live score tile stays 32x64), // not an LDS change. Stay at 64 until that lands. - static constexpr index_t BlockSize = 64; + // + // UA_PREFILL_D128_BLOCKSIZE: compile-time override of the KV tile so the + // VGPR-pressure experiments (kv128 + sub-tiling) can be probed without + // editing this line each build. Defaults to the production 64. +#ifndef UA_PREFILL_D128_BLOCKSIZE +#define UA_PREFILL_D128_BLOCKSIZE 64 +#endif + static constexpr index_t BlockSize = UA_PREFILL_D128_BLOCKSIZE; using BlockWarps = sequence<8, 1, 1>; using WarpGemmShape = sequence<32, 32, 16>; template diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 1b484e6009..cd2629609f 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -23,6 +23,47 @@ #define UA_FA4_VLOAD_IN_MATRIX 1 #endif +// UA_FA4_INPLACE_DELTA (VGPR-pressure cut, blocks kv128): +// 0 (default): fmha_alu0 writes the scaled-shifted score into a dedicated +// `sp_delta` scratch (statically_indexed_array), and +// fmha_alu1 reads it back to compute exp2. That scratch is sized like the +// full score tile and declared per-slot (2x), so at kv128 it is up to +// ~64 VGPR of live state that exists only to bridge alu0 -> D_upd -> alu1. +// 1: do the scale-shift in place on sp_compute (alu0) and the exp2 in place +// (alu1). fmha_alu_D_upd between them reads only m/l/o_acc/rowsum_p, never +// the raw scores, so this is mathematically and bit identical. Removes the +// sp_delta array entirely. (TODO at the sp_delta decl asked for exactly +// this.) +#ifndef UA_FA4_INPLACE_DELTA +#define UA_FA4_INPLACE_DELTA 0 +#endif + +// UA_FA4_SHARED_SPCOMPUTE (VGPR-pressure cut, blocks kv128): +// 0 (default): `sp` is a 2-slot array of a union{sp_compute(fp32), p(fp8)}. +// Each slot is sized at the fp32 score tile, so two slots reserve ~2x the +// score tile (~128 VGPR at kv128) even though the deferred-PV pipeline +// only keeps ONE fp32 score live at a time. +// 1: keep only ONE shared fp32 sp_compute, plus a 2-slot fp8 `p` ping-pong +// (the deferred PV reads p[pi] from a prior softmax while QK overwrites the +// single sp_compute; the softmax that drains sp_compute always runs before +// the next QK within a warp group, so one fp32 buffer is sufficient). +// Saves one full fp32 score tile. Requires in-place delta (the sp_delta +// decltype would otherwise see a reference member), so it force-enables it. +#ifndef UA_FA4_SHARED_SPCOMPUTE +#define UA_FA4_SHARED_SPCOMPUTE 0 +#endif + +// UA_FA4_UNION_KV (VGPR-pressure cut, blocks kv128): union the register- +// resident k_tile and v_tile (ASM unions K/V). Default 0 (separate, overlaps +// the K ds_read with the PV MFMA). See the kv_tile_type definition. +#ifndef UA_FA4_UNION_KV +#define UA_FA4_UNION_KV 0 +#endif +#if UA_FA4_SHARED_SPCOMPUTE +#undef UA_FA4_INPLACE_DELTA +#define UA_FA4_INPLACE_DELTA 1 +#endif + // FMHA_MASK PLACEMENT: pick exactly one of: // - both 0 → baseline (mask in K-side memory phase, W0-3 phase 1 // / W4-7 phase 2, right after `cl_load(memK)`). @@ -621,6 +662,18 @@ struct UnifiedAttentionPipeline // Separate tiles let the K ds_read execute on the LSU *concurrently* // with the PV MFMA (it stays at the same program point, so the // cooperative-load residency slack is preserved -- see fa4_matrix). + // UA_FA4_UNION_KV: union k_tile/v_tile (ASM-style) to halve their VGPR + // footprint at the cost of serialising the K ds_read behind the PV MFMA + // (which reads v_tile). Default off (separate tiles overlap the K read + // with PV — measured ~3-4% faster); a VGPR-pressure lever for kv128. +#if UA_FA4_UNION_KV + union kv_tile_type + { + CK_TILE_DEVICE kv_tile_type() {} + decltype(load_tile(k_lds_window_load(number<0>{}))) k_tile; + decltype(load_tile_transpose(v_lds_window_load(number<0>{}))) v_tile; + } kv_tile; +#else struct kv_tile_type { CK_TILE_DEVICE kv_tile_type() {} @@ -629,7 +682,22 @@ struct UnifiedAttentionPipeline decltype(load_tile_transpose(v_lds_window_load(number<0>{}))) v_tile; } kv_tile; +#endif +#if UA_FA4_SHARED_SPCOMPUTE + // ONE fp32 score tile shared by both slots + a 2-slot fp8 P ping-pong. + using SpComputeT = decltype(gemm_0.MakeCBlockTile()); + using PTileT = decltype(make_static_distributed_tensor( + Policy::template MakePRegTileDistribution())); + SpComputeT sp_compute_shared; + statically_indexed_array p_slots; + struct SpRef + { + SpComputeT& sp_compute; + PTileT& p; + }; + auto sp = [&](auto idx) -> SpRef { return SpRef{sp_compute_shared, p_slots(idx)}; }; +#else union sp_compute_type { CK_TILE_DEVICE sp_compute_type() {} @@ -639,6 +707,13 @@ struct UnifiedAttentionPipeline Policy::template MakePRegTileDistribution())) p; }; statically_indexed_array sp; +#endif + // P thread-buffer length as a type-derived constexpr; usable in the + // static_assert / static_for sites even when `sp(idx)` is a runtime + // proxy (SHARED_SPCOMPUTE) rather than an array element. + static constexpr index_t kPThreadBufSize = + decltype(make_static_distributed_tensor( + Policy::template MakePRegTileDistribution()))::get_thread_buffer_size(); decltype(gemm_1.MakeCBlockTile()) o_acc; constexpr index_t fmha_alu_D_reg_cnt = 6; // threshold to decide how many fmha_alu_D_upd() @@ -1588,7 +1663,9 @@ struct UnifiedAttentionPipeline decltype(m) m_old; SMPLComputeDataType o_acc_scale; // rescale o_acc in fmha_alu1() & fmha_alu_D_upd() /// TODO: remove the sp_delta and use sp_compute directly +#if !UA_FA4_INPLACE_DELTA statically_indexed_array{}).sp_compute), 2> sp_delta; +#endif auto fmha_alu0 = [&](auto sp_reg_idx) { m_old = m; // m{j-1} @@ -1651,8 +1728,13 @@ struct UnifiedAttentionPipeline sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if UA_FA4_INPLACE_DELTA + sp(sp_reg_idx).sp_compute(i_j_idx) = detail::fma_impl_vsv( + sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m_shift(i_j_idx)); +#else sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv( sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m_shift(i_j_idx)); +#endif }); }); /// TODO: move some fmha_alu1() code here if necessary @@ -1664,8 +1746,13 @@ struct UnifiedAttentionPipeline sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if UA_FA4_INPLACE_DELTA + sp(sp_reg_idx).sp_compute(i_j_idx) = + ck_tile::exp2(sp(sp_reg_idx).sp_compute(i_j_idx)); +#else sp(sp_reg_idx).sp_compute(i_j_idx) = ck_tile::exp2(sp_delta(sp_reg_idx)(i_j_idx)); +#endif }); }); @@ -1735,7 +1822,7 @@ struct UnifiedAttentionPipeline /// result 'p' is only consumed later. To anchor them here, we rewrite /// the cast_tile() call as inline assembly, forcing the conversions to be /// emitted at this point. - static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0); + static_assert(kPThreadBufSize % 2 == 0); if constexpr(std::is_same_v) { // FP8 P packing for the PV gemm. @@ -1758,7 +1845,7 @@ struct UnifiedAttentionPipeline // bytes back into `p.thread_buf_`. We still anchor the work // inline (no `cast_tile(...)` indirection) so the conversions // stay at the end of `fmha_alu1` like the FP16/BF16 paths. - static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 4 == 0, + static_assert(kPThreadBufSize % 4 == 0, "fp8 P conversion expects packs of 4 fp32 lanes per " "thread; widen the warp gemm M distribution if this " "trips."); @@ -1833,7 +1920,7 @@ struct UnifiedAttentionPipeline // sub=0 | slot[4..7] | N=8..11 | K=4..7 BAD // sub=1 | slot[0..3] | N=4..7 | K=8..11 BAD // sub=1 | slot[4..7] | N=12..15 | K=12..15 OK - static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 8 == 0, + static_assert(kPThreadBufSize % 8 == 0, "FP8 32x32x16 + Single cross-lane permute " "expects PV per-thread buffer in chunks of 8 " "fp8 (one warp-gemm K iteration)."); @@ -1851,7 +1938,7 @@ struct UnifiedAttentionPipeline #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wuninitialized" int dummy_old; - static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 8>{}([&](auto k_base) { + static_for<0, kPThreadBufSize, 8>{}([&](auto k_base) { auto& p = sp(sp_reg_idx).p; auto& sc = sp(sp_reg_idx).sp_compute; @@ -1936,7 +2023,7 @@ struct UnifiedAttentionPipeline #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wuninitialized" int dummy_old; - static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 4>{}([&](auto idx) { + static_for<0, kPThreadBufSize, 4>{}([&](auto idx) { const float a = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 0]); const float b = @@ -1992,7 +2079,7 @@ struct UnifiedAttentionPipeline #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wuninitialized" int dummy_old; - static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 4>{}([&](auto idx) { + static_for<0, kPThreadBufSize, 4>{}([&](auto idx) { const float a = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 0]); const float b = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]); const float c = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 2]); @@ -2039,7 +2126,7 @@ struct UnifiedAttentionPipeline } else { - static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](auto idx) { + static_for<0, kPThreadBufSize, 2>{}([&](auto idx) { float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]); float y = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]); if constexpr(std::is_same_v)