mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
[AITERKER-112] PER_TOKEN_HEAD: support page_size < kN0 via cross-page dequant
- Pipeline: remove kPageBlockSize >= kN0 static_assert; QK dequant now precomputes tile_k_pages[] and indexes per-column. page_size >= kN0 stays on the original single-page fast path (kPagesPerTile==1). - Codegen: add page_size=64 to SUPPORTED_PAGE_SIZE; drop per_token_head from the page_size < tile.F_bn0 filter (kv_blockscale still filtered).
This commit is contained in:
@@ -48,7 +48,7 @@ DTYPE_BYTES = {k: v // 8 for k, v in DTYPE_BITS.items()}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
|
||||
|
||||
SUPPORTED_PAGE_SIZE = [1, 16, 1024]
|
||||
SUPPORTED_PAGE_SIZE = [1, 16, 64, 1024]
|
||||
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
|
||||
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
|
||||
KV_MEMORY_LAYOUT_ENUM_MAP = {
|
||||
@@ -819,10 +819,11 @@ def get_fwd_blobs(
|
||||
for page_size in SUPPORTED_PAGE_SIZE:
|
||||
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
|
||||
continue
|
||||
# kv_blockscale / per_token_head require page_size >= kN0 (tile.F_bn0)
|
||||
# This ensures all tokens in a main loop iteration belong to the same page
|
||||
# kv_blockscale requires page_size >= kN0 (tile.F_bn0): its dequant
|
||||
# loop only loads a single page per tile. per_token_head supports
|
||||
# cross-page tiles (per-column page lookup in the pipeline).
|
||||
if (
|
||||
pipeline.F_qscale in ("kv_blockscale", "per_token_head")
|
||||
pipeline.F_qscale == "kv_blockscale"
|
||||
and page_size < tile.F_bn0
|
||||
):
|
||||
continue
|
||||
|
||||
@@ -457,11 +457,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
{
|
||||
static_assert(kPageBlockSize >= kN0, "KV_BLOCKSCALE requires kPageBlockSize >= kN0");
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
static_assert(kPageBlockSize >= kN0,
|
||||
"PER_TOKEN_HEAD requires kPageBlockSize >= kN0");
|
||||
}
|
||||
// PER_TOKEN_HEAD supports both kPageBlockSize >= kN0 (single page per
|
||||
// tile) and kPageBlockSize < kN0 (cross-page tile). The dequant loop
|
||||
// below precomputes per-(kPageBlockSize)-wide-slice physical page IDs
|
||||
// and applies them per column.
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -1113,18 +1112,47 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
tile_elementwise_inout([&k_descale](auto& x) { x *= k_descale; }, s_acc);
|
||||
}
|
||||
// PER_TOKEN_HEAD: dequantize QK result with per-row Q descale and per-column K descale.
|
||||
// s_acc[i,j] *= q_descale[q_origin+i, qo_head] * k_descale[k_page, k_slot+j, kv_head]
|
||||
// s_acc[i,j] *= q_descale[q_origin+i, qo_head] * k_descale[k_page(j), k_slot(j), kv_head]
|
||||
// Supports cross-page tiles (kPageBlockSize < kN0): column j is looked up in the
|
||||
// page covering token (k_origin + j).
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
const index_t k_page = k_physical_pages[number<0>{}];
|
||||
const index_t k_slot_base = k_origin.at(number<0>{}) % kPageBlockSize;
|
||||
const index_t qo_head = block_indices.qo_head_idx;
|
||||
const index_t kv_head = block_indices.kv_head_idx;
|
||||
const index_t q_row_base = q_origin.at(number<0>{});
|
||||
|
||||
const index_t k_page_base = k_page * nblock_stride_k_descale_page +
|
||||
kv_head * nhead_stride_k_descale;
|
||||
// Number of distinct pages this tile spans.
|
||||
// page_size >= kN0 -> 1 (fast path, identical to original behavior)
|
||||
// page_size < kN0 -> kN0 / page_size (cross-page tile)
|
||||
constexpr index_t kPagesPerTile =
|
||||
(kPageBlockSize >= kN0) ? 1 : (kN0 / kPageBlockSize);
|
||||
constexpr index_t kLog2PageBlockSize = []{
|
||||
index_t shift = 0;
|
||||
index_t val = kPageBlockSize;
|
||||
while(val > 1) { val >>= 1; ++shift; }
|
||||
return shift;
|
||||
}();
|
||||
constexpr index_t kPageSlotMask = kPageBlockSize - 1;
|
||||
|
||||
// Physical pages for each kPageBlockSize-wide column slice of the tile.
|
||||
// Tiny array (1 or kN0/kPageBlockSize entries); compiler keeps in registers.
|
||||
index_t tile_k_pages[kPagesPerTile];
|
||||
if constexpr(kPagesPerTile == 1)
|
||||
{
|
||||
// Single-page tile: reuse the page already loaded for K-gemm.
|
||||
tile_k_pages[0] = k_physical_pages[number<0>{}];
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t k_origin_n = k_origin.at(number<0>{});
|
||||
static_for<0, kPagesPerTile, 1>{}([&](auto p) {
|
||||
const index_t gp = (k_origin_n + p.value * kPageBlockSize)
|
||||
>> kLog2PageBlockSize;
|
||||
tile_k_pages[p.value] =
|
||||
page_idx[ck_tile::min(gp, max_page_table_idx)];
|
||||
});
|
||||
}
|
||||
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
@@ -1137,8 +1165,16 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const float qd = q_descale_per_token_ptr[
|
||||
(q_row_base + i) * stride_q_descale_token +
|
||||
qo_head * nhead_stride_q_descale];
|
||||
// Per-column page + slot. For kPagesPerTile==1 the
|
||||
// selector folds to 0 at compile time.
|
||||
const index_t k_page = tile_k_pages[
|
||||
(kPagesPerTile == 1) ? index_t{0}
|
||||
: (j >> kLog2PageBlockSize)];
|
||||
const index_t k_slot = j & kPageSlotMask;
|
||||
const float kd = k_descale_ptr[
|
||||
k_page_base + (k_slot_base + j) * stride_k_descale_token];
|
||||
k_page * nblock_stride_k_descale_page +
|
||||
kv_head * nhead_stride_k_descale +
|
||||
k_slot * stride_k_descale_token];
|
||||
s_acc(i_j_idx) *= qd * kd;
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user