[rocm-libraries] ROCm/rocm-libraries#4263 (commit f34aec2)

[CK] Add FP8 KV_BLOCKSCALE support for batch prefill
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Implement per-page K/V quantization for paged attention:
  - Add KV_BLOCKSCALE enum to BlockAttentionQuantScaleEnum
  - Use exp2 shift trick to eliminate explicit P scaling overhead
- Prefetch physical pages offset for KV cache, overlaps with
computations

## Proposed changes

Please describe the motivation behind the pull request, whether it
enables a new feature or fixes a bug. If there are associated pull
requests or issues, please link them to the pull request.

## Checklist

Please put an `x` into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [ ] I have run `clang-format` on all changed files
- [ ] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered
This commit is contained in:
Jeff Huang
2026-02-04 23:26:20 +00:00
committed by assistant-librarian[bot]
parent 62fbda4d1e
commit 7b18f5fed2
8 changed files with 559 additions and 105 deletions

View File

@@ -10,6 +10,7 @@
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
#include <cassert>
#include <string>
#include <type_traits>
#include <utility>
@@ -185,13 +186,45 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t batch_stride_lse = 0;
};
struct FmhaFwdCommonQScaleKargs
// PERTENSOR: Q/K/V all use per-tensor descales
struct FmhaFwdPerTensorQScaleKargs
{
const void* q_descale_ptr = nullptr;
const void* k_descale_ptr = nullptr;
const void* v_descale_ptr = nullptr;
};
// KV_BLOCKSCALE: Q per-tensor, K/V per-page descales
// K descale: [num_block, num_kv_head], V descale: [num_block, num_kv_head]
struct FmhaFwdKVBlockScaleKargs
{
const void* q_descale_ptr = nullptr; // Per-tensor Q descale
const void* k_descale_ptr = nullptr; // [num_block, num_kv_head]
const void* v_descale_ptr = nullptr; // [num_block, num_kv_head]
ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
};
// Helper template to select QScale Kargs type based on QScaleEnum
// EmptyType: type to use when QScaleEnum is NO_SCALE (e.g., FmhaFwdEmptyKargs<3>)
template <BlockAttentionQuantScaleEnum QScale, typename EmptyType>
struct GetQScaleKargs
{
using type = EmptyType;
};
template <typename EmptyType>
struct GetQScaleKargs<BlockAttentionQuantScaleEnum::PERTENSOR, EmptyType>
{
using type = FmhaFwdPerTensorQScaleKargs;
};
template <typename EmptyType>
struct GetQScaleKargs<BlockAttentionQuantScaleEnum::KV_BLOCKSCALE, EmptyType>
{
using type = FmhaFwdKVBlockScaleKargs;
};
struct FmhaFwdDropoutSeedOffset
{
template <typename T>
@@ -255,9 +288,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<3>>,
GetQScaleKargs<QScaleEnum, FmhaFwdEmptyKargs<3>>::type,
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
{
@@ -276,9 +307,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<3>>,
GetQScaleKargs<QScaleEnum, FmhaFwdEmptyKargs<3>>::type,
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
{
@@ -348,7 +377,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
const void* sink_ptr = nullptr)
const void* sink_ptr = nullptr,
ck_tile::index_t nblock_stride_kv_block_descale = 0,
ck_tile::index_t nhead_stride_kv_block_descale = 0)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -419,6 +450,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale;
kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale;
}
if constexpr(kHasDropout)
{
if(drop_seed_offset.index() == 0) // seed & offset come from host
@@ -495,7 +534,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
const void* sink_ptr = nullptr)
const void* sink_ptr = nullptr,
ck_tile::index_t nblock_stride_kv_block_descale = 0,
ck_tile::index_t nhead_stride_kv_block_descale = 0)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -563,6 +604,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale;
kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale;
}
if constexpr(kHasDropout)
{
if(drop_seed_offset.index() == 0) // seed & offset come from host
@@ -1157,11 +1206,20 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
const float scale_s = [&] {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
assert(kargs.q_descale_ptr != nullptr);
assert(kargs.k_descale_ptr != nullptr);
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
return kargs.scale_s * q_descale * k_descale;
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
// Q is per-tensor, K is per-page (handled in pipeline)
assert(kargs.q_descale_ptr != nullptr);
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
return kargs.scale_s * q_descale;
}
else
{
return kargs.scale_s;
@@ -1194,6 +1252,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
// TODO - move global load of descale to pipeline
assert(kargs.v_descale_ptr != nullptr);
float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
float scale_p = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
@@ -1237,6 +1296,39 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
dropout,
sink_value);
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
// KV_BLOCKSCALE: K/V descale is per-page, handled in pipeline
assert(kargs.k_descale_ptr != nullptr);
assert(kargs.v_descale_ptr != nullptr);
const float* k_descale_ptr = reinterpret_cast<const float*>(kargs.k_descale_ptr);
const float* v_descale_ptr = reinterpret_cast<const float*>(kargs.v_descale_ptr);
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
randval_dram_window,
lse_dram_window,
mask,
position_encoding,
variant_params.sm_scale,
variant,
variant_params,
block_indices,
smem_ptr,
page_idx,
stride_k_for_pipeline,
stride_v_for_pipeline,
kargs.batch_stride_k,
kargs.batch_stride_v,
dropout,
sink_value,
k_descale_ptr,
v_descale_ptr,
kargs.nblock_stride_kv_block_descale,
kargs.nhead_stride_kv_block_descale);
}
else
{
return FmhaPipeline{}(q_dram_window,