mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Fix wrong index into knew_host/vnew_host
This commit is contained in:
@@ -1106,7 +1106,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::reference_batched_rotary_position_embedding<RoPEComputeDataType>(
|
||||
q_host_ref, rotary_cos_host, rotary_sin_host, is_rotary_interleaved, q_host_ref_ro);
|
||||
|
||||
#if 0
|
||||
q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); });
|
||||
#endif
|
||||
}
|
||||
#if 0
|
||||
HOST_DEBUG_STMTS {
|
||||
@@ -1136,8 +1138,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(0 < seqlen_knew)
|
||||
{
|
||||
ck_tile::HostTensor<KDataType> knew_host_ref({nhead, seqlen_knew, hdim_q});
|
||||
if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[0] / nr, i[1], i[2]); });
|
||||
else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[1], i[0] / nr, i[2]); });
|
||||
if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[0] / nr, i[1], i[2]); });
|
||||
else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[1], i[0] / nr, i[2]); });
|
||||
|
||||
// optionally apply RoPE to the knew_host_ref
|
||||
auto* real_knew_host_ref = &knew_host_ref;
|
||||
@@ -1222,13 +1224,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<VDataType> vnew_host_ref({nhead, hdim_v, seqlen_knew});
|
||||
if(is_v_rowmajor)
|
||||
{
|
||||
if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(b, i[0] / nr, i[2], i[1]); });
|
||||
else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(b, i[2], i[0] / nr, i[1]); });
|
||||
if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[2], i[1]); });
|
||||
else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[2], i[0] / nr, i[1]); });
|
||||
}
|
||||
else
|
||||
{
|
||||
if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(b, i[0] / nr, i[1], i[2]); });
|
||||
else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(b, i[1], i[0] / nr, i[2]); });
|
||||
if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[1], i[2]); });
|
||||
else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[1], i[0] / nr, i[2]); });
|
||||
}
|
||||
|
||||
const std::size_t knew_start = real_seqlen_k - seqlen_knew;
|
||||
|
||||
Reference in New Issue
Block a user