mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
change g tile distribution
This commit is contained in:
@@ -126,7 +126,7 @@ struct FusedMoeGemmPipeline_General
|
||||
Policy::template MakeGlobalTileDistribution_G<Problem>());
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetBlockGemm0<Problem>();
|
||||
constexpr auto gemm_0 = Policy::template GetBlockGemm0<Problem>();
|
||||
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
auto s_acc = SaccBlockTileType{};
|
||||
|
||||
@@ -138,7 +138,6 @@ struct FusedMoeGemmPipeline_General
|
||||
ignore = s_acc;
|
||||
|
||||
store_tile(o_window_, a_dram_block);
|
||||
|
||||
#if 0
|
||||
//check a matrix gather right or not
|
||||
constexpr auto a_spans = decltype(a_dram_block)::get_distributed_spans();
|
||||
|
||||
@@ -17,7 +17,8 @@ namespace ck_tile {
|
||||
|
||||
struct FusedMoeGemmPipelineGeneralPolicy
|
||||
{
|
||||
static constexpr int kKIter = 2;
|
||||
static constexpr int kKIter = 2;
|
||||
static constexpr int kKPerBlock = 32;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords()
|
||||
{
|
||||
@@ -197,14 +198,18 @@ struct FusedMoeGemmPipelineGeneralPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G()
|
||||
{
|
||||
using S_ = typename Problem::BlockShape;
|
||||
using S_ = typename Problem::BlockShape;
|
||||
constexpr index_t K2 = S_::Warp_K0;
|
||||
constexpr index_t K1 = get_warp_size() / S_::Warp_N0;
|
||||
constexpr index_t K0 = kKPerBlock / (K1 * K2);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<S_::Repeat_N0, S_::WarpPerBlock_N0, S_::Warp_N0>,
|
||||
sequence<kKIter, get_warp_size() / S_::Warp_N0, S_::Warp_K0>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
sequence<K0, K1, K2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
@@ -212,23 +217,21 @@ struct FusedMoeGemmPipelineGeneralPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm0()
|
||||
{
|
||||
using S_ = typename Problem::BlockShape;
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::ADataType,
|
||||
typename Problem::GDataType,
|
||||
typename Problem::AccDataType,
|
||||
S_::BlockSize,
|
||||
TileGemmShape<typename S_::BlockTile_0,
|
||||
typename S_::WarpPerBlock_0,
|
||||
typename S_::WarpTile_0>>;
|
||||
using S_ = typename Problem::BlockShape;
|
||||
using GemmProblem = BlockGemmProblem<typename Problem::ADataType,
|
||||
typename Problem::GDataType,
|
||||
typename Problem::AccDataType,
|
||||
S_::BlockSize,
|
||||
TileGemmShape<typename S_::BlockTile_0,
|
||||
typename S_::WarpPerBlock_0,
|
||||
typename S_::WarpTile_0>>;
|
||||
|
||||
constexpr auto warp_gemm = GetWarpGemm0<Problem>();
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmASmemBRegCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::GDataType,
|
||||
typename Problem::AccDataType,
|
||||
typename S_::WarpPerBlock_0,
|
||||
decltype(warp_gemm)>;
|
||||
using BlockGemmPolicy = BlockGemmASmemBRegCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::GDataType,
|
||||
typename Problem::AccDataType,
|
||||
typename S_::WarpPerBlock_0,
|
||||
decltype(warp_gemm)>;
|
||||
|
||||
return BlockGemmASmemBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user