From b6c780fc7f4fbc568bf94ba005de1bf0a3fe4e88 Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Tue, 24 Jun 2025 14:51:29 +0800 Subject: [PATCH] fix moe i4 bug from aiter (#2339) [ROCm/composable_kernel commit: bb571a033019fd5a8ba6de31119395c3621a4235] --- ...dlops_b_preshuffle_gufusion_dequant_v1.hpp | 50 ++++++++----------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp index 4f7b8e768c..29750b8baa 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp @@ -122,7 +122,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< using Base::B_K1; using Base::I0; using Base::I1; - using Base::KGroup; using Base::KRepeat; using Base::xdlops_gemm; using typename Base::HotLoopInstList; @@ -154,9 +153,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack / KGroup; + constexpr index_t K2 = KPack; constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat * KGroup; + constexpr index_t K0 = KRepeat; return transform_tensor_descriptor( TileDesc_M0_M1_M2_K{}, @@ -291,14 +290,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< block_sync_lds(); static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); }); }); // B VGPR->VGPR dequant @@ -391,15 +388,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); }); }); // B VGPR->VGPR dequant @@ -483,14 +477,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); }); }); // B VGPR->VGPR dequant @@ -596,7 +588,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< ComputeDataType, decltype(a_block_desc_m0_m1_m2_k0_k1_k2), decltype(a_thread_desc_), - Sequence<1, 1, 1, 1, 1, KPack / KGroup>, + Sequence<1, 1, 1, 1, 1, KPack>, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1,