diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index a26af18a58..b54f319e01 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -252,10 +252,11 @@ struct MoeFlatmmKernel #else false; #endif - static constexpr int MXFP4M_Pack = MXF8F6F4MFMA ? 1 : 2; - static constexpr int MXFP4N_Pack = MXF8F6F4MFMA ? 1 : 2; - static constexpr int MXFP4K_Pack = MXF8F6F4MFMA ? 4 : 2; + static constexpr int MXFP4M_Pack = 2; + static constexpr int MXFP4N_Pack = 2; + static constexpr int MXFP4K_Pack = 2; + static constexpr int M_Pack = AQUANT_Pipeline ? MXFP4M_Pack : 1; static constexpr int N_Pack = BMXFP4_Pipeline ? MXFP4N_Pack : 1; static constexpr int K_Pack = BMXFP4_Pipeline ? MXFP4K_Pack : 1; @@ -647,12 +648,12 @@ struct MoeFlatmmKernel constexpr int AGranularityK = 32; //TODO: enable e8m0_t scale - using AScaleType = float; //std::conditional_t; - // using AScaleType = e8m0_t; //std::conditional_t; + // using AScaleType = float; //std::conditional_t; + using AScaleType = e8m0_t; //std::conditional_t; const auto& scale_a_tensor_view = [&]() { - // if constexpr(std::is_same_v) - // { + if constexpr(std::is_same_v) + { index_t scale_m = kargs.M; index_t scale_k = AGranularityK == 0 ? 1 : (kargs.K + AGranularityK - 1) / AGranularityK; return make_naive_tensor_view( @@ -661,57 +662,57 @@ struct MoeFlatmmKernel make_tuple(scale_k, 1), number<8>{}, number<1>{}); - // } - // else if constexpr(std::is_same_v) - // { - // constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0); - // constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0); - // index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl); - // index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl); - // // Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load - // const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed( - // make_tuple(scale_m_packs, scale_k_packs, KThreadPerXdl, MThreadPerXdl)); - // const auto scale_a_desc = transform_tensor_descriptor( - // scale_a_naive_desc, - // make_tuple(make_merge_transform(make_tuple(scale_m_packs, MThreadPerXdl)), - // make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))), - // make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), - // make_tuple(sequence<0>{}, sequence<1>{})); - // return make_tensor_view( - // reinterpret_cast(scale_a.ptr), scale_a_desc); - // } + } + else if constexpr(std::is_same_v) + { + constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0); + constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0); + index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl); + index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl); + // Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load + const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed( + make_tuple(scale_m_packs, scale_k_packs, KThreadPerXdl, MThreadPerXdl)); + const auto scale_a_desc = transform_tensor_descriptor( + scale_a_naive_desc, + make_tuple(make_merge_transform(make_tuple(scale_m_packs, MThreadPerXdl)), + make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))), + make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return make_tensor_view( + reinterpret_cast(scale_a.ptr), scale_a_desc); + } }(); auto scale_n = kargs.scale_n; constexpr int BGranularityK = decltype(scale_n)::GranularityK; const auto scale_b_flat_view = [&]() { - // if constexpr(AQUANT_Pipeline) - // { - // constexpr int NThreadPerXdl = BlockGemmShape::WarpTile::at(I1); - // constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I1); - // index_t scale_n_packs = kargs.N / (MXFP4N_Pack * NThreadPerXdl); - // index_t scale_k_packs = kargs.K / (MXFP4K_Pack * BGranularityK * KThreadPerXdl); - // const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed( - // make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl)); - // const auto scale_b_desc = transform_tensor_descriptor( - // scale_b_navie_desc, - // make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)), - // make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))), - // make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), - // make_tuple(sequence<0>{}, sequence<1>{})); - // - // return make_tensor_view( - // reinterpret_cast(scale_b.ptr), scale_b_desc); - // - // } - // else - // { + if constexpr(AQUANT_Pipeline) + { + constexpr int NThreadPerXdl = BlockGemmShape::WarpTile::at(I1); + constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I1); + index_t scale_n_packs = kargs.N / (MXFP4N_Pack * NThreadPerXdl); + index_t scale_k_packs = kargs.K / (MXFP4K_Pack * BGranularityK * KThreadPerXdl); + const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed( + make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl)); + const auto scale_b_desc = transform_tensor_descriptor( + scale_b_navie_desc, + make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)), + make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))), + make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return make_tensor_view( + reinterpret_cast(scale_b.ptr), scale_b_desc); + + } + else + { index_t scale_k = BGranularityK == 0 ? 1 : (kargs.K + BGranularityK - 1) / BGranularityK; index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1); index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1); - using ScaleType = std::conditional_t; + using ScaleType = e8m0_t; return make_naive_tensor_view( reinterpret_cast(scale_n.ptr) + expert_id * kargs.N * scale_k, @@ -719,7 +720,7 @@ struct MoeFlatmmKernel make_tuple(FlatScaleK, 1), number<8>{}, number<1>{}); - // } + } }(); @@ -819,21 +820,36 @@ struct MoeFlatmmKernel constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline auto a_scale_block_window = + // make_tile_window(views.at(I3), + // make_tuple(number{}, + // number{}), + // {coord_m, 0}); make_tile_window(views.at(I3), - make_tuple(number{}, - number{}), - {coord_m, 0}); + make_tuple(number{}, + number{}), + {i_m / M_Pack, 0}); // constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline constexpr int XDLPerLoadScaleB = BMXFP4_Pipeline ? 4 : 1; // GranularityK32 / XDL16x16x32_K8 = 4 - auto b_scale_block_window = - make_tile_window(views.at(I4), - make_tuple(number{}, - number{}), - {coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0}); + auto b_scale_block_window = [&]() { + if constexpr(MXF8F6F4MFMA) + { + return make_tile_window(views.at(I4), + make_tuple(number{}, + number{}), + {coord_n / N_Pack, 0}); + } + else + { + return make_tile_window(views.at(I4), + make_tuple(number{}, + number{}), + {coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0}); + } + }(); return make_tuple(a_block_window, b_flat_block_window, @@ -947,26 +963,9 @@ struct MoeFlatmmKernel // so don't need extra processing if constexpr(AQUANT_Pipeline) { - constexpr int AGranularityK = decltype(kargs.scale_m)::GranularityK; - constexpr auto a_scale_dram_dist = FlatmmPipeline::GetAScaleDramTileDistribution(); - constexpr ck_tile::index_t DramMScaleRepeat = - decltype(a_scale_dram_dist)::DstrEncode::hs_lengthss_[number<0>{}][number<0>{}]; - statically_indexed_array a_scale_offsets; - static_for<0, DramMScaleRepeat, 1>{}([&](auto m0) { - const auto row_idx = - coord_m + m0 * (TilePartitioner::MPerBlock / DramMScaleRepeat) + a_coord[I0]; - index_t gather_token_id = row_to_token_idx(row_idx); - a_scale_offsets[m0] = gather_token_id * kargs.stride_A / AGranularityK; - }); - auto a_scale_gather_block_tile = - ck_tile::make_tile_scatter_gather(a_scale_block_window.get_bottom_tensor_view(), - a_scale_block_window.get_window_lengths(), - a_scale_block_window.get_window_origin(), - a_scale_dram_dist, - a_scale_offsets); // K DRAM tile window for return FlatmmPipeline{}(a_gather_block_tile, b_block_window, - a_scale_gather_block_tile, // weight scale with granularityK = 32 + a_scale_block_window, // weight scale with granularityK = 32 b_scale_block_window, // weight scale with granularityK = 32 num_loop, // kargs.k_padded_zeros, diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 1015c2f3fd..3dc37dd5b5 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -2493,7 +2493,7 @@ template -struct F8xMXF4FlatmmPipelineProblem : FlatmmPipelineProblem +template struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 { using Underlying = FlatmmPipelineAGmemBGmemCRegV1; @@ -2945,11 +2943,6 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1(); } - CK_TILE_HOST_DEVICE static constexpr auto GetAScaleDramTileDistribution() - { - return PipelinePolicy::template MakeMXFP4_ScaleA_FlatDramTileDistribution(); - } - template ()); + PipelinePolicy::template MakeMXFP4_ADramTileDistribution()); __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 6a8bb1ba6c..459bfb050a 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -237,7 +237,292 @@ struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } }; -struct F8xMXF4FlatmmPipelineAgBgCrPolicy : MXF4FlatmmPipelineAgBgCrPolicy +// struct F8xMXF4FlatmmPipelineAgBgCrPolicy : MXF4FlatmmPipelineAgBgCrPolicy +// { +// static constexpr auto I0 = number<0>{}; +// static constexpr auto I1 = number<1>{}; +// static constexpr auto I2 = number<2>{}; +// +// static constexpr index_t kDramLoadPackBytes = 128; +// +// static constexpr int MXdlPack = 1; +// static constexpr int NXdlPack = 1; +// static constexpr int KXdlPack = 4; +// +// template +// static inline constexpr auto wg_attr_num_access = +// std::is_same_v, pk_fp4_t> +// ? WGAttrNumAccessEnum::Single +// : WGAttrNumAccessEnum::Double; +// +// template +// CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm() +// { +// using ADataType = remove_cvref_t; +// using BDataType = remove_cvref_t; +// static_assert( +// sizeof(ADataType) * numeric_traits::PackedSize == +// sizeof(BDataType) * numeric_traits::PackedSize, +// "sizeof(ADataType) / APackedSize must be equal to sizeof(BDataType) / BPackedSize!"); +// using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; +// using WarpTile = typename Problem::BlockGemmShape::WarpTile; +// using WarpGemm = WarpGemmDispatcher< // +// ADataType, +// BDataType, +// typename Problem::CDataType, +// WarpTile::at(I0), +// WarpTile::at(I1), +// WarpTile::at(I2), +// Problem::TransposeC, +// false, +// false, +// wg_attr_num_access>; +// using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy< // +// ADataType, +// BDataType, +// typename Problem::CDataType, +// BlockWarps, +// WarpGemm>; +// return BlockFlatmmASmemBSmemCRegV1{}; +// } +// +// template +// CK_TILE_DEVICE static constexpr auto +// MakeMXFP4_AAsyncLoadDramDescriptor(const TensorView& naive_view) +// { +// using ADataType = remove_cvref_t; +// using ALayout = remove_cvref_t; +// constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); +// constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1); +// static_assert(MPerXdl == 16 && NPerXdl == 16); +// static_assert(std::is_same_v); +// +// const auto& naive_desc = naive_view.get_tensor_descriptor(); +// constexpr auto ndims = remove_cvref_t::get_num_of_dimension(); +// static_assert(ndims == 2, "only support 2D tensor"); +// const auto rows = naive_desc.get_length(number<0>{}); +// const auto cols = naive_desc.get_length(number<1>{}); +// +// constexpr index_t APackedSize = numeric_traits::PackedSize; +// constexpr index_t K2 = GetSmemPackA() * APackedSize; // f4=32; f8=16 +// constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 +// const index_t K0 = cols / (K1 * K2); +// const auto col_lens = make_tuple(K0, number{}, number{}); +// +// constexpr index_t M1 = 4; // so that we can use imm offset to load lds +// const index_t M0 = rows / M1; +// const auto row_lens = make_tuple(M0, number{}); +// +// const auto desc_0 = +// make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens)); +// const auto desc_1 = transform_tensor_descriptor( +// desc_0, +// make_tuple(make_pass_through_transform(M0), +// make_xor_transform(make_tuple(number{}, number{})), +// make_pass_through_transform(K0), +// make_pass_through_transform(number{})), +// make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}), +// make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{})); +// const auto desc = transform_tensor_descriptor( // +// desc_1, +// make_tuple(make_merge_transform_v3_division_mod(row_lens), +// make_merge_transform_v3_division_mod(col_lens)), +// make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}), +// make_tuple(sequence<0>{}, sequence<1>{})); +// // printf("A async load dram desc %d x %d: \n", desc.get_length(I0), desc.get_length(I1)); +// +// return tensor_view, +// TensorView::DstInMemOp>{naive_view.buf_, desc}; +// } +// +// template +// CK_TILE_DEVICE static constexpr auto MakeADramTileDistribution() +// { +// +// using ADataType = remove_cvref_t; +// using ALayout = remove_cvref_t; +// static_assert(std::is_same_v); +// +// constexpr index_t BlockSize = Problem::kBlockSize; +// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; +// constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; +// constexpr index_t APackedSize = numeric_traits::PackedSize; +// +// constexpr index_t K2 = GetSmemPackA() * APackedSize; // f4=32; f8=16 +// constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 +// constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256 +// +// constexpr index_t M2 = get_warp_size() / K1; // 8 +// constexpr index_t M1 = BlockSize / get_warp_size(); // 4 +// constexpr index_t M0 = MPerBlock / (M2 * M1); +// static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!"); +// static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!"); +// +// return make_static_tile_distribution( +// tile_distribution_encoding< // +// sequence<1>, +// tuple, sequence>, // ?,4,8 1,8,32 or 2,8,16 +// tuple, sequence<1, 2>>, // M1 M2,K1 +// tuple, sequence<2, 1>>, +// sequence<1, 2, 2>, // M0,K0,K2 +// sequence<0, 0, 2>>{}); +// } +// +// template +// CK_TILE_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor() +// { +// using ADataType = remove_cvref_t; +// using ALayout = remove_cvref_t; +// constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); +// constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1); +// static_assert(MPerXdl == 16 && NPerXdl == 16); +// static_assert(std::is_same_v); +// +// /*reduce transform layers,compare with old ck*/ +// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; +// constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; +// constexpr index_t APackedSize = numeric_traits::PackedSize; +// constexpr index_t K2 = GetSmemPackA() * APackedSize; // f4=32; f8=16 +// constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 +// constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256 +// static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!"); +// +// constexpr index_t M3 = 4; // so that we can use imm offset to load lds +// constexpr index_t M2 = get_warp_size() / K1 / M3; // 2 +// constexpr index_t M1 = MPerXdl / (M2 * M3); // 2 +// constexpr index_t M0 = MPerBlock / (M1 * M2 * M3); // MPerBlock/16 +// static_assert(M0 * M1 * M2 * M3 == MPerBlock, "M0, M1, M2, M3 must cover whole MPerBlock!"); +// +// constexpr index_t Pad = 4 * K2; // 4 * 32 +// +// constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( // +// make_tuple(number{}, +// number{}, +// number{}, +// number{}, +// number{}, +// number{}, +// number{}), +// make_tuple(number{}, +// number{}, +// number{}, +// number{}, +// number{}, +// number{}, +// number<1>{}), +// number{}, +// number<1>{}); +// +// constexpr auto a_lds_block_desc_1 = transform_tensor_descriptor( +// a_lds_block_desc_0, +// make_tuple(make_pass_through_transform(M0), +// make_pass_through_transform(M1), +// make_pass_through_transform(K0), +// make_pass_through_transform(M2), +// make_xor_transform(make_tuple(number{}, number{})), +// make_pass_through_transform(number{})), +// make_tuple(sequence<0>{}, +// sequence<1>{}, +// sequence<2>{}, +// sequence<3>{}, +// sequence<4, 5>{}, +// sequence<6>{}), +// make_tuple(sequence<0>{}, +// sequence<1>{}, +// sequence<2>{}, +// sequence<3>{}, +// sequence<4, 5>{}, +// sequence<6>{})); +// constexpr auto a_lds_block_desc = transform_tensor_descriptor( +// a_lds_block_desc_1, +// make_tuple(make_merge_transform_v3_division_mod( +// make_tuple(number{}, number{}, number{}, number{})), +// make_merge_transform_v3_division_mod( +// make_tuple(number{}, number{}, number{}))), +// make_tuple(sequence<0, 1, 3, 4>{}, sequence<2, 5, 6>{}), +// make_tuple(sequence<0>{}, sequence<1>{})); +// +// // return a_lds_block_desc_permuted; +// return a_lds_block_desc; +// } +// +// template +// CK_TILE_HOST_DEVICE static constexpr auto MakeMXF4_ALDS_TileDistribution() +// { +// using TileShape = typename Problem::BlockGemmShape; +// +// static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16"); +// static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1"); +// +// constexpr int M_warps = TileShape::BlockWarps::at(number<0>{}); +// constexpr int N_warps = TileShape::BlockWarps::at(number<1>{}); +// constexpr int M_Lane = TileShape::WarpTile::at(I0); // 16 +// +// constexpr int K_Lane = 64 / M_Lane; // 4 +// +// constexpr int K_Thread = TileShape::WarpTile::at(I2) / K_Lane; // 32 +// constexpr index_t num_access_v = static_cast(wg_attr_num_access); +// constexpr int K1 = K_Thread / num_access_v; // 16 +// +// return make_static_tile_distribution( +// std::conditional_t< +// num_access_v == 1, +// tile_distribution_encoding< +// sequence, +// tuple, sequence>, +// tuple, sequence<2, 1>>, +// tuple, sequence<0, 2>>, +// sequence<2>, +// sequence<1>>, +// tile_distribution_encoding< // +// sequence, +// tuple, sequence>, +// tuple, sequence<2, 1>>, +// tuple, sequence<1, 2>>, +// sequence<2, 2>, +// sequence<0, 2>>>{}); +// } +// +// template +// CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_DramTileDistribution() +// { +// using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape +// +// constexpr index_t BlockSize = Problem::kBlockSize; +// constexpr index_t WaveSize = get_warp_size(); +// constexpr index_t WaveNum = BlockSize / WaveSize; +// +// constexpr index_t kMPerBlock = TileShape::BlockTile::at(I0); +// +// constexpr index_t M_Warps = TileShape::BlockWarps::at(I0); +// constexpr index_t N_Warps = TileShape::BlockWarps::at(I1); +// +// static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size"); +// +// constexpr index_t M_Lanes = TileShape::WarpTile::at(I0); +// constexpr index_t K_Lanes = 64 / M_Lanes; +// +// // Y dimension (M) decomposition +// constexpr index_t Y2 = M_Lanes; +// constexpr index_t Y1 = M_Warps; +// constexpr index_t Y0 = kMPerBlock / (MXdlPack * Y1 * Y2); +// +// // X dimension (K) decomposition +// constexpr index_t X0 = K_Lanes; +// constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load +// +// return make_static_tile_distribution( +// tile_distribution_encoding, // repeat N_warps +// tuple, sequence>, +// tuple, sequence<2, 1>>, +// tuple, sequence<0, 2>>, +// sequence<1, 2>, +// sequence<0, 1>>{}); +// } +// }; + +struct F8xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy { static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; @@ -245,9 +530,9 @@ struct F8xMXF4FlatmmPipelineAgBgCrPolicy : MXF4FlatmmPipelineAgBgCrPolicy static constexpr index_t kDramLoadPackBytes = 128; - static constexpr int MXdlPack = 1; - static constexpr int NXdlPack = 1; - static constexpr int KXdlPack = 4; + static constexpr int MXdlPack = 2; + static constexpr int NXdlPack = 2; + static constexpr int KXdlPack = 2; template static inline constexpr auto wg_attr_num_access =