mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
[rocm-libraries] ROCm/rocm-libraries#4263 (commit f34aec2)
[CK] Add FP8 KV_BLOCKSCALE support for batch prefill MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement per-page K/V quantization for paged attention: - Add KV_BLOCKSCALE enum to BlockAttentionQuantScaleEnum - Use exp2 shift trick to eliminate explicit P scaling overhead - Prefetch physical pages offset for KV cache, overlaps with computations ## Proposed changes Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request. ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [ ] I have run `clang-format` on all changed files - [ ] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
This commit is contained in:
committed by
assistant-librarian[bot]
parent
62fbda4d1e
commit
7b18f5fed2
@@ -677,7 +677,7 @@ class KernelComponentFactory:
|
||||
kv_lookup_table,
|
||||
) in itertools.product(
|
||||
["t", "f"],
|
||||
["pertensor"],
|
||||
["pertensor", "kv_blockscale"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
@@ -740,6 +740,10 @@ def get_fwd_blobs(
|
||||
for page_size in SUPPORTED_PAGE_SIZE:
|
||||
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
|
||||
continue
|
||||
# kv_blockscale requires page_size >= kN0 (tile.F_bn0)
|
||||
# This ensures all tokens in a main loop iteration belong to the same page
|
||||
if pipeline.F_qscale == "kv_blockscale" and page_size < tile.F_bn0:
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
|
||||
Reference in New Issue
Block a user