mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
CK-UA: halve kBlockN for bf16/fp16 m16 decode + generalise PVAttrNumAccess
The decode_d128_m16 tier was VGPR-saturated and LDS-bound on bf16/fp16 (probe_decode_d128 showed VGPR=256 + AGPR overflow, ~2x fp8's LDS at the same kBlockN), capping it at 1 CTA/CU. Halving kBlockN for the non-fp8 path on the m16 tier sheds enough LDS and VGPR pressure to fit 3-4 CTAs/CU (LDS-bound). The halved kBlockN forces a smaller-K MFMA shape on the m16 PV gemm (16x16x32 -> 16x16x16); we also auto- adjust WarpGemm::K so PVAttrNumAccess picks Single vs Double access correctly. The PVAttrNumAccess derivation is now generic — driven by (kABKPerLane, SubMinDim) rather than just (dtype) — so the new shape compiles without per-variant special-casing. Variants only affected where cfg::BlockSize/2 >= WarpGemm::N (i.e. decode_d128_m16); m32/m128/prefill keep their un-halved tiles since they use 32x32 N-warps. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
@@ -259,9 +260,42 @@ struct unified_attention_kernel_traits
|
||||
|
||||
static constexpr index_t HEAD_SIZE = cfg::HeadSize;
|
||||
static constexpr index_t kBlockM = cfg::BlockM;
|
||||
static constexpr index_t BLOCK_SIZE = cfg::BlockSize;
|
||||
static constexpr bool kUseDecodeGrid = cfg::kUseDecodeGrid;
|
||||
|
||||
// bf16/fp16 carry a 2-byte element vs fp8's 1 byte, so at the same kBlockN
|
||||
// they double both LDS usage and per-tile VGPR pressure. The decode probe
|
||||
// (ua-test-scripts/probe_decode_d128.sh) showed bf16 saturating VGPR=256
|
||||
// with AGPR overflow (44-106 AGPRs) and ~2x the LDS of fp8 on the m16
|
||||
// tier — the LDS pressure alone caps decode_d128_m16 at 1 CTA/CU. We
|
||||
// halve kBlockN for bf16/fp16 to shed LDS and VGPR pressure, trading a
|
||||
// small per-iter overhead for a big occupancy boost.
|
||||
//
|
||||
// The halved kBlockN must satisfy both gemm constraints:
|
||||
// QK gemm: kBlockN is the N axis -> kBlockN >= WarpGemm::N
|
||||
// PV gemm: kBlockN is the K axis -> kBlockN >= WarpGemm::K
|
||||
// When `cfg::BlockSize/2 < WarpGemm::K` we additionally swap WarpGemm::K
|
||||
// to the halved kBlockN so the smaller-K MFMA is used. For our variants
|
||||
// this only hits decode_d128_m16 (WG=<16,16,32> -> <16,16,16>); the d=64
|
||||
// tiers already fit the smaller tile under the same 16x16x32 / 32x32x16
|
||||
// MFMA. The 32x32 N-warps (m32/m128/prefill) cannot drop their kBlockN
|
||||
// below 32 (WG::N=32) and stay at the un-halved size.
|
||||
static constexpr index_t WGM_ = cfg::WarpGemmShape::at(number<0>{});
|
||||
static constexpr index_t WGN_ = cfg::WarpGemmShape::at(number<1>{});
|
||||
static constexpr index_t WGK_ = cfg::WarpGemmShape::at(number<2>{});
|
||||
static constexpr bool kBf16HalveBlockN =
|
||||
(DataType != unified_attention_args::data_type_enum::fp8) &&
|
||||
(cfg::BlockSize / 2 >= WGN_);
|
||||
static constexpr index_t BLOCK_SIZE =
|
||||
kBf16HalveBlockN ? cfg::BlockSize / 2 : cfg::BlockSize;
|
||||
// Swap WarpGemm::K down to BLOCK_SIZE when the halved kBlockN dropped
|
||||
// below the original WarpGemm::K. PVAttrNumAccess in GetPVBlockGemm
|
||||
// recomputes from the new WarpGemm shape (lanes_in_K * SubMinDim rule)
|
||||
// so the smaller-K MFMA tiles cleanly.
|
||||
using unified_attention_warp_gemm_shape = std::conditional_t<
|
||||
(kBf16HalveBlockN && BLOCK_SIZE < WGK_),
|
||||
sequence<WGM_, WGN_, BLOCK_SIZE>,
|
||||
typename cfg::WarpGemmShape>;
|
||||
|
||||
// The 2nd entry of the BlockTile is the static `kBlockQ` exposed via
|
||||
// `UnifiedAttentionShape::kBlockQ`. Now that the kernel always reads
|
||||
// kBlockQ from `args.num_queries_per_kv` at runtime, this static value
|
||||
@@ -269,9 +303,8 @@ struct unified_attention_kernel_traits
|
||||
// happens in practice). Anchor it at kBlockM so the static "looks like
|
||||
// num_qpkv == 1" and any (kBlockM, num_qpkv) such that kBlockM % num_qpkv
|
||||
// == 0 works without touching this trait.
|
||||
using unified_attention_block_tile = sequence<kBlockM, kBlockM, BLOCK_SIZE, HEAD_SIZE>;
|
||||
using unified_attention_warp_gemm_shape = typename cfg::WarpGemmShape;
|
||||
using unified_attention_block_warps = typename cfg::BlockWarps;
|
||||
using unified_attention_block_tile = sequence<kBlockM, kBlockM, BLOCK_SIZE, HEAD_SIZE>;
|
||||
using unified_attention_block_warps = typename cfg::BlockWarps;
|
||||
|
||||
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
|
||||
unified_attention_block_warps,
|
||||
|
||||
@@ -333,18 +333,26 @@ struct UnifiedAttentionPipelineDefaultPolicy
|
||||
typename Problem::UnifiedAttentionShape::Gemm1WarpTile>>;
|
||||
// `load_tile_transpose` is only valid when the tile distribution's inner
|
||||
// packing matches the transpose engine's SubtileMinorDimension =
|
||||
// 64 bits / sizeof(VDataType_in_bits). For BF16 / FP16 SubMinDim=4 and the
|
||||
// PV warp gemm produces kABKPerLane / AttrNumAccess = 8 / 2 = 4 elements
|
||||
// per lane on the K direction — Double access is needed there. For FP8
|
||||
// SubMinDim=8 and kABKPerLane=8, so we must pass `Single` (otherwise
|
||||
// 8/2=4 mismatches the FP8 SubMinDim=8 and the load_tile_transpose
|
||||
// validation static_asserts fire — see DefaultTranspose::Quad in
|
||||
// load_tile_transpose.hpp). The select is purely a compile-time alias.
|
||||
// 64 bits / sizeof(VDataType_in_bits). The PV warp gemm produces
|
||||
// kABKPerLane = WG::K / lanes_in_K elements per lane on the K direction
|
||||
// (lanes_in_K = 4 for 16x16x* MFMA, 2 for 32x32x*), so we must pick
|
||||
// AttrNumAccess such that kABKPerLane / AttrNumAccess == SubMinDim:
|
||||
// bf16/fp16 16x16x32 -> kABKPerLane=8, SubMinDim=4 -> Double.
|
||||
// bf16/fp16 16x16x16 -> kABKPerLane=4, SubMinDim=4 -> Single.
|
||||
// bf16/fp16 32x32x16 -> kABKPerLane=8, SubMinDim=4 -> Double.
|
||||
// fp8/bf8 16x16x32 -> kABKPerLane=8, SubMinDim=8 -> Single.
|
||||
// fp8/bf8 32x32x16 -> kABKPerLane=8, SubMinDim=8 -> Single.
|
||||
// The select is a compile-time alias.
|
||||
static constexpr index_t kPVWarpGemmM =
|
||||
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<0>{});
|
||||
static constexpr index_t kPVWarpGemmK =
|
||||
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<2>{});
|
||||
static constexpr index_t kPVLanesInK = (kPVWarpGemmM == 16) ? 4 : 2;
|
||||
static constexpr index_t kPVABKPerLane = kPVWarpGemmK / kPVLanesInK;
|
||||
static constexpr index_t kPVSubMinDim = 8 / sizeof(typename Problem::VDataType);
|
||||
static constexpr WGAttrNumAccessEnum PVAttrNumAccess =
|
||||
std::is_same_v<remove_cvref_t<typename Problem::VDataType>, ck_tile::fp8_t> ||
|
||||
std::is_same_v<remove_cvref_t<typename Problem::VDataType>, ck_tile::bf8_t>
|
||||
? WGAttrNumAccessEnum::Single
|
||||
: WGAttrNumAccessEnum::Double;
|
||||
(kPVABKPerLane == kPVSubMinDim) ? WGAttrNumAccessEnum::Single
|
||||
: WGAttrNumAccessEnum::Double;
|
||||
using WarpGemm =
|
||||
WarpGemmDispatcher<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
|
||||
Reference in New Issue
Block a user