This commit is contained in:
Feng Shijie
2025-08-11 11:24:34 +00:00
parent 200a11afc8
commit edb58d0680
7 changed files with 112 additions and 50 deletions

View File

@@ -41,7 +41,7 @@ struct A16W4_FlatmmConfig16
struct A16W4_FlatmmConfig16_950 : public A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr int kBlockPerCu = 1;
static constexpr int N_Repeat =

View File

@@ -372,15 +372,17 @@ auto preShuffleScale(const ck_tile::HostTensor<T>& scale)
static_assert(sizeof(T) * K_Pack * FlatmmConfig::N_Repeat <= 16, "inefficient pack policy");
ck_tile::HostTensor<T> shfl_scale({
n_ / FlatmmConfig::N_Repeat / FlatmmConfig::N_Warp_Tile,
FlatmmConfig::N_Repeat,
FlatmmConfig::N_Warp_Tile,
k_ / K_Pack / K_Lane,
K_Pack,
K_Lane,
n_ / FlatmmConfig::N_Tile,
FlatmmConfig::N_Repeat,
FlatmmConfig::N_Warp,
FlatmmConfig::N_Warp_Tile,
});
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
return ck_tile::reference_permute(shfl_scale, {0, 3, 5, 2, 4, 1});
// return ck_tile::reference_permute(shfl_scale, {0, 3, 5, 2, 4, 1});
return ck_tile::reference_permute(shfl_scale, {3, 5, 0, 2, 6, 1, 4});
}
#include "run_mixed_prec_flatmm.inc"