From 403d99124db341bc8a78c65c92ccd305b5243cfb Mon Sep 17 00:00:00 2001 From: msaffari-amd Date: Tue, 19 May 2026 15:41:32 +0000 Subject: [PATCH] [AITERKER-112] Add PER_TOKEN_HEAD FP8 quant scheme to batch_prefill - New BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD enum value - Pipeline overload in block_fmha_batch_prefill_pipeline_qr_ks_vs_async applying per-token Q/K descale via GEMM0-post outer product and per-head V descale at epilogue - fmha_batch_prefill_kernel kargs + MakeKargs + pipeline dispatch - fmha_fwd.hpp host-side traits/args wiring - quant.hpp trait specialization - Codegen emits PER_TOKEN_HEAD kernel variants --- .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 2 + .../01_fmha/codegen/ops/fmha_batch_prefill.py | 9 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 29 +++- example/ck_tile/01_fmha/quant.hpp | 17 +- .../block_attention_quant_scale_enum.hpp | 16 +- .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 112 ++++++++++++- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 154 +++++++++++++++++- 7 files changed, 317 insertions(+), 22 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 79fe6492a6..a169eb0ea6 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -81,6 +81,7 @@ QSCALE_MAP = { "pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR", "blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE", "kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE", + "per_token_head": "ck_tile::BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD", "mx": "ck_tile::BlockAttentionQuantScaleEnum::MX", } @@ -89,6 +90,7 @@ QSCALE_CHECK_MAP = { "pertensor": "quant_scale_enum::pertensor", "blockscale": "quant_scale_enum::blockscale", "kv_blockscale": "quant_scale_enum::kv_blockscale", + "per_token_head": "quant_scale_enum::per_token_head", "mx": "quant_scale_enum::mx", } diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 475631a885..733f16ef35 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -733,7 +733,7 @@ class KernelComponentFactory: kv_lookup_table, ) in itertools.product( ["t", "f"], - ["pertensor", "kv_blockscale"], + ["pertensor", "kv_blockscale", "per_token_head"], get_mask_map(mask_impl).keys(), ["no"], ["t", "f"], @@ -819,9 +819,12 @@ def get_fwd_blobs( for page_size in SUPPORTED_PAGE_SIZE: if page_size == 1 and pipeline.F_kv_memory_layout != "linear": continue - # kv_blockscale requires page_size >= kN0 (tile.F_bn0) + # kv_blockscale / per_token_head require page_size >= kN0 (tile.F_bn0) # This ensures all tokens in a main loop iteration belong to the same page - if pipeline.F_qscale == "kv_blockscale" and page_size < tile.F_bn0: + if ( + pipeline.F_qscale in ("kv_blockscale", "per_token_head") + and page_size < tile.F_bn0 + ): continue k = FmhaFwdKernel( F_idx=0, diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 98e2df2e1e..3913d84cff 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -671,6 +671,19 @@ struct fmha_batch_prefill_args // v_descale_ptr: [num_block, num_kv_head] - points to v block descale ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension + + // PER_TOKEN_HEAD: Q/K per-token per-head, V per-head (FP8 fine-grained). + // q_descale_ptr/k_descale_ptr/v_descale_ptr are reused; layout: + // q_descale: [total_q_tokens, nhead_q] fp32 + // k_descale: [num_total_pages, page_block_size, nhead_k] fp32 + // (aligned with paged K cache so we can reuse k_physical_pages[]) + // v_descale: [nhead_k] fp32 + ck_tile::index_t stride_q_descale_token = 0; // Q descale: row stride (per-token) + ck_tile::index_t nhead_stride_q_descale = 0; // Q descale: head stride + ck_tile::index_t nblock_stride_k_descale_page = 0; // K descale: page stride + ck_tile::index_t stride_k_descale_token = 0; // K descale: within-page token stride + ck_tile::index_t nhead_stride_k_descale = 0; // K descale: head stride + ck_tile::index_t nhead_stride_v_descale = 0; // V descale: head stride (per-head only) }; // Selects the KV-cache load mode for a batch-prefill dispatch arm. @@ -1340,7 +1353,13 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.drop_seed_offset, args.sink_ptr, args.nblock_stride_kv_block_descale, - args.nhead_stride_kv_block_descale); + args.nhead_stride_kv_block_descale, + args.stride_q_descale_token, + args.nhead_stride_q_descale, + args.nblock_stride_k_descale_page, + args.stride_k_descale_token, + args.nhead_stride_k_descale, + args.nhead_stride_v_descale); } else { // create batch mode kernel arguments @@ -1395,7 +1414,13 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.drop_seed_offset, args.sink_ptr, args.nblock_stride_kv_block_descale, - args.nhead_stride_kv_block_descale); + args.nhead_stride_kv_block_descale, + args.stride_q_descale_token, + args.nhead_stride_q_descale, + args.nblock_stride_k_descale_page, + args.stride_k_descale_token, + args.nhead_stride_k_descale, + args.nhead_stride_v_descale); } }(); diff --git a/example/ck_tile/01_fmha/quant.hpp b/example/ck_tile/01_fmha/quant.hpp index 4b8cd2e9a4..70c4f843ab 100644 --- a/example/ck_tile/01_fmha/quant.hpp +++ b/example/ck_tile/01_fmha/quant.hpp @@ -14,11 +14,12 @@ // keep sync with BlockAttentionQuantScaleEnum enum class quant_scale_enum { - no_scale = 0, - pertensor = 1, - blockscale = 2, - kv_blockscale = 3, // Q per-tensor, K/V per-page block scale - mx = 4, // Microscaling (MX) + no_scale = 0, + pertensor = 1, + blockscale = 2, + kv_blockscale = 3, // Q per-tensor, K/V per-page block scale + mx = 4, // Microscaling (MX) + per_token_head = 5, // Q/K per-token per-head, V per-head (FP8 fine-grained) }; struct quant_scale_info @@ -37,6 +38,8 @@ struct quant_scale_info os << "kvbs"; else if(type == quant_scale_enum::mx) os << "mx"; + else if(type == quant_scale_enum::per_token_head) + os << "pth"; } static quant_scale_info decode(std::string str) @@ -62,6 +65,10 @@ struct quant_scale_info { info.type = quant_scale_enum::mx; } + else if(str == "pth" || str == "5") + { + info.type = quant_scale_enum::per_token_head; + } else { throw std::invalid_argument("invalid quant scale value: " + str); diff --git a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp index 61051cc08a..401f2050a4 100644 --- a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp +++ b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp @@ -10,11 +10,12 @@ namespace ck_tile { // This class is used for codegen pattern matching enum class BlockAttentionQuantScaleEnum { - NO_SCALE = 0, - PERTENSOR = 1, - BLOCKSCALE = 2, - KV_BLOCKSCALE = 3, // Q per-tensor, K/V per-page block scale - MX = 4, // Microscaling + NO_SCALE = 0, + PERTENSOR = 1, + BLOCKSCALE = 2, + KV_BLOCKSCALE = 3, // Q per-tensor, K/V per-page block scale + MX = 4, // Microscaling + PER_TOKEN_HEAD = 5, // Q/K per-token per-head, V per-head (FP8 fine-grained) }; template @@ -45,5 +46,10 @@ struct BlockAttentionQuantScaleEnumToStr { static constexpr const char* name = "mx"; }; +template <> +struct BlockAttentionQuantScaleEnumToStr +{ + static constexpr const char* name = "per_token_head"; +}; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index cab9ee5944..81a431a3e9 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -205,6 +205,23 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension }; + // PER_TOKEN_HEAD: Q per-token-per-head, K per-token-per-head (paged-aligned), V per-head + // q_descale: [total_q, nhead_q] + // k_descale: [num_total_pages, page_block_size, nhead_k] + // v_descale: [nhead_k] + struct FmhaFwdPerTokenHeadKargs + { + const void* q_descale_ptr = nullptr; + const void* k_descale_ptr = nullptr; + const void* v_descale_ptr = nullptr; + ck_tile::index_t stride_q_descale_token = 0; + ck_tile::index_t nhead_stride_q_descale = 0; + ck_tile::index_t nblock_stride_k_descale_page = 0; + ck_tile::index_t stride_k_descale_token = 0; + ck_tile::index_t nhead_stride_k_descale = 0; + ck_tile::index_t nhead_stride_v_descale = 0; + }; + // Helper template to select QScale Kargs type based on QScaleEnum // EmptyType: type to use when QScaleEnum is NO_SCALE (e.g., FmhaFwdEmptyKargs<3>) template @@ -225,6 +242,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel using type = FmhaFwdKVBlockScaleKargs; }; + template + struct GetQScaleKargs + { + using type = FmhaFwdPerTokenHeadKargs; + }; + struct FmhaFwdDropoutSeedOffset { template @@ -379,7 +402,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel drop_seed_offset, const void* sink_ptr = nullptr, ck_tile::index_t nblock_stride_kv_block_descale = 0, - ck_tile::index_t nhead_stride_kv_block_descale = 0) + ck_tile::index_t nhead_stride_kv_block_descale = 0, + // PER_TOKEN_HEAD strides (only used when QScaleEnum == PER_TOKEN_HEAD) + ck_tile::index_t stride_q_descale_token = 0, + ck_tile::index_t nhead_stride_q_descale = 0, + ck_tile::index_t nblock_stride_k_descale_page = 0, + ck_tile::index_t stride_k_descale_token = 0, + ck_tile::index_t nhead_stride_k_descale = 0, + ck_tile::index_t nhead_stride_v_descale = 0) { Kargs kargs{{q_ptr, k_ptr, @@ -458,6 +488,18 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale; kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale; } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + kargs.stride_q_descale_token = stride_q_descale_token; + kargs.nhead_stride_q_descale = nhead_stride_q_descale; + kargs.nblock_stride_k_descale_page = nblock_stride_k_descale_page; + kargs.stride_k_descale_token = stride_k_descale_token; + kargs.nhead_stride_k_descale = nhead_stride_k_descale; + kargs.nhead_stride_v_descale = nhead_stride_v_descale; + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -536,7 +578,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel drop_seed_offset, const void* sink_ptr = nullptr, ck_tile::index_t nblock_stride_kv_block_descale = 0, - ck_tile::index_t nhead_stride_kv_block_descale = 0) + ck_tile::index_t nhead_stride_kv_block_descale = 0, + // PER_TOKEN_HEAD strides (only used when QScaleEnum == PER_TOKEN_HEAD) + ck_tile::index_t stride_q_descale_token = 0, + ck_tile::index_t nhead_stride_q_descale = 0, + ck_tile::index_t nblock_stride_k_descale_page = 0, + ck_tile::index_t stride_k_descale_token = 0, + ck_tile::index_t nhead_stride_k_descale = 0, + ck_tile::index_t nhead_stride_v_descale = 0) { Kargs kargs{{q_ptr, k_ptr, @@ -612,6 +661,18 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale; kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale; } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + kargs.stride_q_descale_token = stride_q_descale_token; + kargs.nhead_stride_q_descale = nhead_stride_q_descale; + kargs.nblock_stride_k_descale_page = nblock_stride_k_descale_page; + kargs.stride_k_descale_token = stride_k_descale_token; + kargs.nhead_stride_k_descale = nhead_stride_k_descale; + kargs.nhead_stride_v_descale = nhead_stride_v_descale; + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -1222,6 +1283,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel float q_descale = *(reinterpret_cast(kargs.q_descale_ptr)); return kargs.scale_s * q_descale; } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) + { + // Q/K descales are per-token-per-head, applied as outer product in pipeline. + // Here we only forward the softmax scale (1/sqrt(d)). + return kargs.scale_s; + } else { return kargs.scale_s; @@ -1339,6 +1406,47 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.nblock_stride_kv_block_descale, kargs.nhead_stride_kv_block_descale); } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) + { + // PER_TOKEN_HEAD: Q/K descales are per-token-per-head, V is per-head. + assert(kargs.q_descale_ptr != nullptr); + assert(kargs.k_descale_ptr != nullptr); + assert(kargs.v_descale_ptr != nullptr); + const float* q_descale_ptr = reinterpret_cast(kargs.q_descale_ptr); + const float* k_descale_ptr = reinterpret_cast(kargs.k_descale_ptr); + const float* v_descale_ptr = reinterpret_cast(kargs.v_descale_ptr); + + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + lse_dram_window, + mask, + position_encoding, + variant_params.sm_scale, + variant, + variant_params, + block_indices, + smem_ptr, + page_idx, + stride_k_for_pipeline, + stride_v_for_pipeline, + kargs.batch_stride_k, + kargs.batch_stride_v, + dropout, + sink_value, + max_page_table_idx, + q_descale_ptr, + k_descale_ptr, + v_descale_ptr, + kargs.stride_q_descale_token, + kargs.nhead_stride_q_descale, + kargs.nblock_stride_k_descale_page, + kargs.stride_k_descale_token, + kargs.nhead_stride_k_descale, + kargs.nhead_stride_v_descale); + } else { return FmhaPipeline{}(q_dram_window, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index adc24943e6..3c19745e79 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -436,7 +436,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const float* k_descale_ptr = nullptr, const float* v_descale_ptr = nullptr, index_t nblock_stride_kv_block_descale = 0, - index_t nhead_stride_kv_block_descale = 0) const + index_t nhead_stride_kv_block_descale = 0, + // PER_TOKEN_HEAD parameters (only used when QScaleEnum == PER_TOKEN_HEAD) + // Reuses k_descale_ptr / v_descale_ptr above; q_descale provided here. + // Layouts: + // q_descale_per_token_ptr: [total_q, nhead_q] + // k_descale_ptr (when PER_TOKEN_HEAD): [num_total_pages, page_block_size, nhead_k] + // v_descale_ptr (when PER_TOKEN_HEAD): [nhead_k] + const float* q_descale_per_token_ptr = nullptr, + index_t stride_q_descale_token = 0, + index_t nhead_stride_q_descale = 0, + index_t nblock_stride_k_descale_page = 0, + index_t stride_k_descale_token = 0, + index_t nhead_stride_k_descale = 0, + index_t nhead_stride_v_descale = 0) const { // KV_BLOCKSCALE requires page_block_size >= kN0 to ensure // all tokens in a main loop iteration belong to the same page @@ -444,6 +457,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { static_assert(kPageBlockSize >= kN0, "KV_BLOCKSCALE requires kPageBlockSize >= kN0"); } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) + { + static_assert(kPageBlockSize >= kN0, + "PER_TOKEN_HEAD requires kPageBlockSize >= kN0"); + } static_assert( std::is_same_v> && @@ -1027,6 +1045,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync k_descale = k_descale_ptr[scale_offset]; v_descale = v_descale_ptr[scale_offset]; } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) + { + // V scale is per-head only; load scalar from v_descale_ptr[kv_head_idx]. + // K scale is per-token-per-head and is applied as a vector after GEMM0 + // (see PER_TOKEN_HEAD branch below). + v_descale = v_descale_ptr[block_indices.kv_head_idx * nhead_stride_v_descale]; + } // Prefetch V physical pages early - overlaps with GEMM0 computation save_and_prefetch_v_pages(number{}); @@ -1087,6 +1112,37 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { tile_elementwise_inout([&k_descale](auto& x) { x *= k_descale; }, s_acc); } + // PER_TOKEN_HEAD: dequantize QK result with per-row Q descale and per-column K descale. + // s_acc[i,j] *= q_descale[q_origin+i, qo_head] * k_descale[k_page, k_slot+j, kv_head] + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + const index_t k_page = k_physical_pages[number<0>{}]; + const index_t k_slot_base = k_origin.at(number<0>{}) % kPageBlockSize; + const index_t qo_head = block_indices.qo_head_idx; + const index_t kv_head = block_indices.kv_head_idx; + const index_t q_row_base = q_origin.at(number<0>{}); + + const index_t k_page_base = k_page * nblock_stride_k_descale_page + + kv_head * nhead_stride_k_descale; + + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + const index_t i = tile_idx.at(number<0>{}); + const index_t j = tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + const float qd = q_descale_per_token_ptr[ + (q_row_base + i) * stride_q_descale_token + + qo_head * nhead_stride_q_descale]; + const float kd = k_descale_ptr[ + k_page_base + (k_slot_base + j) * stride_k_descale_token]; + s_acc(i_j_idx) *= qd * kd; + }); + }); + } const auto p = [&]() { const auto bias_tile = load_tile(bias_dram_window); // load bias tile @@ -1309,7 +1365,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // This scales P by 2^shift (≈448 for fp8_e4m3) without explicit multiply auto validated_m = get_validated_m(m[i_idx]); auto row_max = scale_s * validated_m; - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE || + QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) { #if CK_TILE_USE_OCP_FP8 validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap @@ -1427,7 +1484,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // STAGE 3, KV gemm // KV_BLOCKSCALE: accumulate P*V into temporary tile before applying v_descale auto o_acc_unscaled = decltype(o_acc){}; - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE || + QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) { clear_tile(o_acc_unscaled); } @@ -1435,7 +1493,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Select GEMM1 target: o_acc_unscaled for KV_BLOCKSCALE (needs v_descale), o_acc // otherwise auto& gemm1_acc = [&]() -> auto& { - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE || + QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) return o_acc_unscaled; else return o_acc; @@ -1586,7 +1645,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // 1. P was scaled by 2^shift through exp2 shift trick // 2. rowsum l was also scaled by 2^shift // 3. Final O = sum(P*V) / l, so the 2^shift cancels out - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE || + QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) { tile_elementwise_inout( [&v_descale](auto& o, auto& o_unscaled) { o += o_unscaled * v_descale; }, @@ -1787,6 +1847,90 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync nblock_stride_kv_block_descale, nhead_stride_kv_block_descale); } + + // Overload for PER_TOKEN_HEAD: Q/K per-token-per-head, V per-head + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr, + const index_t* page_idx, + const index_t stride_k, + const index_t stride_v, + const index_t page_stride_k, + const index_t page_stride_v, + DropoutType& dropout, + float sink_v, + const index_t max_page_table_idx, + const float* q_descale_per_token_ptr, + const float* k_descale_per_token_ptr, + const float* v_descale_per_head_ptr, + index_t stride_q_descale_token, + index_t nhead_stride_q_descale, + index_t nblock_stride_k_descale_page, + index_t stride_k_descale_token, + index_t nhead_stride_k_descale, + index_t nhead_stride_v_descale) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + randval_dram_block_window_tmp, + lse_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + position_encoding, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + page_idx, + stride_k, + stride_v, + page_stride_k, + page_stride_v, + dropout, + sink_v, + max_page_table_idx, + k_descale_per_token_ptr, // reused: k_descale_ptr slot + v_descale_per_head_ptr, // reused: v_descale_ptr slot + /*nblock_stride_kv_block_descale*/ 0, + /*nhead_stride_kv_block_descale*/ 0, + q_descale_per_token_ptr, + stride_q_descale_token, + nhead_stride_q_descale, + nblock_stride_k_descale_page, + stride_k_descale_token, + nhead_stride_k_descale, + nhead_stride_v_descale); + } }; } // namespace ck_tile