diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index 4aa10094fa..e761da57de 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -738,17 +738,17 @@ struct FmhaFwdAppendKVKernel FmhaPipeline{}(q_dram_window, k_dram_window_tmp, i_page_block_k, + k_page_block_navigator, knew_dram_window, v_dram_window_tmp, i_page_block_v, + v_page_block_navigator, vnew_dram_window, q_rotary_cos_dram_window, q_rotary_sin_dram_window, knew_rotary_cos_dram_window, knew_rotary_sin_dram_window, kargs.rotary_dim, - k_page_block_navigator, - v_page_block_navigator, kargs.seqlen_q <= i_m0, kargs.seqlen_knew <= i_n0); } @@ -757,17 +757,17 @@ struct FmhaFwdAppendKVKernel FmhaPipeline{}(q_dram_window, k_dram_window_tmp, i_page_block_k, + k_page_block_navigator, knew_dram_window, v_dram_window_tmp, i_page_block_v, + v_page_block_navigator, vnew_dram_window, q_rotary_cos_dram_window, q_rotary_sin_dram_window, knew_rotary_cos_dram_window, knew_rotary_sin_dram_window, 0, // rotary_dim not used - k_page_block_navigator, - v_page_block_navigator, kargs.seqlen_q <= i_m0, kargs.seqlen_knew <= i_n0); } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index a29e61d540..88c9ccd853 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -829,8 +829,10 @@ struct FmhaFwdSplitKVKernel return FmhaPipeline{}(q_dram_window, identity{}, // q_element_func k_dram_window, + k_page_block_navigator, identity{}, // k_element_func v_dram_window, + v_page_block_navigator, identity{}, // v_element_func bias_dram_window, identity{}, // bias_element_func @@ -844,15 +846,15 @@ struct FmhaFwdSplitKVKernel mask, position_encoding, kargs.scale_s, - smem_ptr, - k_page_block_navigator, - v_page_block_navigator); + smem_ptr); } else { return FmhaPipeline{}(q_dram_window, k_dram_window, + k_page_block_navigator, v_dram_window, + v_page_block_navigator, bias_dram_window, lse_acc_dram_window, kargs.num_splits, @@ -860,9 +862,7 @@ struct FmhaFwdSplitKVKernel mask, position_encoding, kargs.scale_s, - smem_ptr, - k_page_block_navigator, - v_page_block_navigator); + smem_ptr); } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index 7734a10f2e..734abefe63 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -74,8 +74,10 @@ struct BlockFmhaFwdAppendKVPipeline template + typename KnewRotarySinDramBlockWindow> CK_TILE_HOST_DEVICE auto operator()(QDramBlockWindow& q_dram_block_window, // M0*K0 tile const QElementFunction& q_element_func, KDramBlockWindow& k_dram_block_window, // N0*K0 tile index_t i_page_block_k, + const KPageBlockNavigator& k_page_block_navigator, const KnewDramBlockWindow& knew_dram_block_window, // N0*K0 tile const KnewElementFunction& knew_element_func, VDramBlockWindow& v_dram_block_window, // N1*N0 tile index_t i_page_block_v, + const VPageBlockNavigator& v_page_block_navigator, const VnewDramBlockWindow& vnew_dram_block_window, // N1*N0 tile const VnewElementFunction& vnew_element_func, const QRotaryCosDramBlockWindow q_rotary_cos_dram_block_window, @@ -102,12 +104,10 @@ struct BlockFmhaFwdAppendKVPipeline const KnewRotaryCosDramBlockWindow knew_rotary_cos_dram_block_window, const KnewRotarySinDramBlockWindow knew_rotary_sin_dram_block_window, index_t rotary_dim, - const KPageBlockNavigator& k_page_block_navigator, - const VPageBlockNavigator& v_page_block_navigator, - bool skip_transform_q, - bool skip_append_kv) const + bool skip_rotate_q, + bool skip_rotate_append_kv) const { - if(!skip_append_kv) + if(!skip_rotate_append_kv) { // append Knew to K auto knew_window = make_tile_window( @@ -190,7 +190,7 @@ struct BlockFmhaFwdAppendKVPipeline } } - if(!skip_transform_q) + if(!skip_rotate_q) { // optionally apply rotary embedding to Q if constexpr(RotaryEnum != RotaryEmbeddingEnum::NONE) @@ -231,41 +231,43 @@ struct BlockFmhaFwdAppendKVPipeline template + typename KnewRotarySinDramBlockWindow> CK_TILE_HOST_DEVICE auto operator()(QDramBlockWindow& q_dram_block_window, KDramBlockWindow& k_dram_block_window, index_t i_page_block_k, + const KPageBlockNavigator& k_page_block_navigator, const KnewDramBlockWindow& knew_dram_block_window, VDramBlockWindow& v_dram_block_window, index_t i_page_block_v, + const VPageBlockNavigator& v_page_block_navigator, const VnewDramBlockWindow& vnew_dram_block_window, const QRotaryCosDramBlockWindow& q_rotary_cos_dram_block_window, const QRotarySinDramBlockWindow& q_rotary_sin_dram_block_window, const KnewRotaryCosDramBlockWindow& knew_rotary_cos_dram_block_window, const KnewRotarySinDramBlockWindow& knew_rotary_sin_dram_block_window, index_t rotary_dim, - const KPageBlockNavigator& k_page_block_navigator, - const VPageBlockNavigator& v_page_block_navigator, - bool skip_transform_q, - bool skip_append_kv) const + bool skip_rotate_q, + bool skip_rotate_append_kv) const { return operator()(q_dram_block_window, identity{}, k_dram_block_window, i_page_block_k, + k_page_block_navigator, knew_dram_block_window, identity{}, v_dram_block_window, i_page_block_v, + v_page_block_navigator, vnew_dram_block_window, identity{}, q_rotary_cos_dram_block_window, @@ -273,10 +275,8 @@ struct BlockFmhaFwdAppendKVPipeline knew_rotary_cos_dram_block_window, knew_rotary_sin_dram_block_window, rotary_dim, - k_page_block_navigator, - v_page_block_navigator, - skip_transform_q, - skip_append_kv); + skip_rotate_q, + skip_rotate_append_kv); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index cf37d3ed1e..2af54994a2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -105,7 +105,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS template + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KPageBlockNavigator& k_page_block_navigator, const KElementFunction& k_element_func, const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VPageBlockNavigator& v_page_block_navigator, const VElementFunction& v_element_func, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, @@ -138,9 +140,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, - void* smem_ptr, - const KPageBlockNavigator& k_page_block_navigator, - const VPageBlockNavigator& v_page_block_navigator) const + void* smem_ptr) const { static_assert( std::is_same_v> && @@ -634,16 +634,18 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS template + typename PositionEncoding> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KPageBlockNavigator& k_page_block_navigator, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VPageBlockNavigator& v_page_block_navigator, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile index_t num_splits, @@ -651,15 +653,15 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, - void* smem_ptr, - const KPageBlockNavigator& k_page_block_navigator, - const VPageBlockNavigator& v_page_block_navigator) const + void* smem_ptr) const { return operator()(q_dram_block_window_tmp, identity{}, k_dram_block_window_tmp, + k_page_block_navigator, identity{}, v_dram_block_window_tmp, + v_page_block_navigator, identity{}, bias_dram_block_window_tmp, identity{}, @@ -673,9 +675,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS mask, position_encoding, scale_s, - smem_ptr, - k_page_block_navigator, - v_page_block_navigator); + smem_ptr); } };