From ee3ada6e4a5f9dd80b1505378d37bd7b97dbacef Mon Sep 17 00:00:00 2001 From: msaffari-amd Date: Wed, 20 May 2026 14:21:12 +0000 Subject: [PATCH] [AITERKER-112] PER_TOKEN_HEAD: support page_size < kN0 via cross-page dequant - Pipeline: remove kPageBlockSize >= kN0 static_assert; QK dequant now precomputes tile_k_pages[] and indexes per-column. page_size >= kN0 stays on the original single-page fast path (kPagesPerTile==1). - Codegen: add page_size=64 to SUPPORTED_PAGE_SIZE; drop per_token_head from the page_size < tile.F_bn0 filter (kv_blockscale still filtered). --- .../01_fmha/codegen/ops/fmha_batch_prefill.py | 9 +-- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 58 +++++++++++++++---- 2 files changed, 52 insertions(+), 15 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 733f16ef35..34d87b0d66 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -48,7 +48,7 @@ DTYPE_BYTES = {k: v // 8 for k, v in DTYPE_BITS.items()} K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256} -SUPPORTED_PAGE_SIZE = [1, 16, 1024] +SUPPORTED_PAGE_SIZE = [1, 16, 64, 1024] SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"] SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"] KV_MEMORY_LAYOUT_ENUM_MAP = { @@ -819,10 +819,11 @@ def get_fwd_blobs( for page_size in SUPPORTED_PAGE_SIZE: if page_size == 1 and pipeline.F_kv_memory_layout != "linear": continue - # kv_blockscale / per_token_head require page_size >= kN0 (tile.F_bn0) - # This ensures all tokens in a main loop iteration belong to the same page + # kv_blockscale requires page_size >= kN0 (tile.F_bn0): its dequant + # loop only loads a single page per tile. per_token_head supports + # cross-page tiles (per-column page lookup in the pipeline). if ( - pipeline.F_qscale in ("kv_blockscale", "per_token_head") + pipeline.F_qscale == "kv_blockscale" and page_size < tile.F_bn0 ): continue diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 3c19745e79..4ef7aaea21 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -457,11 +457,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { static_assert(kPageBlockSize >= kN0, "KV_BLOCKSCALE requires kPageBlockSize >= kN0"); } - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) - { - static_assert(kPageBlockSize >= kN0, - "PER_TOKEN_HEAD requires kPageBlockSize >= kN0"); - } + // PER_TOKEN_HEAD supports both kPageBlockSize >= kN0 (single page per + // tile) and kPageBlockSize < kN0 (cross-page tile). The dequant loop + // below precomputes per-(kPageBlockSize)-wide-slice physical page IDs + // and applies them per column. static_assert( std::is_same_v> && @@ -1113,18 +1112,47 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync tile_elementwise_inout([&k_descale](auto& x) { x *= k_descale; }, s_acc); } // PER_TOKEN_HEAD: dequantize QK result with per-row Q descale and per-column K descale. - // s_acc[i,j] *= q_descale[q_origin+i, qo_head] * k_descale[k_page, k_slot+j, kv_head] + // s_acc[i,j] *= q_descale[q_origin+i, qo_head] * k_descale[k_page(j), k_slot(j), kv_head] + // Supports cross-page tiles (kPageBlockSize < kN0): column j is looked up in the + // page covering token (k_origin + j). else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) { const auto k_origin = k_dram_block_window.get_window_origin(); - const index_t k_page = k_physical_pages[number<0>{}]; - const index_t k_slot_base = k_origin.at(number<0>{}) % kPageBlockSize; const index_t qo_head = block_indices.qo_head_idx; const index_t kv_head = block_indices.kv_head_idx; const index_t q_row_base = q_origin.at(number<0>{}); - const index_t k_page_base = k_page * nblock_stride_k_descale_page + - kv_head * nhead_stride_k_descale; + // Number of distinct pages this tile spans. + // page_size >= kN0 -> 1 (fast path, identical to original behavior) + // page_size < kN0 -> kN0 / page_size (cross-page tile) + constexpr index_t kPagesPerTile = + (kPageBlockSize >= kN0) ? 1 : (kN0 / kPageBlockSize); + constexpr index_t kLog2PageBlockSize = []{ + index_t shift = 0; + index_t val = kPageBlockSize; + while(val > 1) { val >>= 1; ++shift; } + return shift; + }(); + constexpr index_t kPageSlotMask = kPageBlockSize - 1; + + // Physical pages for each kPageBlockSize-wide column slice of the tile. + // Tiny array (1 or kN0/kPageBlockSize entries); compiler keeps in registers. + index_t tile_k_pages[kPagesPerTile]; + if constexpr(kPagesPerTile == 1) + { + // Single-page tile: reuse the page already loaded for K-gemm. + tile_k_pages[0] = k_physical_pages[number<0>{}]; + } + else + { + const index_t k_origin_n = k_origin.at(number<0>{}); + static_for<0, kPagesPerTile, 1>{}([&](auto p) { + const index_t gp = (k_origin_n + p.value * kPageBlockSize) + >> kLog2PageBlockSize; + tile_k_pages[p.value] = + page_idx[ck_tile::min(gp, max_page_table_idx)]; + }); + } constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { @@ -1137,8 +1165,16 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const float qd = q_descale_per_token_ptr[ (q_row_base + i) * stride_q_descale_token + qo_head * nhead_stride_q_descale]; + // Per-column page + slot. For kPagesPerTile==1 the + // selector folds to 0 at compile time. + const index_t k_page = tile_k_pages[ + (kPagesPerTile == 1) ? index_t{0} + : (j >> kLog2PageBlockSize)]; + const index_t k_slot = j & kPageSlotMask; const float kd = k_descale_ptr[ - k_page_base + (k_slot_base + j) * stride_k_descale_token]; + k_page * nblock_stride_k_descale_page + + kv_head * nhead_stride_k_descale + + k_slot * stride_k_descale_token]; s_acc(i_j_idx) *= qd * kd; }); });