mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
Pass re-created tile window to pipeline
This commit is contained in:
@@ -255,7 +255,9 @@ struct FmhaFwdAppendKVKernel
|
||||
ck_tile::index_t nhead_stride_knew,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_vnew,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_knew,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_vnew)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
@@ -288,7 +290,9 @@ struct FmhaFwdAppendKVKernel
|
||||
{}, // placeholder for rope
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
|
||||
batch_stride_k,
|
||||
batch_stride_v};
|
||||
|
||||
if constexpr(kApplyRoPE)
|
||||
{
|
||||
@@ -371,22 +375,22 @@ struct FmhaFwdAppendKVKernel
|
||||
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
|
||||
i_batch_ * kargs.batch_stride_block_table;
|
||||
const index_t num_blocks =
|
||||
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
|
||||
integer_divide_ceil(kargs.seqlen_k + kargs.seqlen_knew, kargs.page_block_size);
|
||||
|
||||
const long_index_t fixed_offset =
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_k;
|
||||
|
||||
return PagedTileWindowNavigator<const KDataType, 0>(kargs.k_ptr,
|
||||
kargs.batch_stride_k,
|
||||
fixed_offset,
|
||||
block_indices,
|
||||
num_blocks,
|
||||
kargs.page_block_size);
|
||||
return PagedTileWindowNavigator<KDataType, 0>(kargs.k_ptr,
|
||||
kargs.batch_stride_k,
|
||||
fixed_offset,
|
||||
block_indices,
|
||||
num_blocks,
|
||||
kargs.page_block_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
return SimpleTileWindowNavigator<const KDataType>();
|
||||
return SimpleTileWindowNavigator<KDataType>();
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -397,22 +401,22 @@ struct FmhaFwdAppendKVKernel
|
||||
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
|
||||
i_batch_ * kargs.batch_stride_block_table;
|
||||
const index_t num_blocks =
|
||||
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
|
||||
integer_divide_ceil(kargs.seqlen_k + kargs.seqlen_knew, kargs.page_block_size);
|
||||
|
||||
const long_index_t fixed_offset =
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_v;
|
||||
|
||||
return PagedTileWindowNavigator<const VDataType, 1>(kargs.v_ptr,
|
||||
kargs.batch_stride_v,
|
||||
fixed_offset,
|
||||
block_indices,
|
||||
num_blocks,
|
||||
kargs.page_block_size);
|
||||
return PagedTileWindowNavigator<VDataType, 1>(kargs.v_ptr,
|
||||
kargs.batch_stride_v,
|
||||
fixed_offset,
|
||||
block_indices,
|
||||
num_blocks,
|
||||
kargs.page_block_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
return SimpleTileWindowNavigator<const VDataType>();
|
||||
return SimpleTileWindowNavigator<VDataType>();
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -464,7 +468,7 @@ struct FmhaFwdAppendKVKernel
|
||||
}();
|
||||
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
k_ptr, // will update this pointer if using paged-kvcache
|
||||
lengths,
|
||||
make_tuple(kargs.stride_k, 1),
|
||||
number<FmhaPipeline::kAlignmentK>{},
|
||||
@@ -503,7 +507,7 @@ struct FmhaFwdAppendKVKernel
|
||||
}();
|
||||
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
v_ptr, // will update this pointer if using paged-kvcache
|
||||
lengths,
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
@@ -511,8 +515,8 @@ struct FmhaFwdAppendKVKernel
|
||||
|
||||
const auto v_dram_transposed = transform_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
make_pass_through_transform(kargs.seqlen_k + kargs.seqlen_knew)),
|
||||
make_tuple(make_pass_through_transform(lengths.at(number<1>{})),
|
||||
make_pass_through_transform(lengths.at(number<0>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
@@ -535,7 +539,7 @@ struct FmhaFwdAppendKVKernel
|
||||
}();
|
||||
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
v_ptr, // will update this pointer if using paged-kvcache
|
||||
lengths,
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
@@ -700,21 +704,44 @@ struct FmhaFwdAppendKVKernel
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
|
||||
{i_m0, 0});
|
||||
|
||||
/// FIXME: create tile window directly via TileWindowNavigator
|
||||
const bool skip_append_kv = kargs.seqlen_knew <= i_n0;
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
{kargs.seqlen_k + i_n0, 0});
|
||||
{skip_append_kv ? 0 : kargs.seqlen_k + i_n0, 0});
|
||||
|
||||
auto [i_block0, k_dram_window_tmp] = k_tile_navigator.make_tile_window(
|
||||
k_dram_window, {skip_append_kv ? 0 : kargs.seqlen_k + i_n0, 0});
|
||||
DEVICE_DEBUG_STMTS
|
||||
{
|
||||
printf("[DEVICE] i_block0: %d\n", i_block0);
|
||||
auto local_origin = k_dram_window_tmp.get_window_origin();
|
||||
printf("[DEVICE] origin: (%d, %d)\n",
|
||||
local_origin.at(number<0>{}),
|
||||
local_origin.at(number<1>{}));
|
||||
}
|
||||
auto knew_dram_window =
|
||||
make_tile_window(knew_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
{i_n0, 0});
|
||||
|
||||
/// FIXME: create tile window directly via TileWindowNavigator
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
|
||||
{0, kargs.seqlen_k + i_n0});
|
||||
{0, skip_append_kv ? 0 : kargs.seqlen_k + i_n0});
|
||||
|
||||
auto [i_block1, v_dram_window_tmp] = v_tile_navigator.make_tile_window(
|
||||
v_dram_window, {0, skip_append_kv ? 0 : kargs.seqlen_k + i_n0});
|
||||
DEVICE_DEBUG_STMTS
|
||||
{
|
||||
printf("[DEVICE] i_block1: %d\n", i_block1);
|
||||
auto local_origin = v_dram_window_tmp.get_window_origin();
|
||||
printf("[DEVICE] origin: (%d, %d)\n",
|
||||
local_origin.at(number<0>{}),
|
||||
local_origin.at(number<1>{}));
|
||||
}
|
||||
auto vnew_dram_window =
|
||||
make_tile_window(vnew_dram,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
|
||||
@@ -723,9 +750,9 @@ struct FmhaFwdAppendKVKernel
|
||||
if constexpr(kApplyRoPE)
|
||||
{
|
||||
FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
k_dram_window_tmp,
|
||||
knew_dram_window,
|
||||
v_dram_window,
|
||||
v_dram_window_tmp,
|
||||
vnew_dram_window,
|
||||
q_rotary_cos_dram_window,
|
||||
q_rotary_sin_dram_window,
|
||||
@@ -740,9 +767,9 @@ struct FmhaFwdAppendKVKernel
|
||||
else
|
||||
{
|
||||
FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
k_dram_window_tmp,
|
||||
knew_dram_window,
|
||||
v_dram_window,
|
||||
v_dram_window_tmp,
|
||||
vnew_dram_window,
|
||||
q_rotary_cos_dram_window,
|
||||
q_rotary_sin_dram_window,
|
||||
|
||||
Reference in New Issue
Block a user