From 85bac939518f26b74e57a7291968cc1d3c8bda8a Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 23 Jul 2024 15:31:15 +0000 Subject: [PATCH] Fix wrong index into knew_host/vnew_host --- example/ck_tile/01_fmha/fmha_fwd.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index a0c55b3bdd..f2e966597c 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1106,7 +1106,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::reference_batched_rotary_position_embedding( q_host_ref, rotary_cos_host, rotary_sin_host, is_rotary_interleaved, q_host_ref_ro); + #if 0 q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); }); + #endif } #if 0 HOST_DEBUG_STMTS { @@ -1136,8 +1138,8 @@ bool run(const ck_tile::ArgParser& arg_parser) if(0 < seqlen_knew) { ck_tile::HostTensor knew_host_ref({nhead, seqlen_knew, hdim_q}); - if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[0] / nr, i[1], i[2]); }); - else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[1], i[0] / nr, i[2]); }); + if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[0] / nr, i[1], i[2]); }); + else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[1], i[0] / nr, i[2]); }); // optionally apply RoPE to the knew_host_ref auto* real_knew_host_ref = &knew_host_ref; @@ -1222,13 +1224,13 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor vnew_host_ref({nhead, hdim_v, seqlen_knew}); if(is_v_rowmajor) { - if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(b, i[0] / nr, i[2], i[1]); }); - else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(b, i[2], i[0] / nr, i[1]); }); + if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[2], i[1]); }); + else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[2], i[0] / nr, i[1]); }); } else { - if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(b, i[0] / nr, i[1], i[2]); }); - else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(b, i[1], i[0] / nr, i[2]); }); + if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[1], i[2]); }); + else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[1], i[0] / nr, i[2]); }); } const std::size_t knew_start = real_seqlen_k - seqlen_knew;