Simplify K appending logics

This commit is contained in:
PoYen, Chen
2024-07-12 06:37:23 +00:00
parent 3578c6f836
commit e5885cab83

View File

@@ -1029,25 +1029,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
// append (override) Knew to the end of K
if(0 < seqlen_knew)
{
ck_tile::HostTensor<KDataType> knew_host_ref({nhead, real_seqlen_k, hdim_q});
if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[0] / nr, i[1], i[2]); });
else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[1], i[0] / nr, i[2]); });
const std::size_t knew_start = real_seqlen_k - seqlen_knew;
if(i_perm)
{
k_host_ref.ForEach([&](auto& self, auto i) {
if(knew_start <= i[1])
{
self(i) = knew_host(b, i[0], i[1] - knew_start, i[2]);
}
});
}
else
{
k_host_ref.ForEach([&](auto& self, auto i) {
if(knew_start <= i[1])
{
self(i) = knew_host(b, i[1] - knew_start, i[0], i[2]);
}
});
}
k_host_ref.ForEach([&](auto& self, auto i) {
if(knew_start <= i[1])
{
self(i) = knew_host_ref(i[0], i[1] - knew_start, i[2]);
}
});
}
if (is_v_rowmajor) {