mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Implement GetWorkSpaceSize from BaseOperator. (#1564)
This commit is contained in:
@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
[[maybe_unused]] index_t K,
|
||||
[[maybe_unused]] index_t StrideA,
|
||||
[[maybe_unused]] index_t StrideB,
|
||||
index_t StrideC) override
|
||||
index_t StrideC) const override
|
||||
{
|
||||
return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC);
|
||||
}
|
||||
|
||||
std::size_t GetWorkSpaceSize(const BaseArgument* base_arg) const override
|
||||
{
|
||||
const auto* parg = dynamic_cast<const Argument*>(base_arg);
|
||||
|
||||
if(!parg)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Provided argument pointer is not of an Argument class!"
|
||||
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
return GetWorkspaceSize(
|
||||
parg->M, parg->N, parg->K, parg->StrideA, parg->StrideB, parg->StrideC);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
Reference in New Issue
Block a user