From a66068e978f59bf5a78bf0ef33d8f3574fc557ec Mon Sep 17 00:00:00 2001 From: msaffari-amd Date: Thu, 28 May 2026 13:26:37 +0000 Subject: [PATCH] p scale support revert comments --- example/ck_tile/01_fmha/fmha_fwd.hpp | 23 +++++++++-- .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 16 ++++++-- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 39 +++++++++++++++---- 3 files changed, 63 insertions(+), 15 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 9f177c255e..c24db93acc 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -679,9 +679,19 @@ struct fmha_batch_prefill_args ck_tile::index_t stride_k_descale_token = 0; // K descale: within-page token stride ck_tile::index_t nhead_stride_k_descale = 0; // K descale: head stride ck_tile::index_t nhead_stride_v_descale = 0; // V descale: head stride (per-head only) + + // PER_TOKEN_HEAD optional per-q-head P scale [num_head_q] fp32. + const void* p_scale_ptr = nullptr; }; -// Select KV-cache load mode for batch-prefill. +// Selects the KV-cache load mode for a batch-prefill dispatch arm. +// GLOBAL_LOAD_LDS: required when (a) the page is smaller than one K/V tile +// so per-page SRD is impossible, AND (b) the total KV-pool byte size +// exceeds INT32_MAX so SRD's 32-bit byte offset cannot address it. +// BUFFER_LOAD: every other case — the SGPR-resident SRD path is fastest. +// Inputs are taken as plain integers so the helper has no template parameter +// and can be called from each codegen-emitted dispatcher arm with the arm's +// compile-time kN0 / element_bytes substituted as constants. inline ck_tile::BlockAttentionKVCacheLoadModeEnum fmha_batch_prefill_select_kv_load_mode(ck_tile::index_t page_block_size, ck_tile::index_t kN0, @@ -689,7 +699,10 @@ fmha_batch_prefill_select_kv_load_mode(ck_tile::index_t page_block_size, ck_tile::index_t batch_stride_k, ck_tile::index_t element_bytes) { - // Promote all operands before multiply to avoid intermediate overflow. + // Promote every operand to long_index_t so overflow is impossible regardless + // of multiplication order. A bare `static_cast(num_total_pages) + // * batch_stride_k * element_bytes` only works because of left-to-right + // associativity — a future reorder of the operands would silently truncate. const auto kv_pool_bytes = static_cast(num_total_pages) * static_cast(batch_stride_k) * static_cast(element_bytes); @@ -1344,7 +1357,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.nblock_stride_k_descale_page, args.stride_k_descale_token, args.nhead_stride_k_descale, - args.nhead_stride_v_descale); + args.nhead_stride_v_descale, + args.p_scale_ptr); } else { // create batch mode kernel arguments @@ -1405,7 +1419,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.nblock_stride_k_descale_page, args.stride_k_descale_token, args.nhead_stride_k_descale, - args.nhead_stride_v_descale); + args.nhead_stride_v_descale, + args.p_scale_ptr); } }(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 81a431a3e9..7be4ef8d2a 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -220,6 +220,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t stride_k_descale_token = 0; ck_tile::index_t nhead_stride_k_descale = 0; ck_tile::index_t nhead_stride_v_descale = 0; + // Optional per-q-head P scale [num_head_q] fp32. + const void* p_scale_ptr = nullptr; }; // Helper template to select QScale Kargs type based on QScaleEnum @@ -409,7 +411,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t nblock_stride_k_descale_page = 0, ck_tile::index_t stride_k_descale_token = 0, ck_tile::index_t nhead_stride_k_descale = 0, - ck_tile::index_t nhead_stride_v_descale = 0) + ck_tile::index_t nhead_stride_v_descale = 0, + const void* p_scale_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, @@ -499,6 +502,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.stride_k_descale_token = stride_k_descale_token; kargs.nhead_stride_k_descale = nhead_stride_k_descale; kargs.nhead_stride_v_descale = nhead_stride_v_descale; + kargs.p_scale_ptr = p_scale_ptr; } if constexpr(kHasDropout) { @@ -585,7 +589,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t nblock_stride_k_descale_page = 0, ck_tile::index_t stride_k_descale_token = 0, ck_tile::index_t nhead_stride_k_descale = 0, - ck_tile::index_t nhead_stride_v_descale = 0) + ck_tile::index_t nhead_stride_v_descale = 0, + const void* p_scale_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, @@ -672,6 +677,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.stride_k_descale_token = stride_k_descale_token; kargs.nhead_stride_k_descale = nhead_stride_k_descale; kargs.nhead_stride_v_descale = nhead_stride_v_descale; + kargs.p_scale_ptr = p_scale_ptr; } if constexpr(kHasDropout) { @@ -1416,6 +1422,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel const float* k_descale_ptr = reinterpret_cast(kargs.k_descale_ptr); const float* v_descale_ptr = reinterpret_cast(kargs.v_descale_ptr); + const float* p_scale_ptr = + reinterpret_cast(kargs.p_scale_ptr); + return FmhaPipeline{}(q_dram_window, k_dram_window, v_dram_window, @@ -1445,7 +1454,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.nblock_stride_k_descale_page, kargs.stride_k_descale_token, kargs.nhead_stride_k_descale, - kargs.nhead_stride_v_descale); + kargs.nhead_stride_v_descale, + p_scale_ptr); } else { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 22dca51734..7c97f07673 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -449,7 +449,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync index_t nblock_stride_k_descale_page = 0, index_t stride_k_descale_token = 0, index_t nhead_stride_k_descale = 0, - index_t nhead_stride_v_descale = 0) const + index_t nhead_stride_v_descale = 0, + // PER_TOKEN_HEAD caller P scale [num_head_q] fp32, optional. + // Folded into the exp2 row-max shift; the rowsum carries the + // same factor so it cancels in O/l with no v_descale fixup. + const float* p_scale_ptr = nullptr) const { // KV_BLOCKSCALE requires page_block_size >= kN0 to ensure // all tokens in a main loop iteration belong to the same page @@ -462,6 +466,17 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // below precomputes per-(kPageBlockSize)-wide-slice physical page IDs // and applies them per column. + // Per-q-head P scale (PER_TOKEN_HEAD only): folded into the exp2 + // row-max shift below; null pointer reduces to the default 2^shift. + float p_scale_log2 = 0.f; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) + { + if(p_scale_ptr != nullptr) + { + p_scale_log2 = log2f(p_scale_ptr[block_indices.qo_head_idx]); + } + } + static_assert( std::is_same_v> && std::is_same_v> && @@ -1463,21 +1478,27 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - // For KV_BLOCKSCALE: precompute (m - shift) once per row - // exp2(s - (m - shift)) = exp2(s - m + shift) = exp2(s - m) * 2^shift - // This scales P by 2^shift (≈448 for fp8_e4m3) without explicit multiply + // exp2(s - (m - shift)) = exp2(s - m) * 2^shift, i.e. P is + // scaled by 2^shift before the fp8 cast: OCP fp8 (e4m3) + // uses 8 -> 256, FNUZ fp8 uses 7 -> 128. auto validated_m = get_validated_m(m[i_idx]); auto row_max = scale_s * validated_m; if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE || QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) { #if CK_TILE_USE_OCP_FP8 - validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap - row_max -= OCP_FP8_SHIFT; // for else branch + validated_m -= OCP_FP8_SHIFT; + row_max -= OCP_FP8_SHIFT; #else validated_m -= FNUZ_FP8_SHIFT; row_max -= FNUZ_FP8_SHIFT; #endif + if constexpr(QScaleEnum == + BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) + { + validated_m -= p_scale_log2; + row_max -= p_scale_log2; + } } #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { @@ -1991,7 +2012,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync index_t nblock_stride_k_descale_page, index_t stride_k_descale_token, index_t nhead_stride_k_descale, - index_t nhead_stride_v_descale) const + index_t nhead_stride_v_descale, + const float* p_scale_ptr = nullptr) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -2032,7 +2054,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync nblock_stride_k_descale_page, stride_k_descale_token, nhead_stride_k_descale, - nhead_stride_v_descale); + nhead_stride_v_descale, + p_scale_ptr); } };