diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index a33270e311..e35587b71d 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -514,8 +514,9 @@ struct FmhaFwdAppendKVKernel sequence{}); }(); + /// TODO: use tile idx for q return make_tile_window( - rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {0, 0}); + rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {i_sk, 0}); } else { @@ -539,8 +540,9 @@ struct FmhaFwdAppendKVKernel sequence{}); }(); + /// TODO: use tile idx for q return make_tile_window( - rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {0, 0}); + rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {i_sk, 0}); } else { @@ -568,7 +570,7 @@ struct FmhaFwdAppendKVKernel }(); return make_tile_window( - rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {0, 0}); + rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {i_sk, 0}); } else { @@ -593,7 +595,7 @@ struct FmhaFwdAppendKVKernel }(); return make_tile_window( - rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {0, 0}); + rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {i_sk, 0}); } else { @@ -601,6 +603,7 @@ struct FmhaFwdAppendKVKernel } }(); + /// TODO: use tile idx for q auto q_dram_window = make_tile_window( q_dram, make_tuple(number{}, number{}), @@ -609,7 +612,7 @@ struct FmhaFwdAppendKVKernel auto k_dram_window = make_tile_window( k_dram, make_tuple(number{}, number{}), - {kargs.seqlen_k, 0}); + {kargs.seqlen_k + i_sk, 0}); auto knew_dram_window = make_tile_window( knew_dram, @@ -619,7 +622,7 @@ struct FmhaFwdAppendKVKernel auto v_dram_window = make_tile_window( v_dram, make_tuple(number{}, number{}), - {0, kargs.seqlen_k}); + {0, kargs.seqlen_k + i_sk}); auto vnew_dram_window = make_tile_window( vnew_dram,