mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
[AITERKER-112] Add PER_TOKEN_HEAD FP8 quant scheme to batch_prefill
- New BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD enum value - Pipeline overload in block_fmha_batch_prefill_pipeline_qr_ks_vs_async applying per-token Q/K descale via GEMM0-post outer product and per-head V descale at epilogue - fmha_batch_prefill_kernel kargs + MakeKargs + pipeline dispatch - fmha_fwd.hpp host-side traits/args wiring - quant.hpp trait specialization - Codegen emits PER_TOKEN_HEAD kernel variants
This commit is contained in:
@@ -81,6 +81,7 @@ QSCALE_MAP = {
|
||||
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
|
||||
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
|
||||
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
|
||||
"per_token_head": "ck_tile::BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD",
|
||||
"mx": "ck_tile::BlockAttentionQuantScaleEnum::MX",
|
||||
}
|
||||
|
||||
@@ -89,6 +90,7 @@ QSCALE_CHECK_MAP = {
|
||||
"pertensor": "quant_scale_enum::pertensor",
|
||||
"blockscale": "quant_scale_enum::blockscale",
|
||||
"kv_blockscale": "quant_scale_enum::kv_blockscale",
|
||||
"per_token_head": "quant_scale_enum::per_token_head",
|
||||
"mx": "quant_scale_enum::mx",
|
||||
}
|
||||
|
||||
|
||||
@@ -733,7 +733,7 @@ class KernelComponentFactory:
|
||||
kv_lookup_table,
|
||||
) in itertools.product(
|
||||
["t", "f"],
|
||||
["pertensor", "kv_blockscale"],
|
||||
["pertensor", "kv_blockscale", "per_token_head"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
["t", "f"],
|
||||
@@ -819,9 +819,12 @@ def get_fwd_blobs(
|
||||
for page_size in SUPPORTED_PAGE_SIZE:
|
||||
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
|
||||
continue
|
||||
# kv_blockscale requires page_size >= kN0 (tile.F_bn0)
|
||||
# kv_blockscale / per_token_head require page_size >= kN0 (tile.F_bn0)
|
||||
# This ensures all tokens in a main loop iteration belong to the same page
|
||||
if pipeline.F_qscale == "kv_blockscale" and page_size < tile.F_bn0:
|
||||
if (
|
||||
pipeline.F_qscale in ("kv_blockscale", "per_token_head")
|
||||
and page_size < tile.F_bn0
|
||||
):
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
|
||||
@@ -671,6 +671,19 @@ struct fmha_batch_prefill_args
|
||||
// v_descale_ptr: [num_block, num_kv_head] - points to v block descale
|
||||
ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension
|
||||
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
|
||||
|
||||
// PER_TOKEN_HEAD: Q/K per-token per-head, V per-head (FP8 fine-grained).
|
||||
// q_descale_ptr/k_descale_ptr/v_descale_ptr are reused; layout:
|
||||
// q_descale: [total_q_tokens, nhead_q] fp32
|
||||
// k_descale: [num_total_pages, page_block_size, nhead_k] fp32
|
||||
// (aligned with paged K cache so we can reuse k_physical_pages[])
|
||||
// v_descale: [nhead_k] fp32
|
||||
ck_tile::index_t stride_q_descale_token = 0; // Q descale: row stride (per-token)
|
||||
ck_tile::index_t nhead_stride_q_descale = 0; // Q descale: head stride
|
||||
ck_tile::index_t nblock_stride_k_descale_page = 0; // K descale: page stride
|
||||
ck_tile::index_t stride_k_descale_token = 0; // K descale: within-page token stride
|
||||
ck_tile::index_t nhead_stride_k_descale = 0; // K descale: head stride
|
||||
ck_tile::index_t nhead_stride_v_descale = 0; // V descale: head stride (per-head only)
|
||||
};
|
||||
|
||||
// Selects the KV-cache load mode for a batch-prefill dispatch arm.
|
||||
@@ -1340,7 +1353,13 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.drop_seed_offset,
|
||||
args.sink_ptr,
|
||||
args.nblock_stride_kv_block_descale,
|
||||
args.nhead_stride_kv_block_descale);
|
||||
args.nhead_stride_kv_block_descale,
|
||||
args.stride_q_descale_token,
|
||||
args.nhead_stride_q_descale,
|
||||
args.nblock_stride_k_descale_page,
|
||||
args.stride_k_descale_token,
|
||||
args.nhead_stride_k_descale,
|
||||
args.nhead_stride_v_descale);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -1395,7 +1414,13 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.drop_seed_offset,
|
||||
args.sink_ptr,
|
||||
args.nblock_stride_kv_block_descale,
|
||||
args.nhead_stride_kv_block_descale);
|
||||
args.nhead_stride_kv_block_descale,
|
||||
args.stride_q_descale_token,
|
||||
args.nhead_stride_q_descale,
|
||||
args.nblock_stride_k_descale_page,
|
||||
args.stride_k_descale_token,
|
||||
args.nhead_stride_k_descale,
|
||||
args.nhead_stride_v_descale);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -14,11 +14,12 @@
|
||||
// keep sync with BlockAttentionQuantScaleEnum
|
||||
enum class quant_scale_enum
|
||||
{
|
||||
no_scale = 0,
|
||||
pertensor = 1,
|
||||
blockscale = 2,
|
||||
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
|
||||
mx = 4, // Microscaling (MX)
|
||||
no_scale = 0,
|
||||
pertensor = 1,
|
||||
blockscale = 2,
|
||||
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
|
||||
mx = 4, // Microscaling (MX)
|
||||
per_token_head = 5, // Q/K per-token per-head, V per-head (FP8 fine-grained)
|
||||
};
|
||||
|
||||
struct quant_scale_info
|
||||
@@ -37,6 +38,8 @@ struct quant_scale_info
|
||||
os << "kvbs";
|
||||
else if(type == quant_scale_enum::mx)
|
||||
os << "mx";
|
||||
else if(type == quant_scale_enum::per_token_head)
|
||||
os << "pth";
|
||||
}
|
||||
|
||||
static quant_scale_info decode(std::string str)
|
||||
@@ -62,6 +65,10 @@ struct quant_scale_info
|
||||
{
|
||||
info.type = quant_scale_enum::mx;
|
||||
}
|
||||
else if(str == "pth" || str == "5")
|
||||
{
|
||||
info.type = quant_scale_enum::per_token_head;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid quant scale value: " + str);
|
||||
|
||||
Reference in New Issue
Block a user