Merge commit '6b1bceca7baea62941793e562d6ff58c571d9191' into develop

This commit is contained in:
assistant-librarian[bot]
2025-12-05 18:14:37 +00:00
parent e4b2f98d0d
commit b2019db495
7 changed files with 257 additions and 36 deletions

View File

@@ -14,6 +14,7 @@ auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
}
int m_ = t->get_lengths()[0];
int aqk_ = t->get_lengths()[1];
if(aqk_ % block_aq_k != 0)
{
throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
@@ -110,7 +111,7 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
}
template <typename GemmConfig, typename T>
auto bq_permuteN(const ck_tile::HostTensor<T>& t)
auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
{
assert(t.get_lengths().size() == 2);
@@ -118,8 +119,11 @@ auto bq_permuteN(const ck_tile::HostTensor<T>& t)
int bqk_ = t.get_lengths()[0];
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
ck_tile::HostTensor<T> t_view(
{n_ / GemmConfig::N_Tile, GemmConfig::N_Warp, GemmConfig::N_Warp_Tile, NRepeat, bqk_});
ck_tile::HostTensor<T> t_view({n_ / (GemmConfig::N_Tile / group_n),
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile / group_n,
NRepeat,
bqk_});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4});
}