Make tile window directly via PageBlockNavigator

This commit is contained in:
PoYen, Chen
2024-08-13 09:18:24 +00:00
parent a8a2275aca
commit 370babc996
2 changed files with 45 additions and 27 deletions

View File

@@ -16,7 +16,16 @@ struct TrivialPageBlockNavigator
using WindowOrigin = multi_index<2>;
template <typename TensorView, typename WindowLengths>
CK_TILE_DEVICE static constexpr auto
CK_TILE_HOST_DEVICE static constexpr auto make_tile_window(const TensorView& tensor_view,
const WindowLengths& window_lengths,
const WindowOrigin& window_origin)
{
return make_tuple(/*block_index=*/0,
make_tile_window(tensor_view, window_lengths, window_origin));
}
template <typename TensorView, typename WindowLengths>
CK_TILE_HOST_DEVICE static constexpr auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const WindowOrigin& window_origin)
{
@@ -25,7 +34,7 @@ struct TrivialPageBlockNavigator
}
template <typename TensorView, typename WindowLengths, typename TileDistribution>
CK_TILE_DEVICE static constexpr auto
CK_TILE_HOST_DEVICE static constexpr auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const WindowOrigin& window_origin,
const TileDistribution& tile_distribution)
@@ -36,7 +45,7 @@ struct TrivialPageBlockNavigator
}
template <typename TileWindow>
CK_TILE_DEVICE static index_t
CK_TILE_HOST_DEVICE static index_t
move_tile_window(index_t /*block_index*/,
TileWindow& tile_window,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step)
@@ -46,13 +55,13 @@ struct TrivialPageBlockNavigator
return /*block_index=*/0;
}
CK_TILE_DEVICE static constexpr WindowOrigin
CK_TILE_HOST_DEVICE static constexpr WindowOrigin
to_local_window_origin(const WindowOrigin& global_window_origin)
{
return global_window_origin;
}
CK_TILE_DEVICE static constexpr WindowOrigin
CK_TILE_HOST_DEVICE static constexpr WindowOrigin
to_global_window_origin(index_t /*block_index*/, const WindowOrigin& local_window_origin)
{
return local_window_origin;
@@ -83,6 +92,21 @@ struct PageBlockNavigator
{
}
template <typename TensorView, typename WindowLengths>
CK_TILE_HOST_DEVICE auto make_tile_window(const TensorView& tensor_view,
const WindowLengths& window_lengths,
const WindowOrigin& window_origin)
{
const index_t block_index = get_block_index(window_origin);
const WindowOrigin local_window_origin = to_local_window_origin(window_origin);
auto new_tile_window =
ck_tile::make_tile_window(tensor_view, window_lengths, local_window_origin);
new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index));
return make_tuple(block_index, new_tile_window);
}
template <typename TensorView, typename WindowLengths>
CK_TILE_HOST_DEVICE auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,

View File

@@ -704,29 +704,23 @@ struct FmhaFwdAppendKVKernel
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
{i_m0, 0});
/// FIXME: create tile window directly via PageBlockNavigator
const bool skip_append_kv = kargs.seqlen_knew <= i_n0;
auto k_dram_window =
make_tile_window(k_dram,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
{skip_append_kv ? 0 : kargs.seqlen_k + i_n0, 0});
auto [i_page_block_k, k_dram_window_tmp] = k_page_block_navigator.make_tile_window(
k_dram_window, {skip_append_kv ? 0 : kargs.seqlen_k + i_n0, 0});
// window origin = (0, 0) if no work to do for current block
auto [i_page_block_k, k_dram_window] = k_page_block_navigator.make_tile_window(
k_dram,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
{!skip_append_kv * (kargs.seqlen_k + i_n0), 0});
auto knew_dram_window =
make_tile_window(knew_dram,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
{i_n0, 0});
/// FIXME: create tile window directly via PageBlockNavigator
auto v_dram_window =
make_tile_window(v_dram,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
{0, skip_append_kv ? 0 : kargs.seqlen_k + i_n0});
auto [i_page_block_v, v_dram_window_tmp] = v_page_block_navigator.make_tile_window(
v_dram_window, {0, skip_append_kv ? 0 : kargs.seqlen_k + i_n0});
// window origin = (0, 0) if no work to do for current block
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
{0, !skip_append_kv * (kargs.seqlen_k + i_n0)});
auto vnew_dram_window =
make_tile_window(vnew_dram,
@@ -736,11 +730,11 @@ struct FmhaFwdAppendKVKernel
if constexpr(kApplyRoPE)
{
FmhaPipeline{}(q_dram_window,
k_dram_window_tmp,
k_dram_window,
i_page_block_k,
k_page_block_navigator,
knew_dram_window,
v_dram_window_tmp,
v_dram_window,
i_page_block_v,
v_page_block_navigator,
vnew_dram_window,
@@ -750,16 +744,16 @@ struct FmhaFwdAppendKVKernel
knew_rotary_sin_dram_window,
kargs.rotary_dim,
kargs.seqlen_q <= i_m0,
kargs.seqlen_knew <= i_n0);
skip_append_kv);
}
else
{
FmhaPipeline{}(q_dram_window,
k_dram_window_tmp,
k_dram_window,
i_page_block_k,
k_page_block_navigator,
knew_dram_window,
v_dram_window_tmp,
v_dram_window,
i_page_block_v,
v_page_block_navigator,
vnew_dram_window,
@@ -769,7 +763,7 @@ struct FmhaFwdAppendKVKernel
knew_rotary_sin_dram_window,
0, // rotary_dim not used
kargs.seqlen_q <= i_m0,
kargs.seqlen_knew <= i_n0);
skip_append_kv);
}
}
};