From 5be42bb39876865ac11bba856e175587f852aee5 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Wed, 5 Feb 2025 08:22:14 +0000 Subject: [PATCH] fix errors --- .../gpu/device/device_gemm_multiple_d.hpp | 45 ------------------- .../src/profile_gemm_multiply_multiply.cpp | 2 +- 2 files changed, 1 insertion(+), 46 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp index 403a1cb085..48fca67f56 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp @@ -96,51 +96,6 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator virtual std::unique_ptr MakeInvokerPointer() = 0; }; -// GEMM: -// input : A[M, K], B[K, N], -// input : D0[M, N], D1[M, N], ... -// output : E[M, N] -// C = a_op(A) * b_op(B) -// E = cde_op(C, D0, D1, ...) -// Assume: -// D0, D1, ... and E have the same layout -template -struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator -{ - static constexpr index_t NumDTensor = DsDataType::Size(); - - virtual std::unique_ptr - MakeArgumentPointer(const void* p_a, - const void* p_b, - std::array p_ds, - void* p_e, - ck::index_t M, - ck::index_t N, - ck::index_t K, - ck::index_t StrideA, - ck::index_t StrideB, - std::array StrideDs, - ck::index_t StrideE, - ck::index_t KBatch, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) = 0; - - virtual std::unique_ptr MakeInvokerPointer() = 0; - - virtual int GetPreShuffleParameters() = 0; -}; - } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/profiler/src/profile_gemm_multiply_multiply.cpp b/profiler/src/profile_gemm_multiply_multiply.cpp index 0bec53b045..5f5cf35af7 100644 --- a/profiler/src/profile_gemm_multiply_multiply.cpp +++ b/profiler/src/profile_gemm_multiply_multiply.cpp @@ -175,7 +175,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[]) else if(data_type == GemmDataType::INT8_INT8_BF16 && layout == GemmMatrixLayout::MK_NK_MN) { return profile( - F8{}, F8{}, F8{}, F32{}, F32{}, F32{}, F16{}, Row{}, Col{}, Row{}, Col{}, Row{}); + I8{}, I8{}, I8{}, I32{}, F32{}, F32{}, BF16{}, Row{}, Col{}, Row{}, Col{}, Row{}); } else if(data_type == GemmDataType::INT8_INT8_F16 && layout == GemmMatrixLayout::MK_NK_MN) {