mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
update
This commit is contained in:
@@ -41,7 +41,7 @@ struct A16W4_FlatmmConfig16
|
||||
|
||||
struct A16W4_FlatmmConfig16_950 : public A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr int N_Repeat =
|
||||
|
||||
@@ -372,15 +372,17 @@ auto preShuffleScale(const ck_tile::HostTensor<T>& scale)
|
||||
static_assert(sizeof(T) * K_Pack * FlatmmConfig::N_Repeat <= 16, "inefficient pack policy");
|
||||
|
||||
ck_tile::HostTensor<T> shfl_scale({
|
||||
n_ / FlatmmConfig::N_Repeat / FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::N_Repeat,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
k_ / K_Pack / K_Lane,
|
||||
K_Pack,
|
||||
K_Lane,
|
||||
n_ / FlatmmConfig::N_Tile,
|
||||
FlatmmConfig::N_Repeat,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
});
|
||||
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
|
||||
return ck_tile::reference_permute(shfl_scale, {0, 3, 5, 2, 4, 1});
|
||||
// return ck_tile::reference_permute(shfl_scale, {0, 3, 5, 2, 4, 1});
|
||||
return ck_tile::reference_permute(shfl_scale, {3, 5, 0, 2, 6, 1, 4});
|
||||
}
|
||||
|
||||
#include "run_mixed_prec_flatmm.inc"
|
||||
|
||||
Reference in New Issue
Block a user