mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
merge from dteng_flatmm_opt
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -340,6 +340,37 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
// constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
// static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
// "Incorrect M0, M2, M1 configuration! "
|
||||
// "M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user