mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Fix wrong tensor size
This commit is contained in:
@@ -383,9 +383,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
const ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
|
||||
if(!(rotary_dim < hdim_q) || !(rotary_dim < hdim_v))
|
||||
if(!(rotary_dim <= hdim_q))
|
||||
{
|
||||
std::cerr << "rotary_dim should be less than head dim for q/q" << std::endl;
|
||||
std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1053,7 +1053,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// copy Knew to the end of K
|
||||
if(0 < seqlen_knew)
|
||||
{
|
||||
ck_tile::HostTensor<KDataType> knew_host_ref({nhead, real_seqlen_k, hdim_q});
|
||||
ck_tile::HostTensor<KDataType> knew_host_ref({nhead, seqlen_knew, hdim_q});
|
||||
if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[0] / nr, i[1], i[2]); });
|
||||
else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[1], i[0] / nr, i[2]); });
|
||||
|
||||
|
||||
Reference in New Issue
Block a user