[rocm-libraries] ROCm/rocm-libraries#4263 (commit f34aec2)

[CK] Add FP8 KV_BLOCKSCALE support for batch prefill
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Implement per-page K/V quantization for paged attention:
  - Add KV_BLOCKSCALE enum to BlockAttentionQuantScaleEnum
  - Use exp2 shift trick to eliminate explicit P scaling overhead
- Prefetch physical pages offset for KV cache, overlaps with
computations

## Proposed changes

Please describe the motivation behind the pull request, whether it
enables a new feature or fixes a bug. If there are associated pull
requests or issues, please link them to the pull request.

## Checklist

Please put an `x` into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [ ] I have run `clang-format` on all changed files
- [ ] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered
This commit is contained in:
Jeff Huang
2026-02-04 23:26:20 +00:00
committed by assistant-librarian[bot]
parent 62fbda4d1e
commit 7b18f5fed2
8 changed files with 559 additions and 105 deletions

View File

@@ -14,9 +14,10 @@
// keep sync with BlockAttentionQuantScaleEnum
enum class quant_scale_enum
{
no_scale = 0,
pertensor = 1,
blockscale,
no_scale = 0,
pertensor = 1,
blockscale = 2,
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
};
struct quant_scale_info
@@ -31,6 +32,8 @@ struct quant_scale_info
os << "pt";
else if(type == quant_scale_enum::blockscale)
os << "bs";
else if(type == quant_scale_enum::kv_blockscale)
os << "kvbs";
}
static quant_scale_info decode(std::string str)
@@ -48,6 +51,10 @@ struct quant_scale_info
{
info.type = quant_scale_enum::blockscale;
}
else if(str == "kvbs" || str == "3")
{
info.type = quant_scale_enum::kv_blockscale;
}
else
{
throw std::invalid_argument("invalid quant scale value: " + str);