mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4263 (commit f34aec2)
[CK] Add FP8 KV_BLOCKSCALE support for batch prefill MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement per-page K/V quantization for paged attention: - Add KV_BLOCKSCALE enum to BlockAttentionQuantScaleEnum - Use exp2 shift trick to eliminate explicit P scaling overhead - Prefetch physical pages offset for KV cache, overlaps with computations ## Proposed changes Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request. ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [ ] I have run `clang-format` on all changed files - [ ] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
This commit is contained in:
committed by
assistant-librarian[bot]
parent
62fbda4d1e
commit
7b18f5fed2
@@ -10,6 +10,7 @@
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
|
||||
#include <cassert>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
@@ -185,13 +186,45 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
ck_tile::index_t batch_stride_lse = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonQScaleKargs
|
||||
// PERTENSOR: Q/K/V all use per-tensor descales
|
||||
struct FmhaFwdPerTensorQScaleKargs
|
||||
{
|
||||
const void* q_descale_ptr = nullptr;
|
||||
const void* k_descale_ptr = nullptr;
|
||||
const void* v_descale_ptr = nullptr;
|
||||
};
|
||||
|
||||
// KV_BLOCKSCALE: Q per-tensor, K/V per-page descales
|
||||
// K descale: [num_block, num_kv_head], V descale: [num_block, num_kv_head]
|
||||
struct FmhaFwdKVBlockScaleKargs
|
||||
{
|
||||
const void* q_descale_ptr = nullptr; // Per-tensor Q descale
|
||||
const void* k_descale_ptr = nullptr; // [num_block, num_kv_head]
|
||||
const void* v_descale_ptr = nullptr; // [num_block, num_kv_head]
|
||||
ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension
|
||||
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
|
||||
};
|
||||
|
||||
// Helper template to select QScale Kargs type based on QScaleEnum
|
||||
// EmptyType: type to use when QScaleEnum is NO_SCALE (e.g., FmhaFwdEmptyKargs<3>)
|
||||
template <BlockAttentionQuantScaleEnum QScale, typename EmptyType>
|
||||
struct GetQScaleKargs
|
||||
{
|
||||
using type = EmptyType;
|
||||
};
|
||||
|
||||
template <typename EmptyType>
|
||||
struct GetQScaleKargs<BlockAttentionQuantScaleEnum::PERTENSOR, EmptyType>
|
||||
{
|
||||
using type = FmhaFwdPerTensorQScaleKargs;
|
||||
};
|
||||
|
||||
template <typename EmptyType>
|
||||
struct GetQScaleKargs<BlockAttentionQuantScaleEnum::KV_BLOCKSCALE, EmptyType>
|
||||
{
|
||||
using type = FmhaFwdKVBlockScaleKargs;
|
||||
};
|
||||
|
||||
struct FmhaFwdDropoutSeedOffset
|
||||
{
|
||||
template <typename T>
|
||||
@@ -255,9 +288,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
FmhaFwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>,
|
||||
GetQScaleKargs<QScaleEnum, FmhaFwdEmptyKargs<3>>::type,
|
||||
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
|
||||
{
|
||||
@@ -276,9 +307,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
FmhaFwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>,
|
||||
GetQScaleKargs<QScaleEnum, FmhaFwdEmptyKargs<3>>::type,
|
||||
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
|
||||
{
|
||||
@@ -348,7 +377,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
const void* sink_ptr = nullptr)
|
||||
const void* sink_ptr = nullptr,
|
||||
ck_tile::index_t nblock_stride_kv_block_descale = 0,
|
||||
ck_tile::index_t nhead_stride_kv_block_descale = 0)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -419,6 +450,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale;
|
||||
kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
@@ -495,7 +534,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
const void* sink_ptr = nullptr)
|
||||
const void* sink_ptr = nullptr,
|
||||
ck_tile::index_t nblock_stride_kv_block_descale = 0,
|
||||
ck_tile::index_t nhead_stride_kv_block_descale = 0)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -563,6 +604,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
kargs.nblock_stride_kv_block_descale = nblock_stride_kv_block_descale;
|
||||
kargs.nhead_stride_kv_block_descale = nhead_stride_kv_block_descale;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
@@ -1157,11 +1206,20 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
const float scale_s = [&] {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
assert(kargs.q_descale_ptr != nullptr);
|
||||
assert(kargs.k_descale_ptr != nullptr);
|
||||
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
|
||||
float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
|
||||
|
||||
return kargs.scale_s * q_descale * k_descale;
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
{
|
||||
// Q is per-tensor, K is per-page (handled in pipeline)
|
||||
assert(kargs.q_descale_ptr != nullptr);
|
||||
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
|
||||
return kargs.scale_s * q_descale;
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.scale_s;
|
||||
@@ -1194,6 +1252,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
// TODO - move global load of descale to pipeline
|
||||
assert(kargs.v_descale_ptr != nullptr);
|
||||
float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
|
||||
|
||||
float scale_p = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
|
||||
@@ -1237,6 +1296,39 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
dropout,
|
||||
sink_value);
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
{
|
||||
// KV_BLOCKSCALE: K/V descale is per-page, handled in pipeline
|
||||
assert(kargs.k_descale_ptr != nullptr);
|
||||
assert(kargs.v_descale_ptr != nullptr);
|
||||
const float* k_descale_ptr = reinterpret_cast<const float*>(kargs.k_descale_ptr);
|
||||
const float* v_descale_ptr = reinterpret_cast<const float*>(kargs.v_descale_ptr);
|
||||
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
mask,
|
||||
position_encoding,
|
||||
variant_params.sm_scale,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
page_idx,
|
||||
stride_k_for_pipeline,
|
||||
stride_v_for_pipeline,
|
||||
kargs.batch_stride_k,
|
||||
kargs.batch_stride_v,
|
||||
dropout,
|
||||
sink_value,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
kargs.nblock_stride_kv_block_descale,
|
||||
kargs.nhead_stride_kv_block_descale);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
|
||||
Reference in New Issue
Block a user