From 2f5eb26839dee2cf893126b72fd4b72e869d9049 Mon Sep 17 00:00:00 2001 From: Zzz9990 Date: Tue, 9 Dec 2025 07:25:26 -0600 Subject: [PATCH] compile pass --- .../ops/flatmm/kernel/moe_flatmm_kernel.hpp | 25 +++++++++------- ...ec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 30 +++++++++---------- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 15 +++++----- 3 files changed, 36 insertions(+), 34 deletions(-) diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index b54f319e01..2e8f7a19e1 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -242,10 +242,13 @@ struct MoeFlatmmKernel IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock; // MXF4_Pipeline only has the of scale B and granularityK is 32 - static constexpr bool AQUANT_Pipeline = std::is_same_v || - std::is_same_v || - std::is_same_v; + static constexpr bool AQUANT_Pipeline = std::is_same_v || + std::is_same_v || + std::is_same_v; static constexpr bool BMXFP4_Pipeline = std::is_same_v; + + static_assert(AQUANT_Pipeline); + static_assert(BMXFP4_Pipeline); static constexpr bool MXF8F6F4MFMA = #ifdef __gfx950__ AQUANT_Pipeline && BMXFP4_Pipeline; @@ -663,7 +666,7 @@ struct MoeFlatmmKernel number<8>{}, number<1>{}); } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0); constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0); @@ -675,11 +678,11 @@ struct MoeFlatmmKernel const auto scale_a_desc = transform_tensor_descriptor( scale_a_naive_desc, make_tuple(make_merge_transform(make_tuple(scale_m_packs, MThreadPerXdl)), - make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))), + make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))), make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); return make_tensor_view( - reinterpret_cast(scale_a.ptr), scale_a_desc); + reinterpret_cast(scale_m_desc.ptr), scale_a_desc); } }(); @@ -694,16 +697,16 @@ struct MoeFlatmmKernel index_t scale_n_packs = kargs.N / (MXFP4N_Pack * NThreadPerXdl); index_t scale_k_packs = kargs.K / (MXFP4K_Pack * BGranularityK * KThreadPerXdl); const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed( - make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl)); + make_tuple(scale_n_packs, scale_k_packs, KThreadPerXdl, NThreadPerXdl)); const auto scale_b_desc = transform_tensor_descriptor( scale_b_navie_desc, - make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)), - make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))), + make_tuple(make_merge_transform(make_tuple(scale_n_packs, NThreadPerXdl)), + make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))), make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); return make_tensor_view( - reinterpret_cast(scale_b.ptr), scale_b_desc); + reinterpret_cast(scale_n.ptr), scale_b_desc); } else @@ -827,7 +830,7 @@ struct MoeFlatmmKernel make_tile_window(views.at(I3), make_tuple(number{}, number{}), - {i_m / M_Pack, 0}); + {coord_m / M_Pack, 0}); // constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline constexpr int XDLPerLoadScaleB = diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 3dc37dd5b5..edb84ac1e6 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -2493,15 +2493,15 @@ template -struct F8xMXFlatmmPipelineProblem : FlatmmPipelineProblem +struct F8xMXF4FlatmmPipelineProblem : FlatmmPipelineProblem { using BlockGemmShape = BlockGemmShape_; @@ -2519,7 +2519,7 @@ struct F8xMXFlatmmPipelineProblem : FlatmmPipelineProblem +template struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 { using Underlying = FlatmmPipelineAGmemBGmemCRegV1; @@ -2600,8 +2600,8 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1= DsReadPreload) ? DsReadPreload @@ -2612,8 +2612,8 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1()); + PipelinePolicy::template MakeADramTileDistribution()); __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 459bfb050a..c87b93397b 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -536,19 +536,17 @@ struct F8xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy template static inline constexpr auto wg_attr_num_access = - std::is_same_v, pk_fp4_t> - ? WGAttrNumAccessEnum::Single - : WGAttrNumAccessEnum::Double; + WGAttrNumAccessEnum::Single; + // std::is_same_v, pk_fp4_t> + // ? WGAttrNumAccessEnum::Single + // : WGAttrNumAccessEnum::Double; template CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm() { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; - static_assert( - sizeof(ADataType) * numeric_traits::PackedSize == - sizeof(BDataType) * numeric_traits::PackedSize, - "sizeof(ADataType) / APackedSize must be equal to sizeof(BDataType) / BPackedSize!"); + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpGemm = WarpGemmDispatcher< // @@ -634,7 +632,8 @@ struct F8xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t APackedSize = numeric_traits::PackedSize; - constexpr index_t K2 = GetSmemPackA() * APackedSize; // f4=32; f8=16 + constexpr index_t K2 = MPerBlock == 16 ? GetSmemPackA() * APackedSize/ 4: + GetSmemPackA() * APackedSize; // f4=32; f8=16 constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256