From 73378ff95da3b64d0d8525cdfc0191abfada7870 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 20 Aug 2024 23:58:47 +0000 Subject: [PATCH] Fix wrong knew/vew appending logic on host --- example/ck_tile/01_fmha/fmha_fwd.cpp | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index d917ee3293..889e969314 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1198,12 +1198,8 @@ bool run(const ck_tile::ArgParser& arg_parser) real_knew_host_ref = &knew_host_ref_ro.value(); } - const std::size_t knew_start = real_seqlen_k - seqlen_knew; - k_host_ref.ForEach([&](auto& self, auto i) { - if(knew_start <= i[1]) - { - self(i) = (*real_knew_host_ref)(i[0], i[1] - knew_start, i[2]); - } + (*real_knew_host_ref).ForEach([&](auto& self, auto i) { + k_host_ref(i[0], i[1] + cache_seqlen_ks[wb], i[2]) = self(i); }); } #endif @@ -1264,12 +1260,8 @@ bool run(const ck_tile::ArgParser& arg_parser) else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[1], i[0] / nr, i[2]); }); } - const std::size_t knew_start = real_seqlen_k - seqlen_knew; - v_host_ref.ForEach([&](auto& self, auto i) { - if(knew_start <= i[2]) - { - self(i) = vnew_host_ref(i[0], i[1], i[2] - knew_start); - } + vnew_host_ref.ForEach([&](auto& self, auto i) { + v_host_ref(i[0], i[1], i[2] + cache_seqlen_ks[wb]) = self(i); }); } #endif