diff --git a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp index 8484212118..44dedeeef9 100644 --- a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "device_base.hpp" @@ -31,13 +31,13 @@ struct DeviceCGemm : public BaseOperator CElementwiseOperation c_element_op, ck::index_t KBatch = 1) = 0; - virtual std::unique_ptr MakeInvokerPointer() = 0; + virtual std::unique_ptr MakeInvokerPointer() = 0; virtual std::size_t GetWorkspaceSize(index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, - index_t StrideC) = 0; + index_t StrideC) const = 0; }; template (base_arg); + + if(!parg) + { + std::ostringstream err; + err << "Provided argument pointer is not of an Argument class!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + return GetWorkspaceSize( + parg->M, parg->N, parg->K, parg->StrideA, parg->StrideB, parg->StrideC); + } }; } // namespace device