From 73da271e037b4a60f3091fd4bb1961f221fbf599 Mon Sep 17 00:00:00 2001 From: huaiguxu <145733371+huaiguxu@users.noreply.github.com> Date: Wed, 16 Jul 2025 15:44:34 +0800 Subject: [PATCH] Handle moe_fp8 no-mainloop cases. Supprese no-mainloop check (#2438) Co-authored-by: felix [ROCm/composable_kernel commit: c1badfd30c1679f4c8e176c8f0608db2c6ac6505] --- .../gpu/device/impl/device_moe_gemm.hpp | 50 ++++++++++++++++--- .../gpu/grid/gridwise_moe_gemm.hpp | 2 +- 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp index 08d177035e..27d3c378ac 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp @@ -325,12 +325,50 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; - RunKernel(kernel); + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + } + else + { + throw std::runtime_error("todo: only v1 & v2 support now"); } } #endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 36f8fd7cc1..3d5066d52d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1112,7 +1112,7 @@ struct GridwiseMoeGemm } // check gridwise gemm pipeline -#if 1 +#if 0 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)