From b05ec1b096560e994b9532bb3cf329ff8cbcd829 Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Sat, 12 Oct 2024 08:05:11 +0200 Subject: [PATCH] Implement GetWorkSpaceSize from BaseOperator. (#1564) [ROCm/composable_kernel commit: 29d384d0b2f266ba8fbf3f7728d2bba4f5a7b852] --- .../gpu/device/device_cgemm.hpp | 6 +++--- .../impl/device_cgemm_4gemm_xdl_cshuffle.hpp | 18 +++++++++++++++++- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp index 8484212118..44dedeeef9 100644 --- a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "device_base.hpp" @@ -31,13 +31,13 @@ struct DeviceCGemm : public BaseOperator CElementwiseOperation c_element_op, ck::index_t KBatch = 1) = 0; - virtual std::unique_ptr MakeInvokerPointer() = 0; + virtual std::unique_ptr MakeInvokerPointer() = 0; virtual std::size_t GetWorkspaceSize(index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, - index_t StrideC) = 0; + index_t StrideC) const = 0; }; template (base_arg); + + if(!parg) + { + std::ostringstream err; + err << "Provided argument pointer is not of an Argument class!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + return GetWorkspaceSize( + parg->M, parg->N, parg->K, parg->StrideA, parg->StrideB, parg->StrideC); + } }; } // namespace device