From aa11f2fb52c8c203c3ea39b5935f9f0570f3c890 Mon Sep 17 00:00:00 2001 From: Enrico Degregori <73224202+EnricoDeg@users.noreply.github.com> Date: Wed, 10 Sep 2025 17:03:29 +0200 Subject: [PATCH] Fix merge bug: add DeviceMoEGemmMXBPreShuffle again (#2816) [ROCm/composable_kernel commit: bbc8c7d99907b001106dd9fe5ad447dbeba97db4] --- .../gpu/device/device_gemm_multiple_d.hpp | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp index 3dff1b28c6..6769ba347e 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp @@ -149,6 +149,52 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator #endif }; +template +struct DeviceMoEGemmMXBPreShuffle : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + +#ifndef CK_CODE_GEN_RTC + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideAScale, + ck::index_t StrideB, + ck::index_t StrideBScale, + std::array StrideDs, + ck::index_t StrideE, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual int GetPreShuffleParameters() = 0; +#endif +}; + /// @brief Wrapper for backward compatibility that allows to use instances of /// DeviceGemmMultipleDSplitK in contexts where DeviceGemmMultipleD is expected. ///