Fix wrong index into knew_host/vnew_host

This commit is contained in:
PoYen, Chen
2024-07-23 15:31:15 +00:00
parent b11f92dc4c
commit 85bac93951

View File

@@ -1106,7 +1106,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::reference_batched_rotary_position_embedding<RoPEComputeDataType>(
q_host_ref, rotary_cos_host, rotary_sin_host, is_rotary_interleaved, q_host_ref_ro);
#if 0
q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); });
#endif
}
#if 0
HOST_DEBUG_STMTS {
@@ -1136,8 +1138,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(0 < seqlen_knew)
{
ck_tile::HostTensor<KDataType> knew_host_ref({nhead, seqlen_knew, hdim_q});
if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[0] / nr, i[1], i[2]); });
else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[1], i[0] / nr, i[2]); });
if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[0] / nr, i[1], i[2]); });
else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[1], i[0] / nr, i[2]); });
// optionally apply RoPE to the knew_host_ref
auto* real_knew_host_ref = &knew_host_ref;
@@ -1222,13 +1224,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<VDataType> vnew_host_ref({nhead, hdim_v, seqlen_knew});
if(is_v_rowmajor)
{
if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(b, i[0] / nr, i[2], i[1]); });
else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(b, i[2], i[0] / nr, i[1]); });
if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[2], i[1]); });
else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[2], i[0] / nr, i[1]); });
}
else
{
if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(b, i[0] / nr, i[1], i[2]); });
else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(b, i[1], i[0] / nr, i[2]); });
if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[1], i[2]); });
else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[1], i[0] / nr, i[2]); });
}
const std::size_t knew_start = real_seqlen_k - seqlen_knew;