[AITERKER-112] PER_TOKEN_HEAD: support page_size < kN0 via cross-page dequant

- Pipeline: remove kPageBlockSize >= kN0 static_assert; QK dequant now
  precomputes tile_k_pages[] and indexes per-column. page_size >= kN0 stays
  on the original single-page fast path (kPagesPerTile==1).
- Codegen: add page_size=64 to SUPPORTED_PAGE_SIZE; drop per_token_head from
  the page_size < tile.F_bn0 filter (kv_blockscale still filtered).
This commit is contained in:
msaffari-amd
2026-05-20 14:21:12 +00:00
parent 403d99124d
commit ee3ada6e4a
2 changed files with 52 additions and 15 deletions

View File

@@ -48,7 +48,7 @@ DTYPE_BYTES = {k: v // 8 for k, v in DTYPE_BITS.items()}
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
SUPPORTED_PAGE_SIZE = [1, 16, 1024]
SUPPORTED_PAGE_SIZE = [1, 16, 64, 1024]
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
KV_MEMORY_LAYOUT_ENUM_MAP = {
@@ -819,10 +819,11 @@ def get_fwd_blobs(
for page_size in SUPPORTED_PAGE_SIZE:
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
continue
# kv_blockscale / per_token_head require page_size >= kN0 (tile.F_bn0)
# This ensures all tokens in a main loop iteration belong to the same page
# kv_blockscale requires page_size >= kN0 (tile.F_bn0): its dequant
# loop only loads a single page per tile. per_token_head supports
# cross-page tiles (per-column page lookup in the pipeline).
if (
pipeline.F_qscale in ("kv_blockscale", "per_token_head")
pipeline.F_qscale == "kv_blockscale"
and page_size < tile.F_bn0
):
continue

View File

@@ -457,11 +457,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
{
static_assert(kPageBlockSize >= kN0, "KV_BLOCKSCALE requires kPageBlockSize >= kN0");
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
static_assert(kPageBlockSize >= kN0,
"PER_TOKEN_HEAD requires kPageBlockSize >= kN0");
}
// PER_TOKEN_HEAD supports both kPageBlockSize >= kN0 (single page per
// tile) and kPageBlockSize < kN0 (cross-page tile). The dequant loop
// below precomputes per-(kPageBlockSize)-wide-slice physical page IDs
// and applies them per column.
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -1113,18 +1112,47 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
tile_elementwise_inout([&k_descale](auto& x) { x *= k_descale; }, s_acc);
}
// PER_TOKEN_HEAD: dequantize QK result with per-row Q descale and per-column K descale.
// s_acc[i,j] *= q_descale[q_origin+i, qo_head] * k_descale[k_page, k_slot+j, kv_head]
// s_acc[i,j] *= q_descale[q_origin+i, qo_head] * k_descale[k_page(j), k_slot(j), kv_head]
// Supports cross-page tiles (kPageBlockSize < kN0): column j is looked up in the
// page covering token (k_origin + j).
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
const auto k_origin = k_dram_block_window.get_window_origin();
const index_t k_page = k_physical_pages[number<0>{}];
const index_t k_slot_base = k_origin.at(number<0>{}) % kPageBlockSize;
const index_t qo_head = block_indices.qo_head_idx;
const index_t kv_head = block_indices.kv_head_idx;
const index_t q_row_base = q_origin.at(number<0>{});
const index_t k_page_base = k_page * nblock_stride_k_descale_page +
kv_head * nhead_stride_k_descale;
// Number of distinct pages this tile spans.
// page_size >= kN0 -> 1 (fast path, identical to original behavior)
// page_size < kN0 -> kN0 / page_size (cross-page tile)
constexpr index_t kPagesPerTile =
(kPageBlockSize >= kN0) ? 1 : (kN0 / kPageBlockSize);
constexpr index_t kLog2PageBlockSize = []{
index_t shift = 0;
index_t val = kPageBlockSize;
while(val > 1) { val >>= 1; ++shift; }
return shift;
}();
constexpr index_t kPageSlotMask = kPageBlockSize - 1;
// Physical pages for each kPageBlockSize-wide column slice of the tile.
// Tiny array (1 or kN0/kPageBlockSize entries); compiler keeps in registers.
index_t tile_k_pages[kPagesPerTile];
if constexpr(kPagesPerTile == 1)
{
// Single-page tile: reuse the page already loaded for K-gemm.
tile_k_pages[0] = k_physical_pages[number<0>{}];
}
else
{
const index_t k_origin_n = k_origin.at(number<0>{});
static_for<0, kPagesPerTile, 1>{}([&](auto p) {
const index_t gp = (k_origin_n + p.value * kPageBlockSize)
>> kLog2PageBlockSize;
tile_k_pages[p.value] =
page_idx[ck_tile::min(gp, max_page_table_idx)];
});
}
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
@@ -1137,8 +1165,16 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const float qd = q_descale_per_token_ptr[
(q_row_base + i) * stride_q_descale_token +
qo_head * nhead_stride_q_descale];
// Per-column page + slot. For kPagesPerTile==1 the
// selector folds to 0 at compile time.
const index_t k_page = tile_k_pages[
(kPagesPerTile == 1) ? index_t{0}
: (j >> kLog2PageBlockSize)];
const index_t k_slot = j & kPageSlotMask;
const float kd = k_descale_ptr[
k_page_base + (k_slot_base + j) * stride_k_descale_token];
k_page * nblock_stride_k_descale_page +
kv_head * nhead_stride_k_descale +
k_slot * stride_k_descale_token];
s_acc(i_j_idx) *= qd * kd;
});
});