[AITERKER-112] Add PER_TOKEN_HEAD FP8 quant scheme to batch_prefill

- New BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD enum value
- Pipeline overload in block_fmha_batch_prefill_pipeline_qr_ks_vs_async
  applying per-token Q/K descale via GEMM0-post outer product and
  per-head V descale at epilogue
- fmha_batch_prefill_kernel kargs + MakeKargs + pipeline dispatch
- fmha_fwd.hpp host-side traits/args wiring
- quant.hpp trait specialization
- Codegen emits PER_TOKEN_HEAD kernel variants
This commit is contained in:
msaffari-amd
2026-05-19 15:41:32 +00:00
parent 83566edb0f
commit 403d99124d
7 changed files with 317 additions and 22 deletions

View File

@@ -81,6 +81,7 @@ QSCALE_MAP = {
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
"per_token_head": "ck_tile::BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD",
"mx": "ck_tile::BlockAttentionQuantScaleEnum::MX",
}
@@ -89,6 +90,7 @@ QSCALE_CHECK_MAP = {
"pertensor": "quant_scale_enum::pertensor",
"blockscale": "quant_scale_enum::blockscale",
"kv_blockscale": "quant_scale_enum::kv_blockscale",
"per_token_head": "quant_scale_enum::per_token_head",
"mx": "quant_scale_enum::mx",
}

View File

@@ -733,7 +733,7 @@ class KernelComponentFactory:
kv_lookup_table,
) in itertools.product(
["t", "f"],
["pertensor", "kv_blockscale"],
["pertensor", "kv_blockscale", "per_token_head"],
get_mask_map(mask_impl).keys(),
["no"],
["t", "f"],
@@ -819,9 +819,12 @@ def get_fwd_blobs(
for page_size in SUPPORTED_PAGE_SIZE:
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
continue
# kv_blockscale requires page_size >= kN0 (tile.F_bn0)
# kv_blockscale / per_token_head require page_size >= kN0 (tile.F_bn0)
# This ensures all tokens in a main loop iteration belong to the same page
if pipeline.F_qscale == "kv_blockscale" and page_size < tile.F_bn0:
if (
pipeline.F_qscale in ("kv_blockscale", "per_token_head")
and page_size < tile.F_bn0
):
continue
k = FmhaFwdKernel(
F_idx=0,

View File

@@ -671,6 +671,19 @@ struct fmha_batch_prefill_args
// v_descale_ptr: [num_block, num_kv_head] - points to v block descale
ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
// PER_TOKEN_HEAD: Q/K per-token per-head, V per-head (FP8 fine-grained).
// q_descale_ptr/k_descale_ptr/v_descale_ptr are reused; layout:
// q_descale: [total_q_tokens, nhead_q] fp32
// k_descale: [num_total_pages, page_block_size, nhead_k] fp32
// (aligned with paged K cache so we can reuse k_physical_pages[])
// v_descale: [nhead_k] fp32
ck_tile::index_t stride_q_descale_token = 0; // Q descale: row stride (per-token)
ck_tile::index_t nhead_stride_q_descale = 0; // Q descale: head stride
ck_tile::index_t nblock_stride_k_descale_page = 0; // K descale: page stride
ck_tile::index_t stride_k_descale_token = 0; // K descale: within-page token stride
ck_tile::index_t nhead_stride_k_descale = 0; // K descale: head stride
ck_tile::index_t nhead_stride_v_descale = 0; // V descale: head stride (per-head only)
};
// Selects the KV-cache load mode for a batch-prefill dispatch arm.
@@ -1340,7 +1353,13 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.drop_seed_offset,
args.sink_ptr,
args.nblock_stride_kv_block_descale,
args.nhead_stride_kv_block_descale);
args.nhead_stride_kv_block_descale,
args.stride_q_descale_token,
args.nhead_stride_q_descale,
args.nblock_stride_k_descale_page,
args.stride_k_descale_token,
args.nhead_stride_k_descale,
args.nhead_stride_v_descale);
}
else
{ // create batch mode kernel arguments
@@ -1395,7 +1414,13 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.drop_seed_offset,
args.sink_ptr,
args.nblock_stride_kv_block_descale,
args.nhead_stride_kv_block_descale);
args.nhead_stride_kv_block_descale,
args.stride_q_descale_token,
args.nhead_stride_q_descale,
args.nblock_stride_k_descale_page,
args.stride_k_descale_token,
args.nhead_stride_k_descale,
args.nhead_stride_v_descale);
}
}();

View File

@@ -14,11 +14,12 @@
// keep sync with BlockAttentionQuantScaleEnum
enum class quant_scale_enum
{
no_scale = 0,
pertensor = 1,
blockscale = 2,
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
mx = 4, // Microscaling (MX)
no_scale = 0,
pertensor = 1,
blockscale = 2,
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
mx = 4, // Microscaling (MX)
per_token_head = 5, // Q/K per-token per-head, V per-head (FP8 fine-grained)
};
struct quant_scale_info
@@ -37,6 +38,8 @@ struct quant_scale_info
os << "kvbs";
else if(type == quant_scale_enum::mx)
os << "mx";
else if(type == quant_scale_enum::per_token_head)
os << "pth";
}
static quant_scale_info decode(std::string str)
@@ -62,6 +65,10 @@ struct quant_scale_info
{
info.type = quant_scale_enum::mx;
}
else if(str == "pth" || str == "5")
{
info.type = quant_scale_enum::per_token_head;
}
else
{
throw std::invalid_argument("invalid quant scale value: " + str);