From dc58110b06d50a24a3734aefc1612913bba42090 Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Fri, 18 Apr 2025 10:45:49 +0800 Subject: [PATCH] enable do top k weights in moe stage1 gemm (#2094) * add switch for mul topk weights * fix bf16/f16 bugs * complete [ROCm/composable_kernel commit: bcf5bb41be976d948b504f3d66c29e5baa82618a] --- .../moe_gemm1_xdl_fp8.cpp | 64 +++++++++++-- .../moe_gemm1_xdl_pk_i4.cpp | 63 +++++++++++-- .../moe_gemm2_xdl_fp8.cpp | 8 +- .../moe_gemm2_xdl_pk_i4.cpp | 8 +- .../gpu/device/impl/device_moe_gemm.hpp | 8 +- .../gpu/grid/gridwise_moe_gemm.hpp | 93 +++++++++++-------- .../cpu/reference_moe_gemm.hpp | 15 ++- .../cpu/reference_moe_gemm2.hpp | 12 ++- 8 files changed, 203 insertions(+), 68 deletions(-) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index 66825edcf9..f594080755 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -39,14 +39,16 @@ using AccDataType = F32; using CShuffleDataType = F32; using D0DataType = F32; using D1DataType = F32; -using DsDataType = ck::Tuple; +using D2DataType = F32; +using DsDataType = ck::Tuple; using A0Layout = Row; using B0Layout = Col; using ELayout = Row; using D0Layout = Row; using D1Layout = Col; -using DsLayout = ck::Tuple; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; // for gate, a_scale, b_scale struct MulABScale @@ -83,9 +85,36 @@ struct MulABScaleSilu } }; +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for real kernel use + // warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside. + // tofix:felix + (void)d2; + e = ck::type_convert(c * d1 * d0); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + e = ck::type_convert(c * d0 * d1 * d2); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; // combine MulRoutedWeight = true // using DsLayout = DsLayoutGate; // using DsDataType = DsDataTypeGate; -using CDEElementOp = MulABScale; +// using CDEElementOp = MulABScale; // combine MulRoutedWeight = false // using CDEElementOp = MulABScaleSiluMulGate; @@ -133,11 +162,13 @@ static constexpr ck::index_t NPerBlock = 128; static constexpr ck::index_t MNPerXDL = 32; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); static constexpr ck::index_t Nswizzle = true; +static constexpr bool MulRoutedWeight = false; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType); static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; +static constexpr ck::index_t D2Vec = 1; // using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off @@ -157,8 +188,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 2, 1, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>; + 2, 1, S<1, 32, 1, 8>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, MulRoutedWeight, A0DataType>; // clang-format on @@ -224,7 +255,7 @@ int main(int argc, char* argv[]) ck::index_t StrideB = K; ck::index_t StrideE = N; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto StrideDs = std::array{1, 0}; + constexpr auto StrideDs = std::array{0, 0, 0}; ck::index_t KBatch = 1; @@ -266,6 +297,7 @@ int main(int argc, char* argv[]) Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); Tensor e_t_n_device_result( HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); @@ -273,6 +305,7 @@ int main(int argc, char* argv[]) std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; switch(init_method) @@ -283,24 +316,28 @@ int main(int argc, char* argv[]) b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); d1_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{-2, 2}); break; case 2: a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{}); break; case 3: a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{}); break; default: a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); @@ -310,6 +347,7 @@ int main(int argc, char* argv[]) DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); // a0_t_k.savetxt("a.txt"); // d0_t_n.savetxt("d0_t_n.txt", "int"); @@ -320,6 +358,7 @@ int main(int argc, char* argv[]) a0_device_buf.ToDevice(a0_t_k.mData.data()); d0_device_buf.ToDevice(d0_t_n.mData.data()); d1_device_buf.ToDevice(d1_e_n.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; @@ -342,7 +381,8 @@ int main(int argc, char* argv[]) a0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(), std::array{d0_device_buf.GetDeviceBuffer(), - d1_device_buf.GetDeviceBuffer()}, + d1_device_buf.GetDeviceBuffer(), + d2_device_buf.GetDeviceBuffer()}, e_device_buf.GetDeviceBuffer(), tokens, topk, @@ -392,10 +432,12 @@ int main(int argc, char* argv[]) using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm; + PassThrough, + MulRoutedWeight>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); @@ -406,6 +448,7 @@ int main(int argc, char* argv[]) a0_t_k, b0_e_n_k, c_t_k_n, + d2_e_n, PassThrough{}, PassThrough{}, PassThrough{}); @@ -428,7 +471,8 @@ int main(int argc, char* argv[]) cde_element_op(e_t_n_host_result(t, topk_id, n), c_t_k_n(t, topk_id, n), d0_t_n(t, n), - d1_e_n(e, n)); + d1_e_n(e, n), + 1.f); } } diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index a25d1b5fa3..fb8a8b9826 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -39,14 +39,15 @@ using AccDataType = F32; using CShuffleDataType = F32; using D0DataType = F32; using D1DataType = F32; -using DsDataType = ck::Tuple; +using D2DataType = F32; +using DsDataType = ck::Tuple; using A0Layout = Row; using B0Layout = Col; using ELayout = Row; using D0Layout = Row; using D1Layout = Col; -using DsLayout = ck::Tuple; +using DsLayout = ck::Tuple; // for gate, a_scale, b_scale struct MulABScale @@ -91,7 +92,39 @@ struct MulABScaleSilu } }; -using CDEElementOp = MulABScale; +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + (void)d2; + +#if CK_USE_PK4_LAYOUT_SHUFFLE + e = ck::type_convert(c * d1 * d0 * 16); +#else + e = ck::type_convert(c * d1 * d0); +#endif + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu +#if CK_USE_PK4_LAYOUT_SHUFFLE + e = ck::type_convert(c * d0 * d1 * d2 * 16); +#else + e = ck::type_convert(c * d0 * d1 * d2); +#endif + } +}; + +using CDEElementOp = MulABScaleExpertWeight; #if 1 void preShuffleBuffer(const I4* src, I4* dst, int N, int K, int NXdl) @@ -164,6 +197,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< #else static constexpr ck::index_t MPerBlock = 128; static constexpr ck::index_t Nswizzle = false; +static constexpr bool MulRoutedWeight = false; // clang-format off using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< Row, Col, DsLayout, ELayout, @@ -175,8 +209,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, - 1, 1, S<1, 32, 1, 8>, S<8, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>; + 1, 1, S<1, 32, 1, 8>, S<8, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, MulRoutedWeight, A0DataType>; // clang-format on #endif @@ -265,6 +299,7 @@ int main(int argc, char* argv[]) Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); Tensor e_t_n_device_result( HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); @@ -283,18 +318,21 @@ int main(int argc, char* argv[]) b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); d1_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{-2, 2}); break; case 2: a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{}); break; default: a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); @@ -304,6 +342,7 @@ int main(int argc, char* argv[]) DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize() / 2); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); @@ -312,6 +351,7 @@ int main(int argc, char* argv[]) a0_device_buf.ToDevice(a0_t_k.mData.data()); d0_device_buf.ToDevice(d0_t_n.mData.data()); d1_device_buf.ToDevice(d1_e_n.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; @@ -424,7 +464,8 @@ int main(int argc, char* argv[]) a0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(), std::array{d0_device_buf.GetDeviceBuffer(), - d1_device_buf.GetDeviceBuffer()}, + d1_device_buf.GetDeviceBuffer(), + d2_device_buf.GetDeviceBuffer()}, e_device_buf.GetDeviceBuffer(), tokens, topk, @@ -480,10 +521,12 @@ int main(int argc, char* argv[]) using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm; + PassThrough, + MulRoutedWeight>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); @@ -494,6 +537,7 @@ int main(int argc, char* argv[]) a0_t_k, b0_e_n_k, c_t_k_n, + d2_e_n, PassThrough{}, PassThrough{}, PassThrough{}); @@ -516,7 +560,8 @@ int main(int argc, char* argv[]) cde_element_op(e_t_n_host_result(t, topk_id, n), c_t_k_n(t, topk_id, n), d0_t_n(t, n), - d1_e_n(e, n)); + d1_e_n(e, n), + 1.f); } } diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp index 0d12441016..04f10b53ae 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -135,6 +135,7 @@ static constexpr ck::index_t EVec = 2; static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; static constexpr ck::index_t D2Vec = 1; +static constexpr bool MulRoutedWeight = false; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off ///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -164,7 +165,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| 2, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, MulRoutedWeight, A0DataType>; // kernel 2: 128->32x128x128 // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; @@ -409,7 +410,8 @@ int main(int argc, char* argv[]) AccDataType, PassThrough, PassThrough, - CDEElementOp>; + CDEElementOp, + MulRoutedWeight>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp index 8c2c70b4a1..ba4e40151f 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -138,6 +138,7 @@ static constexpr ck::index_t EVec = 2; static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; static constexpr ck::index_t D2Vec = 1; +static constexpr bool MulRoutedWeight = true; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, @@ -149,7 +150,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, 1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, MulRoutedWeight, A0DataType>; // clang-format on int main(int argc, char* argv[]) @@ -455,7 +456,8 @@ int main(int argc, char* argv[]) AccDataType, PassThrough, PassThrough, - CDEElementOp>; + CDEElementOp, + MulRoutedWeight>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp index f3fc1aaa9f..03db4bdd41 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -67,6 +67,7 @@ template ; RunKernel(kernel); } @@ -280,6 +282,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } @@ -295,6 +298,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } @@ -305,6 +309,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } @@ -325,6 +330,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 1924c27b2b..a2d1114bbe 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -31,6 +31,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -44,19 +45,22 @@ __global__ void auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run(karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -67,6 +71,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -81,21 +86,23 @@ __global__ void auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm:: - template Run_2Lds( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - p_shared1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds(karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1134,8 +1141,9 @@ struct GridwiseMoeGemm template + bool IsInputGemm = true, + bool MulRoutedWeight = true, + TailNumber TailNum = TailNumber::Odd> __device__ static void Run(const index_t* p_sorted_token_ids, const index_t* p_sorted_expert_ids, const index_t* p_max_token_id, @@ -1492,7 +1500,7 @@ struct GridwiseMoeGemm using CDEBlockTransferCluster = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; // hack fix felix + constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< ThisThreadBlock, decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), @@ -1579,10 +1587,13 @@ struct GridwiseMoeGemm { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - else + if constexpr(MulRoutedWeight) { const float* p_sorted_weights_2 = p_ds_grid[I2]; - weight = weight * p_sorted_weights_2[c_token_pos + m0]; + if constexpr(sizeof(ADataType) < 2) + weight = p_sorted_weights_2[c_token_pos + m0] * weight; + else + weight = p_sorted_weights_2[c_token_pos + m0]; } scatter_offsets(m0) = token_offset * problem.N; scatter_weights(m0) = weight; @@ -1632,8 +1643,9 @@ struct GridwiseMoeGemm template + bool IsInputGemm = true, + bool MulRoutedWeight = true, + TailNumber TailNum = TailNumber::Odd> __device__ static void Run_2Lds(const index_t* p_sorted_token_ids, const index_t* p_sorted_expert_ids, const index_t* p_max_token_id, @@ -1998,7 +2010,7 @@ struct GridwiseMoeGemm using CDEBlockTransferCluster = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; // hack fix felix + constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< ThisThreadBlock, decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), @@ -2086,10 +2098,13 @@ struct GridwiseMoeGemm { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - else + if constexpr(MulRoutedWeight) { const float* p_sorted_weights_2 = p_ds_grid[I2]; - weight = weight * p_sorted_weights_2[c_token_pos + m0]; + if constexpr(sizeof(ADataType) < 2) + weight = p_sorted_weights_2[c_token_pos + m0] * weight; + else + weight = p_sorted_weights_2[c_token_pos + m0]; } scatter_offsets(m0) = token_offset * problem.N; scatter_weights(m0) = weight; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp index af735925ed..72c9dc86ac 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -18,10 +18,12 @@ namespace host { template struct ReferenceMoeGemm : public device::BaseOperator @@ -36,6 +38,7 @@ struct ReferenceMoeGemm : public device::BaseOperator const Tensor& a_t_k, const Tensor& b_e_n_k, Tensor& c_t_k_n, + const Tensor& d2, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) @@ -46,6 +49,7 @@ struct ReferenceMoeGemm : public device::BaseOperator a_t_k_{a_t_k}, b_e_n_k_{b_e_n_k}, c_t_k_n_{c_t_k_n}, + d2_{d2}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} @@ -59,6 +63,7 @@ struct ReferenceMoeGemm : public device::BaseOperator const Tensor& a_t_k_; const Tensor& b_e_n_k_; Tensor& c_t_k_n_; + const Tensor& d2_; AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; @@ -81,6 +86,7 @@ struct ReferenceMoeGemm : public device::BaseOperator const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24; const int e = arg.expert_ids_(m / arg.sorted_tile_size_); const int token_cnt = arg.a_t_k_.mDesc.GetLengths()[0]; + D2DataType v_topk_w = arg.d2_(m, 0); // expert if(t < token_cnt) { for(int k = 0; k < K; ++k) @@ -128,6 +134,11 @@ struct ReferenceMoeGemm : public device::BaseOperator } CDataType v_c{0}; + if constexpr(MulRoutedWeight) + { + v_acc *= v_topk_w; + } + arg.c_element_op_(v_c, v_acc); arg.c_t_k_n_(t, topk_id, n) = v_c; @@ -164,6 +175,7 @@ struct ReferenceMoeGemm : public device::BaseOperator const Tensor& a_t_k, const Tensor& b_e_n_k, Tensor& c_t_k_n, + const Tensor& d2, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) @@ -175,6 +187,7 @@ struct ReferenceMoeGemm : public device::BaseOperator a_t_k, b_e_n_k, c_t_k_n, + d2, a_element_op, b_element_op, c_element_op}; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp index 1e8a086bc4..fb5c71e30a 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -25,6 +25,7 @@ template struct ReferenceMoeGemm2 : public device::BaseOperator @@ -143,7 +144,14 @@ struct ReferenceMoeGemm2 : public device::BaseOperator CDataType v_c{0}; D0DataType v_d0 = arg.d0_(m, n); // a D0DataType v_d1 = arg.d1_(e, n); // b - arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w); + if constexpr(MulRoutedWeight) + { + arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w); + } + else + { + arg.c_element_op_(v_c, v_acc, v_d0, v_d1, 1.f); + } arg.c_t_n_(t, n) += v_c; } };