Fix wrong tensor size

This commit is contained in:
PoYen, Chen
2024-07-14 15:40:56 +00:00
parent 93e5125d7a
commit 55f55025ee

View File

@@ -383,9 +383,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
const ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
if(!(rotary_dim < hdim_q) || !(rotary_dim < hdim_v))
if(!(rotary_dim <= hdim_q))
{
std::cerr << "rotary_dim should be less than head dim for q/q" << std::endl;
std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl;
return false;
}
@@ -1053,7 +1053,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// copy Knew to the end of K
if(0 < seqlen_knew)
{
ck_tile::HostTensor<KDataType> knew_host_ref({nhead, real_seqlen_k, hdim_q});
ck_tile::HostTensor<KDataType> knew_host_ref({nhead, seqlen_knew, hdim_q});
if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[0] / nr, i[1], i[2]); });
else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[1], i[0] / nr, i[2]); });