From 234d70f97dca5de7c6f9dee4bbf5d50e50d81691 Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Tue, 4 Mar 2025 11:12:33 +0800 Subject: [PATCH] Revert "fix build" This reverts commit f83e7e138af1440dfa0785f6859574be01e1aa54. --- .../gpu/device/device_gemm_multiple_d.hpp | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp index 48fca67f56..403a1cb085 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp @@ -96,6 +96,51 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator virtual std::unique_ptr MakeInvokerPointer() = 0; }; +// GEMM: +// input : A[M, K], B[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + std::array StrideDs, + ck::index_t StrideE, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual int GetPreShuffleParameters() = 0; +}; + } // namespace device } // namespace tensor_operation } // namespace ck