mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
[CK_TILE] Fix incorrect computation of group mode PagedAttention (#1688)
* Allow getting batch size from splitkv tile partitioner * Fix wrong paged-kvcache impl for group mode * Fix wrong example code for page-kvcache * Undo changes in fmha_fwd.cpp * Always use 2D block table * Add is_gappy kernel argument for paged-kvcache The is_gappy argument is used for differentiating seqstart_k_ptr usage in flash-attention & xformers * Remove out-of-date comments * Remove no-longer used method * Fix wrong # page-block calculation * Fix wrong comment --------- Co-authored-by: Qianfeng <qianfeng.zhang@amd.com>
This commit is contained in:
@@ -172,13 +172,18 @@ struct FmhaFwdSplitKVKernel
|
||||
float scale_p;
|
||||
};
|
||||
|
||||
struct PageBlockTableKargs
|
||||
struct CommonPageBlockTableKargs
|
||||
{
|
||||
const int32_t* block_table_ptr;
|
||||
ck_tile::index_t batch_stride_block_table;
|
||||
ck_tile::index_t page_block_size;
|
||||
};
|
||||
|
||||
struct GroupModePageBlockTableKargs : CommonPageBlockTableKargs
|
||||
{
|
||||
bool is_gappy = false;
|
||||
};
|
||||
|
||||
struct CacheBatchIdxKargs
|
||||
{
|
||||
const int32_t* cache_batch_idx;
|
||||
@@ -193,7 +198,7 @@ struct FmhaFwdSplitKVKernel
|
||||
EmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
|
||||
std::conditional_t<kIsPagedKV, PageBlockTableKargs, CacheBatchIdxKargs>
|
||||
std::conditional_t<kIsPagedKV, CommonPageBlockTableKargs, CacheBatchIdxKargs>
|
||||
{
|
||||
const int32_t* seqlen_k_ptr;
|
||||
|
||||
@@ -215,7 +220,7 @@ struct FmhaFwdSplitKVKernel
|
||||
EmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
|
||||
std::conditional_t<kIsPagedKV, PageBlockTableKargs, EmptyKargs<3>>
|
||||
std::conditional_t<kIsPagedKV, GroupModePageBlockTableKargs, EmptyKargs<3>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
@@ -375,6 +380,7 @@ struct FmhaFwdSplitKVKernel
|
||||
const void* block_table_ptr,
|
||||
ck_tile::index_t batch_stride_block_table,
|
||||
ck_tile::index_t page_block_size,
|
||||
bool is_gappy,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
ck_tile::index_t stride_q,
|
||||
@@ -461,6 +467,7 @@ struct FmhaFwdSplitKVKernel
|
||||
kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
|
||||
kargs.batch_stride_block_table = batch_stride_block_table;
|
||||
kargs.page_block_size = page_block_size;
|
||||
kargs.is_gappy = is_gappy;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
@@ -495,11 +502,13 @@ struct FmhaFwdSplitKVKernel
|
||||
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
long_index_t batch_offset_q = 0;
|
||||
long_index_t batch_offset_k = 0;
|
||||
long_index_t batch_offset_v = 0;
|
||||
long_index_t batch_offset_k = 0; // unused for paged-kvcache
|
||||
long_index_t batch_offset_v = 0; // unused for paged-kvcache
|
||||
long_index_t batch_offset_bias = 0;
|
||||
long_index_t batch_offset_lse_acc = 0;
|
||||
long_index_t batch_offset_o_acc = 0;
|
||||
index_t kv_l2p_offset =
|
||||
0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
@@ -508,22 +517,14 @@ struct FmhaFwdSplitKVKernel
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
if constexpr(kIsPagedKV)
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_v = key_start;
|
||||
}
|
||||
batch_offset_v = key_start;
|
||||
}
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
@@ -551,6 +552,15 @@ struct FmhaFwdSplitKVKernel
|
||||
{
|
||||
kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
|
||||
}
|
||||
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
if(kargs.is_gappy)
|
||||
{
|
||||
// seqstart_k_ptr has different meaning in this case
|
||||
kv_l2p_offset = kargs.seqstart_k_ptr[i_batch];
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -703,7 +713,7 @@ struct FmhaFwdSplitKVKernel
|
||||
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
|
||||
i_batch_ * kargs.batch_stride_block_table;
|
||||
const index_t num_blocks =
|
||||
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
|
||||
integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
|
||||
|
||||
const long_index_t fixed_offset =
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
@@ -718,7 +728,8 @@ struct FmhaFwdSplitKVKernel
|
||||
kargs.page_block_size,
|
||||
k_dram,
|
||||
make_k_dram(nullptr,
|
||||
kargs.seqlen_k - (num_blocks - 1) * kargs.page_block_size));
|
||||
(kv_l2p_offset + kargs.seqlen_k) -
|
||||
(num_blocks - 1) * kargs.page_block_size));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -733,7 +744,7 @@ struct FmhaFwdSplitKVKernel
|
||||
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
|
||||
i_batch_ * kargs.batch_stride_block_table;
|
||||
const index_t num_blocks =
|
||||
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
|
||||
integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
|
||||
|
||||
const long_index_t fixed_offset =
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
@@ -748,7 +759,8 @@ struct FmhaFwdSplitKVKernel
|
||||
kargs.page_block_size,
|
||||
v_dram,
|
||||
make_v_dram(nullptr,
|
||||
kargs.seqlen_k - (num_blocks - 1) * kargs.page_block_size));
|
||||
(kv_l2p_offset + kargs.seqlen_k) -
|
||||
(num_blocks - 1) * kargs.page_block_size));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -896,6 +908,7 @@ struct FmhaFwdSplitKVKernel
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
}
|
||||
else
|
||||
@@ -912,6 +925,7 @@ struct FmhaFwdSplitKVKernel
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -18,11 +18,11 @@ struct FmhaFwdSplitKVTilePartitioner
|
||||
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
|
||||
|
||||
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits)
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
|
||||
|
||||
Reference in New Issue
Block a user