diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 5ac7a149d8..fc42506ea7 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1102,8 +1102,11 @@ bool run(const ck_tile::ArgParser& arg_parser) { decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths()); + auto [rotary_cos_slice, rotary_sin_slice] = + slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q); + ck_tile::reference_batched_rotary_position_embedding( - q_host_ref, rotary_cos_host, rotary_sin_host, is_rotary_interleaved, q_host_ref_ro); + q_host_ref, rotary_cos_slice, rotary_sin_slice, is_rotary_interleaved, q_host_ref_ro); q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); }); } @@ -1165,10 +1168,13 @@ bool run(const ck_tile::ArgParser& arg_parser) { knew_host_ref_ro.emplace(knew_host_ref.get_lengths()); + auto [rotary_cos_slice, rotary_sin_slice] = + slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew); + ck_tile::reference_batched_rotary_position_embedding( knew_host_ref, - rotary_cos_host, - rotary_sin_host, + rotary_cos_slice, + rotary_sin_slice, is_rotary_interleaved, knew_host_ref_ro.value()); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index 73991ce85d..a0b5d69e5a 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -494,7 +494,7 @@ struct FmhaFwdAppendKVKernel const auto rotary_cos_dram_native = make_naive_tensor_view( reinterpret_cast(kargs.rotary_cos_ptr), - make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2), + make_tuple(kargs.seqlen_k + kargs.seqlen_q, kargs.rotary_dim / 2), make_tuple(kargs.rotary_dim / 2, 1), number<8>{}, number<1>{}); @@ -505,8 +505,9 @@ struct FmhaFwdAppendKVKernel sequence{}); }(); - return make_tile_window( - rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0}); + return make_tile_window(rotary_cos_dram, + q_rotary_cos_sin_dram_window_lengths, + {kargs.seqlen_k + i_m0, 0}); } else { @@ -519,7 +520,7 @@ struct FmhaFwdAppendKVKernel const auto rotary_sin_dram_native = make_naive_tensor_view( reinterpret_cast(kargs.rotary_sin_ptr), - make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2), + make_tuple(kargs.seqlen_k + kargs.seqlen_q, kargs.rotary_dim / 2), make_tuple(kargs.rotary_dim / 2, 1), number<8>{}, number<1>{}); @@ -530,8 +531,9 @@ struct FmhaFwdAppendKVKernel sequence{}); }(); - return make_tile_window( - rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0}); + return make_tile_window(rotary_sin_dram, + q_rotary_cos_sin_dram_window_lengths, + {kargs.seqlen_k + i_m0, 0}); } else { @@ -558,8 +560,9 @@ struct FmhaFwdAppendKVKernel sequence{}); }(); - return make_tile_window( - rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0}); + return make_tile_window(rotary_cos_dram, + knew_rotary_cos_sin_dram_window_lengths, + {kargs.seqlen_k + i_n0, 0}); } else { @@ -583,8 +586,9 @@ struct FmhaFwdAppendKVKernel sequence{}); }(); - return make_tile_window( - rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0}); + return make_tile_window(rotary_sin_dram, + knew_rotary_cos_sin_dram_window_lengths, + {kargs.seqlen_k + i_n0, 0}); } else {