mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4263 (commit f34aec2)
[CK] Add FP8 KV_BLOCKSCALE support for batch prefill MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement per-page K/V quantization for paged attention: - Add KV_BLOCKSCALE enum to BlockAttentionQuantScaleEnum - Use exp2 shift trick to eliminate explicit P scaling overhead - Prefetch physical pages offset for KV cache, overlaps with computations ## Proposed changes Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request. ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [ ] I have run `clang-format` on all changed files - [ ] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
This commit is contained in:
committed by
assistant-librarian[bot]
parent
62fbda4d1e
commit
7b18f5fed2
@@ -78,12 +78,14 @@ QSCALE_MAP = {
|
||||
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
|
||||
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
|
||||
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
|
||||
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
|
||||
}
|
||||
|
||||
QSCALE_CHECK_MAP = {
|
||||
"no": "quant_scale_enum::no_scale",
|
||||
"pertensor": "quant_scale_enum::pertensor",
|
||||
"blockscale": "quant_scale_enum::blockscale",
|
||||
"kv_blockscale": "quant_scale_enum::kv_blockscale",
|
||||
}
|
||||
|
||||
BIAS_MAP = {
|
||||
|
||||
@@ -677,7 +677,7 @@ class KernelComponentFactory:
|
||||
kv_lookup_table,
|
||||
) in itertools.product(
|
||||
["t", "f"],
|
||||
["pertensor"],
|
||||
["pertensor", "kv_blockscale"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
@@ -740,6 +740,10 @@ def get_fwd_blobs(
|
||||
for page_size in SUPPORTED_PAGE_SIZE:
|
||||
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
|
||||
continue
|
||||
# kv_blockscale requires page_size >= kN0 (tile.F_bn0)
|
||||
# This ensures all tokens in a main loop iteration belong to the same page
|
||||
if pipeline.F_qscale == "kv_blockscale" and page_size < tile.F_bn0:
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
|
||||
@@ -602,6 +602,13 @@ struct fmha_batch_prefill_args
|
||||
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset;
|
||||
|
||||
// KV_BLOCKSCALE: per-page K/V descales (Q per-tensor, K/V per-page)
|
||||
// k_descale_ptr/v_descale_ptr are reused for KV_BLOCKSCALE mode:
|
||||
// k_descale_ptr: [num_block, num_kv_head] - points to k block descale
|
||||
// v_descale_ptr: [num_block, num_kv_head] - points to v block descale
|
||||
ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension
|
||||
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
|
||||
};
|
||||
|
||||
template <typename FmhaKernel>
|
||||
@@ -1225,7 +1232,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.sink_ptr);
|
||||
args.sink_ptr,
|
||||
args.nblock_stride_kv_block_descale,
|
||||
args.nhead_stride_kv_block_descale);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -1278,7 +1287,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.sink_ptr);
|
||||
args.sink_ptr,
|
||||
args.nblock_stride_kv_block_descale,
|
||||
args.nhead_stride_kv_block_descale);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -14,9 +14,10 @@
|
||||
// keep sync with BlockAttentionQuantScaleEnum
|
||||
enum class quant_scale_enum
|
||||
{
|
||||
no_scale = 0,
|
||||
pertensor = 1,
|
||||
blockscale,
|
||||
no_scale = 0,
|
||||
pertensor = 1,
|
||||
blockscale = 2,
|
||||
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
|
||||
};
|
||||
|
||||
struct quant_scale_info
|
||||
@@ -31,6 +32,8 @@ struct quant_scale_info
|
||||
os << "pt";
|
||||
else if(type == quant_scale_enum::blockscale)
|
||||
os << "bs";
|
||||
else if(type == quant_scale_enum::kv_blockscale)
|
||||
os << "kvbs";
|
||||
}
|
||||
|
||||
static quant_scale_info decode(std::string str)
|
||||
@@ -48,6 +51,10 @@ struct quant_scale_info
|
||||
{
|
||||
info.type = quant_scale_enum::blockscale;
|
||||
}
|
||||
else if(str == "kvbs" || str == "3")
|
||||
{
|
||||
info.type = quant_scale_enum::kv_blockscale;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid quant scale value: " + str);
|
||||
|
||||
Reference in New Issue
Block a user