diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 0342ea2d67..c012ba8823 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1029,25 +1029,17 @@ bool run(const ck_tile::ArgParser& arg_parser) // append (override) Knew to the end of K if(0 < seqlen_knew) { + ck_tile::HostTensor knew_host_ref({nhead, real_seqlen_k, hdim_q}); + if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[0] / nr, i[1], i[2]); }); + else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[1], i[0] / nr, i[2]); }); + const std::size_t knew_start = real_seqlen_k - seqlen_knew; - if(i_perm) - { - k_host_ref.ForEach([&](auto& self, auto i) { - if(knew_start <= i[1]) - { - self(i) = knew_host(b, i[0], i[1] - knew_start, i[2]); - } - }); - } - else - { - k_host_ref.ForEach([&](auto& self, auto i) { - if(knew_start <= i[1]) - { - self(i) = knew_host(b, i[1] - knew_start, i[0], i[2]); - } - }); - } + k_host_ref.ForEach([&](auto& self, auto i) { + if(knew_start <= i[1]) + { + self(i) = knew_host_ref(i[0], i[1] - knew_start, i[2]); + } + }); } if (is_v_rowmajor) {