From 6dfe24c53e26fdc2ea254297059de5ef2b919be6 Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Tue, 13 May 2025 04:15:53 -0500 Subject: [PATCH] updated --- .../moe_gemm2_xdl_mx_fp4.cpp | 126 ++++++++++-------- .../gpu/grid/gridwise_moe_mx_gemm.hpp | 2 +- .../cpu/reference_moe_mx_gemm2.hpp | 14 +- 3 files changed, 83 insertions(+), 59 deletions(-) diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp index 59e623151a..948ddb441f 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp @@ -155,22 +155,22 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic // clang-format on #else -static constexpr ck::index_t MPerBlock = 128; -static constexpr bool MulRoutedWeight = true; +static constexpr ck::index_t MPerBlock = 16; +static constexpr bool MulRoutedWeight = true; // clang-format off using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX< A0Layout, B0Layout, DsLayout, ELayout, A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, - ScaleBlockSize, 256, - MPerBlock, 128, 128, + ScaleBlockSize, 64, + MPerBlock, 16, 128, 32, 32, 16, 16, - 8, 2, - S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, - S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, - 1, 1, S<1, 16, 1, 16>, S<2, 1, 1, 1>, + 1, 1, + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, + 1, 1, S<1, 8, 1, 8>, S<2, 1, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, ck::index_t, A0DataType>; // clang-format on #endif @@ -183,14 +183,14 @@ int main(int argc, char* argv[]) // per expert: // GEMM shape - constexpr ck::index_t sorted_tile_num = 19; - constexpr ck::index_t valid_tile_num = 16; + constexpr ck::index_t sorted_tile_num = 2; + constexpr ck::index_t valid_tile_num = 2; ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; ck::index_t N = 6144; ck::index_t K = 4096; - ck::index_t experts = 8; + ck::index_t experts = 2; ck::index_t tokens = 832; ck::index_t topk = 2; @@ -285,7 +285,7 @@ int main(int argc, char* argv[]) Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {(N * Scale_Stride_BN), Scale_Stride_BN, 1})); + {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); @@ -371,34 +371,32 @@ int main(int argc, char* argv[]) b0_device_buf.ToDevice(b0_preshuffled.mData.data()); - auto invoker = device_op.MakeInvoker(); - auto argument = - device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(), - expert_ids_dev.GetDeviceBuffer(), - max_token_id_dev.GetDeviceBuffer(), - a0_device_buf.GetDeviceBuffer(), - a1_device_buf.GetDeviceBuffer(), - b0_device_buf.GetDeviceBuffer(), - b1_device_buf.GetDeviceBuffer(), - std::array{d0_device_buf.GetDeviceBuffer(), - d1_device_buf.GetDeviceBuffer(), - d2_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - tokens, - topk, - sorted_size, - N, - K, - StrideA, - Scale_Stride_AM, - StrideB, - Scale_Stride_BN, - StrideDs, - StrideE, - KBatch, - a_element_op, - b_element_op, - cde_element_op); + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + std::array{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); if(!device_op.IsSupportedArgument(argument)) { @@ -439,19 +437,19 @@ int main(int argc, char* argv[]) Tensor c_t_n({tokens, N}); using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceMoeGemm2; + ck::tensor_operation::host::ReferenceMoeMXGemm2; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); @@ -480,6 +478,28 @@ int main(int argc, char* argv[]) e_device_buf.FromDevice(e_t_n_device_result.mData.data()); +#if 1 + printf("e_t_n_device_result:\n"); + for(int t = 0; t < tokens; ++t) + { + for(int n = 0; n < N; ++n) + { + printf("%f ", ck::type_convert(e_t_n_device_result(t, n))); + } + printf("\n"); + } + + printf("e_t_n_host_result:\n"); + for(int t = 0; t < tokens; ++t) + { + for(int n = 0; n < N; ++n) + { + printf("%f ", ck::type_convert(e_t_n_host_result(t, n))); + } + printf("\n"); + } +#endif + return ck::utils::check_err( e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) ? 0 diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp index 3996b0c1a5..0341428d8f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp @@ -1167,7 +1167,7 @@ struct GridwiseMoeGemmMX } // check gridwise gemm pipeline -#if 1 +#if 0 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp index da03ce7f10..5a2d33f0bf 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp @@ -28,7 +28,7 @@ template -struct ReferenceMoeGemm2 : public device::BaseOperator +struct ReferenceMoeMXGemm2 : public device::BaseOperator { // Argument struct Argument : public device::BaseArgument @@ -81,14 +81,18 @@ struct ReferenceMoeGemm2 : public device::BaseOperator // Invoker struct Invoker : public device::BaseInvoker { - using Argument = ReferenceMoeGemm2::Argument; + using Argument = ReferenceMoeMXGemm2::Argument; float Run(const Argument& arg) { arg.c_t_n_.SetZero(); - const ck::index_t SCALE_BLOCK = arg.b_e_n_k_.mDesc.GetLengths()[2]; - auto f_mk_kn_mn = [&](auto m, auto n) { - const int K = arg.a_t_k_k_.mDesc.GetLengths()[2]; + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_t_k_k_.mDesc.GetLengths()[2]; + const ck::index_t SCALE_BLOCK = K / arg.b_e_n_k_scale_.mDesc.GetLengths()[1]; + if(m == 0 && n == 0) + { + printf("SCALE_BLOCK: %d\n", SCALE_BLOCK); + } AccDataType v_acc{0}; ComputeTypeA v_a{0}; ComputeTypeB v_b{0};