diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index d917ee3293..889e969314 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1198,12 +1198,8 @@ bool run(const ck_tile::ArgParser& arg_parser) real_knew_host_ref = &knew_host_ref_ro.value(); } - const std::size_t knew_start = real_seqlen_k - seqlen_knew; - k_host_ref.ForEach([&](auto& self, auto i) { - if(knew_start <= i[1]) - { - self(i) = (*real_knew_host_ref)(i[0], i[1] - knew_start, i[2]); - } + (*real_knew_host_ref).ForEach([&](auto& self, auto i) { + k_host_ref(i[0], i[1] + cache_seqlen_ks[wb], i[2]) = self(i); }); } #endif @@ -1264,12 +1260,8 @@ bool run(const ck_tile::ArgParser& arg_parser) else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[1], i[0] / nr, i[2]); }); } - const std::size_t knew_start = real_seqlen_k - seqlen_knew; - v_host_ref.ForEach([&](auto& self, auto i) { - if(knew_start <= i[2]) - { - self(i) = vnew_host_ref(i[0], i[1], i[2] - knew_start); - } + vnew_host_ref.ForEach([&](auto& self, auto i) { + v_host_ref(i[0], i[1], i[2] + cache_seqlen_ks[wb]) = self(i); }); } #endif