diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index 6d21752ba8..1f488fef83 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -134,7 +134,7 @@ static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType); static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; -static constexpr ck::index_t ActOP = 0; +static constexpr ck::index_t Act_OP = 0; // 0: gelu, 1: silu, 2: swiglu // using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index 744a164a2f..d2db33cd57 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -59,8 +59,8 @@ struct MulABScale __host__ __device__ constexpr void operator()( EDataType& e, const EDataType& c, const float& d0, const float& d1) const { - (void) d0; - (void) d1; + (void)d0; + (void)d1; #if CK_USE_PK4_LAYOUT_SHUFFLE e = ck::type_convert(c * 16); #else @@ -71,8 +71,8 @@ struct MulABScale __host__ __device__ constexpr void operator()( EDataType& e, const float& c, const float& d0, const float& d1) const { - (void) d0; - (void) d1; + (void)d0; + (void)d1; #if CK_USE_PK4_LAYOUT_SHUFFLE e = ck::type_convert(c * 16); #else @@ -81,7 +81,6 @@ struct MulABScale } }; - using CDEElementOp = MulABScale; #if 1 @@ -125,7 +124,8 @@ using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 128; -static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t Act_OP = 0; // 0: gelu, 1: silu, 2: swiglu // clang-format off using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< Row, Col, DsLayout, ELayout, @@ -138,7 +138,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 2, 1, S<1, 32, 1, 8>, S<8, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, true, ck::index_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Act_OP, Nswizzle, true, true, ck::index_t, A0DataType>; // clang-format on int main(int argc, char* argv[]) @@ -152,7 +152,7 @@ int main(int argc, char* argv[]) // experts = 8 // per expert: // GEMM shape - ck::index_t N = 14336 * 2; + ck::index_t N = 14336; ck::index_t K = 4096; ck::index_t experts = 8; ck::index_t sorted_tile_num = 16; @@ -402,8 +402,7 @@ int main(int argc, char* argv[]) b_element_op, cde_element_op); - if(!device_op.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!device_op.IsSupportedArgument(argument)) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does " @@ -440,7 +439,8 @@ int main(int argc, char* argv[]) AccDataType, PassThrough, PassThrough, - PassThrough>; + PassThrough, + Act_OP>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp index e0ee6a4459..f2335e9175 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -59,7 +59,7 @@ struct MulABScaleExpertWeight // for real kernel use template <> __host__ __device__ constexpr void operator()( - EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const + EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const { (void)d0; (void)d1; @@ -151,7 +151,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, 2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, false, ck::index_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, ck::index_t, A0DataType>; // clang-format on int main(int argc, char* argv[]) @@ -409,8 +409,7 @@ int main(int argc, char* argv[]) b_element_op, cde_element_op); - if(!device_op.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!device_op.IsSupportedArgument(argument)) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does "