merge from dteng_flatmm_opt

This commit is contained in:
lalala-sh
2025-07-16 10:12:19 +00:00
parent 3499fe67ff
commit fb76450e63
9 changed files with 1324 additions and 267 deletions

View File

@@ -340,6 +340,37 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
// constexpr index_t M0 = MPerBlock / (M2 * M1);
// static_assert(M0 * M1 * M2 == MPerBlock,
// "Incorrect M0, M2, M1 configuration! "
// "M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2>,
sequence<1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
{