[rocm-libraries] ROCm/rocm-libraries#4263 (commit f34aec2)

[CK] Add FP8 KV_BLOCKSCALE support for batch prefill
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Implement per-page K/V quantization for paged attention:
  - Add KV_BLOCKSCALE enum to BlockAttentionQuantScaleEnum
  - Use exp2 shift trick to eliminate explicit P scaling overhead
- Prefetch physical pages offset for KV cache, overlaps with
computations

## Proposed changes

Please describe the motivation behind the pull request, whether it
enables a new feature or fixes a bug. If there are associated pull
requests or issues, please link them to the pull request.

## Checklist

Please put an `x` into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [ ] I have run `clang-format` on all changed files
- [ ] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered
This commit is contained in:
Jeff Huang
2026-02-04 23:26:20 +00:00
committed by assistant-librarian[bot]
parent 62fbda4d1e
commit 7b18f5fed2
8 changed files with 559 additions and 105 deletions

View File

@@ -78,12 +78,14 @@ QSCALE_MAP = {
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
}
QSCALE_CHECK_MAP = {
"no": "quant_scale_enum::no_scale",
"pertensor": "quant_scale_enum::pertensor",
"blockscale": "quant_scale_enum::blockscale",
"kv_blockscale": "quant_scale_enum::kv_blockscale",
}
BIAS_MAP = {

View File

@@ -677,7 +677,7 @@ class KernelComponentFactory:
kv_lookup_table,
) in itertools.product(
["t", "f"],
["pertensor"],
["pertensor", "kv_blockscale"],
get_mask_map(mask_impl).keys(),
["no"],
SUPPORTED_KV_MEMORY_LAYOUT,
@@ -740,6 +740,10 @@ def get_fwd_blobs(
for page_size in SUPPORTED_PAGE_SIZE:
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
continue
# kv_blockscale requires page_size >= kN0 (tile.F_bn0)
# This ensures all tokens in a main loop iteration belong to the same page
if pipeline.F_qscale == "kv_blockscale" and page_size < tile.F_bn0:
continue
k = FmhaFwdKernel(
F_idx=0,
F_hdim=hdim,

View File

@@ -602,6 +602,13 @@ struct fmha_batch_prefill_args
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset;
// KV_BLOCKSCALE: per-page K/V descales (Q per-tensor, K/V per-page)
// k_descale_ptr/v_descale_ptr are reused for KV_BLOCKSCALE mode:
// k_descale_ptr: [num_block, num_kv_head] - points to k block descale
// v_descale_ptr: [num_block, num_kv_head] - points to v block descale
ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
};
template <typename FmhaKernel>
@@ -1225,7 +1232,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.sink_ptr);
args.sink_ptr,
args.nblock_stride_kv_block_descale,
args.nhead_stride_kv_block_descale);
}
else
{ // create batch mode kernel arguments
@@ -1278,7 +1287,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.sink_ptr);
args.sink_ptr,
args.nblock_stride_kv_block_descale,
args.nhead_stride_kv_block_descale);
}
}();

View File

@@ -14,9 +14,10 @@
// keep sync with BlockAttentionQuantScaleEnum
enum class quant_scale_enum
{
no_scale = 0,
pertensor = 1,
blockscale,
no_scale = 0,
pertensor = 1,
blockscale = 2,
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
};
struct quant_scale_info
@@ -31,6 +32,8 @@ struct quant_scale_info
os << "pt";
else if(type == quant_scale_enum::blockscale)
os << "bs";
else if(type == quant_scale_enum::kv_blockscale)
os << "kvbs";
}
static quant_scale_info decode(std::string str)
@@ -48,6 +51,10 @@ struct quant_scale_info
{
info.type = quant_scale_enum::blockscale;
}
else if(str == "kvbs" || str == "3")
{
info.type = quant_scale_enum::kv_blockscale;
}
else
{
throw std::invalid_argument("invalid quant scale value: " + str);