mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
[CK_TILE] Add paged-kvcache support in group mode fmha fwd splitkv kernels (#1678)
* Generate group mode paged-attn kernel * Enable paged-kvcache + group mode support * Add missing header: fused_moe.hpp * Add comment to explain kernel arg usage * Make error message more clear * Add comment for confusing data member names * Add more comment for confusing variable names * Fix typo in option description
This commit is contained in:
@@ -46,8 +46,7 @@ struct FmhaFwdSplitKVKernel
|
||||
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
|
||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
|
||||
static_assert(!kIsGroupMode || (kIsGroupMode && !kIsPagedKV),
|
||||
"paged-kvcache only supported by batch mode kernels");
|
||||
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
|
||||
@@ -198,8 +197,10 @@ struct FmhaFwdSplitKVKernel
|
||||
const int32_t* seqlen_k_ptr;
|
||||
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_k; // when using paged-kvcache, this will be stride/size for
|
||||
// single kcache page-block
|
||||
ck_tile::index_t batch_stride_v; // when using paged-kvcache, this will be stride/size for
|
||||
// single vcache page-block
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
};
|
||||
@@ -212,14 +213,17 @@ struct FmhaFwdSplitKVKernel
|
||||
AlibiKargs,
|
||||
EmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>
|
||||
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
|
||||
std::conditional_t<kIsPagedKV, PageBlockTableKargs, EmptyKargs<3>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
|
||||
ck_tile::index_t batch_stride_k; // only used for paged-kvcache
|
||||
ck_tile::index_t batch_stride_v; // only used for paged-kvcache
|
||||
ck_tile::index_t batch_stride_k; // only used for paged-kvcache, this will be stride/size
|
||||
// for single kcache page-block
|
||||
ck_tile::index_t batch_stride_v; // only used for paged-kvcache, this will be stride/size
|
||||
// for single vcache page-block
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
|
||||
@@ -363,6 +367,9 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
ck_tile::index_t num_splits,
|
||||
const void* block_table_ptr,
|
||||
ck_tile::index_t batch_stride_block_table,
|
||||
ck_tile::index_t page_block_size,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
ck_tile::index_t stride_q,
|
||||
@@ -416,6 +423,7 @@ struct FmhaFwdSplitKVKernel
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
{}, // placeholder for paged-block table
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
|
||||
@@ -443,6 +451,12 @@ struct FmhaFwdSplitKVKernel
|
||||
{
|
||||
kargs.scale_p = scale_p;
|
||||
}
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
|
||||
kargs.batch_stride_block_table = batch_stride_block_table;
|
||||
kargs.page_block_size = page_block_size;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -489,15 +503,22 @@ struct FmhaFwdSplitKVKernel
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_v = key_start;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_v = key_start;
|
||||
}
|
||||
}
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
@@ -685,7 +706,7 @@ struct FmhaFwdSplitKVKernel
|
||||
|
||||
return make_page_block_navigator<const KDataType, 0>(
|
||||
kargs.k_ptr,
|
||||
kargs.batch_stride_k,
|
||||
kargs.batch_stride_k, // kcache page-block stride/size
|
||||
fixed_offset,
|
||||
block_indices,
|
||||
num_blocks,
|
||||
@@ -715,7 +736,7 @@ struct FmhaFwdSplitKVKernel
|
||||
|
||||
return make_page_block_navigator<const VDataType, 1>(
|
||||
kargs.v_ptr,
|
||||
kargs.batch_stride_v,
|
||||
kargs.batch_stride_v, // vcache page-block stride/size
|
||||
fixed_offset,
|
||||
block_indices,
|
||||
num_blocks,
|
||||
|
||||
Reference in New Issue
Block a user