Avoid using too small rotary_cos & rotary_sin

This commit is contained in:
PoYen, Chen
2024-08-18 18:27:37 +00:00
parent e5db71cc59
commit 4cd3432361

View File

@@ -603,8 +603,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
: std::array<ck_tile::index_t, 2>{batch, nhead})
: std::array<ck_tile::index_t, 2>{1, 1});
auto [rotary_cos_host, rotary_sin_host] =
generate_rotary_cos_sin<KDataType>(shape_seqlen_k, rotary_dim, seed);
auto [rotary_cos_host, rotary_sin_host] = generate_rotary_cos_sin<KDataType>(
std::max(shape_seqlen_q, shape_seqlen_k), rotary_dim, seed);
ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits || use_kvcache