mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
Make tile window directly via PageBlockNavigator
This commit is contained in:
@@ -16,7 +16,16 @@ struct TrivialPageBlockNavigator
|
||||
using WindowOrigin = multi_index<2>;
|
||||
|
||||
template <typename TensorView, typename WindowLengths>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_tile_window(const TensorView& tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const WindowOrigin& window_origin)
|
||||
{
|
||||
return make_tuple(/*block_index=*/0,
|
||||
make_tile_window(tensor_view, window_lengths, window_origin));
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const WindowOrigin& window_origin)
|
||||
{
|
||||
@@ -25,7 +34,7 @@ struct TrivialPageBlockNavigator
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths, typename TileDistribution>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const WindowOrigin& window_origin,
|
||||
const TileDistribution& tile_distribution)
|
||||
@@ -36,7 +45,7 @@ struct TrivialPageBlockNavigator
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_DEVICE static index_t
|
||||
CK_TILE_HOST_DEVICE static index_t
|
||||
move_tile_window(index_t /*block_index*/,
|
||||
TileWindow& tile_window,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step)
|
||||
@@ -46,13 +55,13 @@ struct TrivialPageBlockNavigator
|
||||
return /*block_index=*/0;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr WindowOrigin
|
||||
CK_TILE_HOST_DEVICE static constexpr WindowOrigin
|
||||
to_local_window_origin(const WindowOrigin& global_window_origin)
|
||||
{
|
||||
return global_window_origin;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr WindowOrigin
|
||||
CK_TILE_HOST_DEVICE static constexpr WindowOrigin
|
||||
to_global_window_origin(index_t /*block_index*/, const WindowOrigin& local_window_origin)
|
||||
{
|
||||
return local_window_origin;
|
||||
@@ -83,6 +92,21 @@ struct PageBlockNavigator
|
||||
{
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths>
|
||||
CK_TILE_HOST_DEVICE auto make_tile_window(const TensorView& tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const WindowOrigin& window_origin)
|
||||
{
|
||||
const index_t block_index = get_block_index(window_origin);
|
||||
const WindowOrigin local_window_origin = to_local_window_origin(window_origin);
|
||||
|
||||
auto new_tile_window =
|
||||
ck_tile::make_tile_window(tensor_view, window_lengths, local_window_origin);
|
||||
new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index));
|
||||
|
||||
return make_tuple(block_index, new_tile_window);
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
|
||||
@@ -704,29 +704,23 @@ struct FmhaFwdAppendKVKernel
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
|
||||
{i_m0, 0});
|
||||
|
||||
/// FIXME: create tile window directly via PageBlockNavigator
|
||||
const bool skip_append_kv = kargs.seqlen_knew <= i_n0;
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
{skip_append_kv ? 0 : kargs.seqlen_k + i_n0, 0});
|
||||
|
||||
auto [i_page_block_k, k_dram_window_tmp] = k_page_block_navigator.make_tile_window(
|
||||
k_dram_window, {skip_append_kv ? 0 : kargs.seqlen_k + i_n0, 0});
|
||||
// window origin = (0, 0) if no work to do for current block
|
||||
auto [i_page_block_k, k_dram_window] = k_page_block_navigator.make_tile_window(
|
||||
k_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
{!skip_append_kv * (kargs.seqlen_k + i_n0), 0});
|
||||
|
||||
auto knew_dram_window =
|
||||
make_tile_window(knew_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
{i_n0, 0});
|
||||
|
||||
/// FIXME: create tile window directly via PageBlockNavigator
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
|
||||
{0, skip_append_kv ? 0 : kargs.seqlen_k + i_n0});
|
||||
|
||||
auto [i_page_block_v, v_dram_window_tmp] = v_page_block_navigator.make_tile_window(
|
||||
v_dram_window, {0, skip_append_kv ? 0 : kargs.seqlen_k + i_n0});
|
||||
// window origin = (0, 0) if no work to do for current block
|
||||
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
|
||||
v_dram,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
|
||||
{0, !skip_append_kv * (kargs.seqlen_k + i_n0)});
|
||||
|
||||
auto vnew_dram_window =
|
||||
make_tile_window(vnew_dram,
|
||||
@@ -736,11 +730,11 @@ struct FmhaFwdAppendKVKernel
|
||||
if constexpr(kApplyRoPE)
|
||||
{
|
||||
FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window_tmp,
|
||||
k_dram_window,
|
||||
i_page_block_k,
|
||||
k_page_block_navigator,
|
||||
knew_dram_window,
|
||||
v_dram_window_tmp,
|
||||
v_dram_window,
|
||||
i_page_block_v,
|
||||
v_page_block_navigator,
|
||||
vnew_dram_window,
|
||||
@@ -750,16 +744,16 @@ struct FmhaFwdAppendKVKernel
|
||||
knew_rotary_sin_dram_window,
|
||||
kargs.rotary_dim,
|
||||
kargs.seqlen_q <= i_m0,
|
||||
kargs.seqlen_knew <= i_n0);
|
||||
skip_append_kv);
|
||||
}
|
||||
else
|
||||
{
|
||||
FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window_tmp,
|
||||
k_dram_window,
|
||||
i_page_block_k,
|
||||
k_page_block_navigator,
|
||||
knew_dram_window,
|
||||
v_dram_window_tmp,
|
||||
v_dram_window,
|
||||
i_page_block_v,
|
||||
v_page_block_navigator,
|
||||
vnew_dram_window,
|
||||
@@ -769,7 +763,7 @@ struct FmhaFwdAppendKVKernel
|
||||
knew_rotary_sin_dram_window,
|
||||
0, // rotary_dim not used
|
||||
kargs.seqlen_q <= i_m0,
|
||||
kargs.seqlen_knew <= i_n0);
|
||||
skip_append_kv);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user