diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp index 08d177035e..a6110d2bfc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp @@ -264,77 +264,152 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle= 256) ? 1 : 2; - constexpr auto MemoryDataOp = - IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd; - if(has_main_k_block_loop) + if(IsInputGemm || arg.TopK == 1) { - // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + constexpr auto MemoryDataOp = InMemoryDataOperationEnum::Set; + + if(has_main_k_block_loop) { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { - const auto kernel = kernel_moe_gemm; + const auto kernel = kernel_moe_gemm_2lds; RunKernel(kernel); } else { - const auto kernel = kernel_moe_gemm; + const auto kernel = kernel_moe_gemm_2lds; RunKernel(kernel); } } - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || - BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + else { - const auto kernel = kernel_moe_gemm_2lds; + throw std::runtime_error("todo: only v1 & v2 support now"); + } + } +#if 1 + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + const auto kernel = kernel_moe_gemm; RunKernel(kernel); } + } +#endif + } + else + { + constexpr auto MemoryDataOp = InMemoryDataOperationEnum::AtomicAdd; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + } else { - const auto kernel = kernel_moe_gemm_2lds; + throw std::runtime_error("todo: only v1 & v2 support now"); + } + } +#if 1 + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + const auto kernel = kernel_moe_gemm; RunKernel(kernel); } } - else - { - throw std::runtime_error("todo: only v1 & v2 support now"); - } - } -#if 1 - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - const auto kernel = kernel_moe_gemm; - RunKernel(kernel); - } - } #endif - + } return ave_time; }