From 370babc99621212048a5b062351baaf12a04e6fc Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 13 Aug 2024 09:18:24 +0000 Subject: [PATCH] Make tile window directly via PageBlockNavigator --- .../ops/fmha/block/page_block_navigator.hpp | 34 ++++++++++++++--- .../fmha/kernel/fmha_fwd_appendkv_kernel.hpp | 38 ++++++++----------- 2 files changed, 45 insertions(+), 27 deletions(-) diff --git a/include/ck_tile/ops/fmha/block/page_block_navigator.hpp b/include/ck_tile/ops/fmha/block/page_block_navigator.hpp index 82e93c128c..1e5e63e393 100644 --- a/include/ck_tile/ops/fmha/block/page_block_navigator.hpp +++ b/include/ck_tile/ops/fmha/block/page_block_navigator.hpp @@ -16,7 +16,16 @@ struct TrivialPageBlockNavigator using WindowOrigin = multi_index<2>; template - CK_TILE_DEVICE static constexpr auto + CK_TILE_HOST_DEVICE static constexpr auto make_tile_window(const TensorView& tensor_view, + const WindowLengths& window_lengths, + const WindowOrigin& window_origin) + { + return make_tuple(/*block_index=*/0, + make_tile_window(tensor_view, window_lengths, window_origin)); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto make_tile_window(const tile_window_with_static_lengths& tile_window, const WindowOrigin& window_origin) { @@ -25,7 +34,7 @@ struct TrivialPageBlockNavigator } template - CK_TILE_DEVICE static constexpr auto + CK_TILE_HOST_DEVICE static constexpr auto make_tile_window(const tile_window_with_static_lengths& tile_window, const WindowOrigin& window_origin, const TileDistribution& tile_distribution) @@ -36,7 +45,7 @@ struct TrivialPageBlockNavigator } template - CK_TILE_DEVICE static index_t + CK_TILE_HOST_DEVICE static index_t move_tile_window(index_t /*block_index*/, TileWindow& tile_window, const typename remove_cvref_t::BottomTensorIndex& step) @@ -46,13 +55,13 @@ struct TrivialPageBlockNavigator return /*block_index=*/0; } - CK_TILE_DEVICE static constexpr WindowOrigin + CK_TILE_HOST_DEVICE static constexpr WindowOrigin to_local_window_origin(const WindowOrigin& global_window_origin) { return global_window_origin; } - CK_TILE_DEVICE static constexpr WindowOrigin + CK_TILE_HOST_DEVICE static constexpr WindowOrigin to_global_window_origin(index_t /*block_index*/, const WindowOrigin& local_window_origin) { return local_window_origin; @@ -83,6 +92,21 @@ struct PageBlockNavigator { } + template + CK_TILE_HOST_DEVICE auto make_tile_window(const TensorView& tensor_view, + const WindowLengths& window_lengths, + const WindowOrigin& window_origin) + { + const index_t block_index = get_block_index(window_origin); + const WindowOrigin local_window_origin = to_local_window_origin(window_origin); + + auto new_tile_window = + ck_tile::make_tile_window(tensor_view, window_lengths, local_window_origin); + new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index)); + + return make_tuple(block_index, new_tile_window); + } + template CK_TILE_HOST_DEVICE auto make_tile_window(const tile_window_with_static_lengths& tile_window, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index e761da57de..e835fcf848 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -704,29 +704,23 @@ struct FmhaFwdAppendKVKernel make_tuple(number{}, number{}), {i_m0, 0}); - /// FIXME: create tile window directly via PageBlockNavigator const bool skip_append_kv = kargs.seqlen_knew <= i_n0; - auto k_dram_window = - make_tile_window(k_dram, - make_tuple(number{}, number{}), - {skip_append_kv ? 0 : kargs.seqlen_k + i_n0, 0}); - - auto [i_page_block_k, k_dram_window_tmp] = k_page_block_navigator.make_tile_window( - k_dram_window, {skip_append_kv ? 0 : kargs.seqlen_k + i_n0, 0}); + // window origin = (0, 0) if no work to do for current block + auto [i_page_block_k, k_dram_window] = k_page_block_navigator.make_tile_window( + k_dram, + make_tuple(number{}, number{}), + {!skip_append_kv * (kargs.seqlen_k + i_n0), 0}); auto knew_dram_window = make_tile_window(knew_dram, make_tuple(number{}, number{}), {i_n0, 0}); - /// FIXME: create tile window directly via PageBlockNavigator - auto v_dram_window = - make_tile_window(v_dram, - make_tuple(number{}, number{}), - {0, skip_append_kv ? 0 : kargs.seqlen_k + i_n0}); - - auto [i_page_block_v, v_dram_window_tmp] = v_page_block_navigator.make_tile_window( - v_dram_window, {0, skip_append_kv ? 0 : kargs.seqlen_k + i_n0}); + // window origin = (0, 0) if no work to do for current block + auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( + v_dram, + make_tuple(number{}, number{}), + {0, !skip_append_kv * (kargs.seqlen_k + i_n0)}); auto vnew_dram_window = make_tile_window(vnew_dram, @@ -736,11 +730,11 @@ struct FmhaFwdAppendKVKernel if constexpr(kApplyRoPE) { FmhaPipeline{}(q_dram_window, - k_dram_window_tmp, + k_dram_window, i_page_block_k, k_page_block_navigator, knew_dram_window, - v_dram_window_tmp, + v_dram_window, i_page_block_v, v_page_block_navigator, vnew_dram_window, @@ -750,16 +744,16 @@ struct FmhaFwdAppendKVKernel knew_rotary_sin_dram_window, kargs.rotary_dim, kargs.seqlen_q <= i_m0, - kargs.seqlen_knew <= i_n0); + skip_append_kv); } else { FmhaPipeline{}(q_dram_window, - k_dram_window_tmp, + k_dram_window, i_page_block_k, k_page_block_navigator, knew_dram_window, - v_dram_window_tmp, + v_dram_window, i_page_block_v, v_page_block_navigator, vnew_dram_window, @@ -769,7 +763,7 @@ struct FmhaFwdAppendKVKernel knew_rotary_sin_dram_window, 0, // rotary_dim not used kargs.seqlen_q <= i_m0, - kargs.seqlen_knew <= i_n0); + skip_append_kv); } } };