From 310efc556fc830b6094b99549939bab8af653a85 Mon Sep 17 00:00:00 2001 From: juuso-oskari Date: Tue, 26 May 2026 08:20:55 +0000 Subject: [PATCH] CK-UA: halve kBlockN for bf16/fp16 m16 decode + generalise PVAttrNumAccess MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The decode_d128_m16 tier was VGPR-saturated and LDS-bound on bf16/fp16 (probe_decode_d128 showed VGPR=256 + AGPR overflow, ~2x fp8's LDS at the same kBlockN), capping it at 1 CTA/CU. Halving kBlockN for the non-fp8 path on the m16 tier sheds enough LDS and VGPR pressure to fit 3-4 CTAs/CU (LDS-bound). The halved kBlockN forces a smaller-K MFMA shape on the m16 PV gemm (16x16x32 -> 16x16x16); we also auto- adjust WarpGemm::K so PVAttrNumAccess picks Single vs Double access correctly. The PVAttrNumAccess derivation is now generic — driven by (kABKPerLane, SubMinDim) rather than just (dtype) — so the new shape compiles without per-variant special-casing. Variants only affected where cfg::BlockSize/2 >= WarpGemm::N (i.e. decode_d128_m16); m32/m128/prefill keep their un-halved tiles since they use 32x32 N-warps. Co-authored-by: Cursor --- .../unified_attention_impl.hpp | 41 +++++++++++++++++-- ...fied_attention_pipeline_default_policy.hpp | 30 +++++++++----- 2 files changed, 56 insertions(+), 15 deletions(-) diff --git a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp index e8061fa528..9db4ef397d 100644 --- a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp @@ -3,6 +3,7 @@ #pragma once +#include #include #include "ck_tile/core/numeric/bfloat16.hpp" @@ -259,9 +260,42 @@ struct unified_attention_kernel_traits static constexpr index_t HEAD_SIZE = cfg::HeadSize; static constexpr index_t kBlockM = cfg::BlockM; - static constexpr index_t BLOCK_SIZE = cfg::BlockSize; static constexpr bool kUseDecodeGrid = cfg::kUseDecodeGrid; + // bf16/fp16 carry a 2-byte element vs fp8's 1 byte, so at the same kBlockN + // they double both LDS usage and per-tile VGPR pressure. The decode probe + // (ua-test-scripts/probe_decode_d128.sh) showed bf16 saturating VGPR=256 + // with AGPR overflow (44-106 AGPRs) and ~2x the LDS of fp8 on the m16 + // tier — the LDS pressure alone caps decode_d128_m16 at 1 CTA/CU. We + // halve kBlockN for bf16/fp16 to shed LDS and VGPR pressure, trading a + // small per-iter overhead for a big occupancy boost. + // + // The halved kBlockN must satisfy both gemm constraints: + // QK gemm: kBlockN is the N axis -> kBlockN >= WarpGemm::N + // PV gemm: kBlockN is the K axis -> kBlockN >= WarpGemm::K + // When `cfg::BlockSize/2 < WarpGemm::K` we additionally swap WarpGemm::K + // to the halved kBlockN so the smaller-K MFMA is used. For our variants + // this only hits decode_d128_m16 (WG=<16,16,32> -> <16,16,16>); the d=64 + // tiers already fit the smaller tile under the same 16x16x32 / 32x32x16 + // MFMA. The 32x32 N-warps (m32/m128/prefill) cannot drop their kBlockN + // below 32 (WG::N=32) and stay at the un-halved size. + static constexpr index_t WGM_ = cfg::WarpGemmShape::at(number<0>{}); + static constexpr index_t WGN_ = cfg::WarpGemmShape::at(number<1>{}); + static constexpr index_t WGK_ = cfg::WarpGemmShape::at(number<2>{}); + static constexpr bool kBf16HalveBlockN = + (DataType != unified_attention_args::data_type_enum::fp8) && + (cfg::BlockSize / 2 >= WGN_); + static constexpr index_t BLOCK_SIZE = + kBf16HalveBlockN ? cfg::BlockSize / 2 : cfg::BlockSize; + // Swap WarpGemm::K down to BLOCK_SIZE when the halved kBlockN dropped + // below the original WarpGemm::K. PVAttrNumAccess in GetPVBlockGemm + // recomputes from the new WarpGemm shape (lanes_in_K * SubMinDim rule) + // so the smaller-K MFMA tiles cleanly. + using unified_attention_warp_gemm_shape = std::conditional_t< + (kBf16HalveBlockN && BLOCK_SIZE < WGK_), + sequence, + typename cfg::WarpGemmShape>; + // The 2nd entry of the BlockTile is the static `kBlockQ` exposed via // `UnifiedAttentionShape::kBlockQ`. Now that the kernel always reads // kBlockQ from `args.num_queries_per_kv` at runtime, this static value @@ -269,9 +303,8 @@ struct unified_attention_kernel_traits // happens in practice). Anchor it at kBlockM so the static "looks like // num_qpkv == 1" and any (kBlockM, num_qpkv) such that kBlockM % num_qpkv // == 0 works without touching this trait. - using unified_attention_block_tile = sequence; - using unified_attention_warp_gemm_shape = typename cfg::WarpGemmShape; - using unified_attention_block_warps = typename cfg::BlockWarps; + using unified_attention_block_tile = sequence; + using unified_attention_block_warps = typename cfg::BlockWarps; using unified_attention_shape = TileUnifiedAttentionShape>; // `load_tile_transpose` is only valid when the tile distribution's inner // packing matches the transpose engine's SubtileMinorDimension = - // 64 bits / sizeof(VDataType_in_bits). For BF16 / FP16 SubMinDim=4 and the - // PV warp gemm produces kABKPerLane / AttrNumAccess = 8 / 2 = 4 elements - // per lane on the K direction — Double access is needed there. For FP8 - // SubMinDim=8 and kABKPerLane=8, so we must pass `Single` (otherwise - // 8/2=4 mismatches the FP8 SubMinDim=8 and the load_tile_transpose - // validation static_asserts fire — see DefaultTranspose::Quad in - // load_tile_transpose.hpp). The select is purely a compile-time alias. + // 64 bits / sizeof(VDataType_in_bits). The PV warp gemm produces + // kABKPerLane = WG::K / lanes_in_K elements per lane on the K direction + // (lanes_in_K = 4 for 16x16x* MFMA, 2 for 32x32x*), so we must pick + // AttrNumAccess such that kABKPerLane / AttrNumAccess == SubMinDim: + // bf16/fp16 16x16x32 -> kABKPerLane=8, SubMinDim=4 -> Double. + // bf16/fp16 16x16x16 -> kABKPerLane=4, SubMinDim=4 -> Single. + // bf16/fp16 32x32x16 -> kABKPerLane=8, SubMinDim=4 -> Double. + // fp8/bf8 16x16x32 -> kABKPerLane=8, SubMinDim=8 -> Single. + // fp8/bf8 32x32x16 -> kABKPerLane=8, SubMinDim=8 -> Single. + // The select is a compile-time alias. + static constexpr index_t kPVWarpGemmM = + Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<0>{}); + static constexpr index_t kPVWarpGemmK = + Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<2>{}); + static constexpr index_t kPVLanesInK = (kPVWarpGemmM == 16) ? 4 : 2; + static constexpr index_t kPVABKPerLane = kPVWarpGemmK / kPVLanesInK; + static constexpr index_t kPVSubMinDim = 8 / sizeof(typename Problem::VDataType); static constexpr WGAttrNumAccessEnum PVAttrNumAccess = - std::is_same_v, ck_tile::fp8_t> || - std::is_same_v, ck_tile::bf8_t> - ? WGAttrNumAccessEnum::Single - : WGAttrNumAccessEnum::Double; + (kPVABKPerLane == kPVSubMinDim) ? WGAttrNumAccessEnum::Single + : WGAttrNumAccessEnum::Double; using WarpGemm = WarpGemmDispatcher