From fe1648f0ce0065019018fedbda8a0dbbeb429073 Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Wed, 30 Apr 2025 13:33:31 +0800 Subject: [PATCH] fix device kernel selection. --- .../impl/device_moe_gemm_blockscale.hpp | 137 +++++++++++------- 1 file changed, 82 insertions(+), 55 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp index 2a2d4116ca..f707f1600b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp @@ -96,60 +96,60 @@ struct DeviceMoeGemmBlockScale { static constexpr index_t NumDTensor = DsDataType::Size(); using GridwiseGemm = GridwiseMoeGemmBlockScale< - ALayout, - BLayout, - DsLayout, - CLayout, - ADataType, - BDataType, - GemmAccDataType, - CShuffleDataType, - DsDataType, - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - GemmSpec, - BlockSize, - ScaleBlockM, - ScaleBlockN, - ScaleBlockK, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEShuffleBlockTransferScalarPerVectors, - BlkGemmPipeSched, - BlkGemmPipelineVer, - NSwizzle, - ComputeTypeA, - ComputeTypeB, - LDSTypeA, - LDSTypeB>; + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockM, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + NSwizzle, + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; using Argument = typename GridwiseGemm::Argument; @@ -317,9 +317,11 @@ struct DeviceMoeGemmBlockScale #if 1 else { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { - printf("here to do odd.\n"); const auto kernel = kernel_moe_gemm; RunKernel(kernel); } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + } } #endif