From 22c0a59fa9be1e9636442f9f60224ee1fd72e70b Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Wed, 10 Sep 2025 15:12:01 +0000 Subject: [PATCH] Merge commit 'bbc8c7d99907b001106dd9fe5ad447dbeba97db4' into develop --- .../gpu/device/device_gemm_multiple_d.hpp | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp index 3dff1b28c6..6769ba347e 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp @@ -149,6 +149,52 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator #endif }; +template +struct DeviceMoEGemmMXBPreShuffle : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + +#ifndef CK_CODE_GEN_RTC + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideAScale, + ck::index_t StrideB, + ck::index_t StrideBScale, + std::array StrideDs, + ck::index_t StrideE, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual int GetPreShuffleParameters() = 0; +#endif +}; + /// @brief Wrapper for backward compatibility that allows to use instances of /// DeviceGemmMultipleDSplitK in contexts where DeviceGemmMultipleD is expected. ///