[CK_TILE] Add paged-kvcache support in group mode fmha fwd splitkv kernels (#1678)

* Generate group mode paged-attn kernel

* Enable paged-kvcache + group mode support

* Add missing header: fused_moe.hpp

* Add comment to explain kernel arg usage

* Make error message more clear

* Add comment for confusing data member names

* Add more comment for confusing variable names

* Fix typo in option description
This commit is contained in:
Po Yen Chen
2024-11-21 14:53:10 +08:00
committed by GitHub
parent 6916d8cc03
commit fb1ccfa9df
6 changed files with 95 additions and 43 deletions

View File

@@ -46,8 +46,7 @@ struct FmhaFwdSplitKVKernel
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
static_assert(!kIsGroupMode || (kIsGroupMode && !kIsPagedKV),
"paged-kvcache only supported by batch mode kernels");
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
@@ -198,8 +197,10 @@ struct FmhaFwdSplitKVKernel
const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_k; // when using paged-kvcache, this will be stride/size for
// single kcache page-block
ck_tile::index_t batch_stride_v; // when using paged-kvcache, this will be stride/size for
// single vcache page-block
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
};
@@ -212,14 +213,17 @@ struct FmhaFwdSplitKVKernel
AlibiKargs,
EmptyKargs<0>>>,
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
std::conditional_t<kIsPagedKV, PageBlockTableKargs, EmptyKargs<3>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_k; // only used for paged-kvcache
ck_tile::index_t batch_stride_v; // only used for paged-kvcache
ck_tile::index_t batch_stride_k; // only used for paged-kvcache, this will be stride/size
// for single kcache page-block
ck_tile::index_t batch_stride_v; // only used for paged-kvcache, this will be stride/size
// for single vcache page-block
};
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
@@ -363,6 +367,9 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
ck_tile::index_t num_splits,
const void* block_table_ptr,
ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size,
float scale_s,
float scale_p,
ck_tile::index_t stride_q,
@@ -416,6 +423,7 @@ struct FmhaFwdSplitKVKernel
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for fp8_static_quant args
{}, // placeholder for paged-block table
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
@@ -443,6 +451,12 @@ struct FmhaFwdSplitKVKernel
{
kargs.scale_p = scale_p;
}
if constexpr(kIsPagedKV)
{
kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
kargs.batch_stride_block_table = batch_stride_block_table;
kargs.page_block_size = page_block_size;
}
return kargs;
}
@@ -489,15 +503,22 @@ struct FmhaFwdSplitKVKernel
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(kIsPagedKV)
{
batch_offset_v = key_start * kargs.stride_v;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
}
else
{
batch_offset_v = key_start;
batch_offset_k = key_start * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
}
else
{
batch_offset_v = key_start;
}
}
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
@@ -685,7 +706,7 @@ struct FmhaFwdSplitKVKernel
return make_page_block_navigator<const KDataType, 0>(
kargs.k_ptr,
kargs.batch_stride_k,
kargs.batch_stride_k, // kcache page-block stride/size
fixed_offset,
block_indices,
num_blocks,
@@ -715,7 +736,7 @@ struct FmhaFwdSplitKVKernel
return make_page_block_navigator<const VDataType, 1>(
kargs.v_ptr,
kargs.batch_stride_v,
kargs.batch_stride_v, // vcache page-block stride/size
fixed_offset,
block_indices,
num_blocks,