mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Add tile navigators to the appendkv kernel
This commit is contained in:
@@ -139,6 +139,9 @@ struct FmhaFwdAppendKVKernel
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
|
||||
@@ -361,6 +364,58 @@ struct FmhaFwdAppendKVKernel
|
||||
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
}
|
||||
|
||||
auto k_tile_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
const auto* block_indices =
|
||||
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
|
||||
i_batch_ * kargs.batch_stride_block_table;
|
||||
const index_t num_blocks =
|
||||
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
|
||||
|
||||
const long_index_t fixed_offset =
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_k;
|
||||
|
||||
return PagedTileWindowNavigator<const KDataType, 0>(kargs.k_ptr,
|
||||
kargs.batch_stride_k,
|
||||
fixed_offset,
|
||||
block_indices,
|
||||
num_blocks,
|
||||
kargs.page_block_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
return SimpleTileWindowNavigator<const KDataType>();
|
||||
}
|
||||
}();
|
||||
|
||||
auto v_tile_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
const auto* block_indices =
|
||||
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
|
||||
i_batch_ * kargs.batch_stride_block_table;
|
||||
const index_t num_blocks =
|
||||
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
|
||||
|
||||
const long_index_t fixed_offset =
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_v;
|
||||
|
||||
return PagedTileWindowNavigator<const VDataType, 1>(kargs.v_ptr,
|
||||
kargs.batch_stride_v,
|
||||
fixed_offset,
|
||||
block_indices,
|
||||
num_blocks,
|
||||
kargs.page_block_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
return SimpleTileWindowNavigator<const VDataType>();
|
||||
}
|
||||
}();
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
QDataType* q_ptr = reinterpret_cast<QDataType*>(kargs.q_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
|
||||
@@ -397,9 +452,20 @@ struct FmhaFwdAppendKVKernel
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}();
|
||||
const auto k_dram = [&]() {
|
||||
const auto lengths = [&]() {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return make_tuple(kargs.page_block_size, kargs.hdim_q);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.hdim_q);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.hdim_q),
|
||||
lengths,
|
||||
make_tuple(kargs.stride_k, 1),
|
||||
number<FmhaPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
@@ -425,9 +491,20 @@ struct FmhaFwdAppendKVKernel
|
||||
const auto v_dram = [&]() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
const auto lengths = [&]() {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return make_tuple(kargs.page_block_size, kargs.hdim_v);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.hdim_v);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.hdim_v),
|
||||
lengths,
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
@@ -446,9 +523,20 @@ struct FmhaFwdAppendKVKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto lengths = [&]() {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return make_tuple(kargs.hdim_v, kargs.page_block_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(kargs.hdim_v, kargs.seqlen_k + kargs.seqlen_knew);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.hdim_v, kargs.seqlen_k + kargs.seqlen_knew),
|
||||
lengths,
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
@@ -644,6 +732,8 @@ struct FmhaFwdAppendKVKernel
|
||||
knew_rotary_cos_dram_window,
|
||||
knew_rotary_sin_dram_window,
|
||||
kargs.rotary_dim,
|
||||
k_tile_navigator,
|
||||
v_tile_navigator,
|
||||
kargs.seqlen_q <= i_m0,
|
||||
kargs.seqlen_knew <= i_n0);
|
||||
}
|
||||
@@ -658,7 +748,9 @@ struct FmhaFwdAppendKVKernel
|
||||
q_rotary_sin_dram_window,
|
||||
knew_rotary_cos_dram_window,
|
||||
knew_rotary_sin_dram_window,
|
||||
0,
|
||||
0, // rotary_dim not used
|
||||
k_tile_navigator,
|
||||
v_tile_navigator,
|
||||
kargs.seqlen_q <= i_m0,
|
||||
kargs.seqlen_knew <= i_n0);
|
||||
}
|
||||
|
||||
@@ -83,7 +83,9 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
typename QRotaryCosDramBlockWindow,
|
||||
typename QRotarySinDramBlockWindow,
|
||||
typename KnewRotaryCosDramBlockWindow,
|
||||
typename KnewRotarySinDramBlockWindow>
|
||||
typename KnewRotarySinDramBlockWindow,
|
||||
typename KTileWindowNavigator,
|
||||
typename VTileWindowNavigator>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(QDramBlockWindow& q_dram_block_window, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
@@ -98,6 +100,8 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
const KnewRotaryCosDramBlockWindow knew_rotary_cos_dram_block_window,
|
||||
const KnewRotarySinDramBlockWindow knew_rotary_sin_dram_block_window,
|
||||
index_t rotary_dim,
|
||||
const KTileWindowNavigator& k_tile_navigator,
|
||||
const VTileWindowNavigator& v_tile_navigator,
|
||||
bool skip_transform_q,
|
||||
bool skip_append_kv) const
|
||||
{
|
||||
@@ -141,7 +145,16 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
rotary_dim,
|
||||
thread_end);
|
||||
}
|
||||
store_tile(k_dram_block_window, knew_tile);
|
||||
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
/// TODO: handle cross-page-block write
|
||||
store_tile(k_dram_block_window, knew_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(k_dram_block_window, knew_tile);
|
||||
}
|
||||
|
||||
// append Vnew to V
|
||||
auto vnew_window = make_tile_window(
|
||||
@@ -151,7 +164,16 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
auto vnew = load_tile(vnew_window);
|
||||
return tile_elementwise_in(vnew_element_func, vnew);
|
||||
}();
|
||||
store_tile(v_dram_block_window, vnew_tile);
|
||||
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
/// TODO: handle cross-page-block write
|
||||
store_tile(v_dram_block_window, vnew_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(v_dram_block_window, vnew_tile);
|
||||
}
|
||||
}
|
||||
|
||||
if(!skip_transform_q)
|
||||
@@ -201,7 +223,9 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
typename QRotaryCosDramBlockWindow,
|
||||
typename QRotarySinDramBlockWindow,
|
||||
typename KnewRotaryCosDramBlockWindow,
|
||||
typename KnewRotarySinDramBlockWindow>
|
||||
typename KnewRotarySinDramBlockWindow,
|
||||
typename KTileWindowNavigator,
|
||||
typename VTileWindowNavigator>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(QDramBlockWindow& q_dram_block_window,
|
||||
KDramBlockWindow& k_dram_block_window,
|
||||
@@ -213,6 +237,8 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
const KnewRotaryCosDramBlockWindow& knew_rotary_cos_dram_block_window,
|
||||
const KnewRotarySinDramBlockWindow& knew_rotary_sin_dram_block_window,
|
||||
index_t rotary_dim,
|
||||
const KTileWindowNavigator& k_tile_navigator,
|
||||
const VTileWindowNavigator& v_tile_navigator,
|
||||
bool skip_transform_q,
|
||||
bool skip_append_kv) const
|
||||
{
|
||||
@@ -229,6 +255,8 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
knew_rotary_cos_dram_block_window,
|
||||
knew_rotary_sin_dram_block_window,
|
||||
rotary_dim,
|
||||
k_tile_navigator,
|
||||
v_tile_navigator,
|
||||
skip_transform_q,
|
||||
skip_append_kv);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user