mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
add GetWorkSpaceSize to base arg (#253)
* add GetWorkSpaceSize to base arg and make an example on convnd_bwd_weight * remove redundant compute * use datatype and split k to check whether a workspace is used * remove unused computation for work space size
This commit is contained in:
@@ -40,6 +40,8 @@ struct BaseOperator
|
||||
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
|
||||
virtual std::string GetTypeString() const { return ""; }
|
||||
|
||||
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
|
||||
|
||||
virtual ~BaseOperator() {}
|
||||
};
|
||||
|
||||
|
||||
@@ -1175,6 +1175,57 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
static size_t GetWorkSpaceSize(const Argument& arg)
|
||||
{
|
||||
size_t WorkSpaceSize = 0;
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
|
||||
{
|
||||
WorkSpaceSize =
|
||||
arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] * sizeof(float);
|
||||
}
|
||||
}
|
||||
return WorkSpaceSize;
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static size_t GetWorkSpaceSize(const Argument& arg)
|
||||
{
|
||||
size_t WorkSpaceSize = 0;
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
|
||||
{
|
||||
WorkSpaceSize = arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] *
|
||||
arg.filter_spatial_lengths_[1] * sizeof(float);
|
||||
}
|
||||
}
|
||||
return WorkSpaceSize;
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static size_t GetWorkSpaceSize(const Argument& arg)
|
||||
{
|
||||
size_t WorkSpaceSize = 0;
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
|
||||
{
|
||||
WorkSpaceSize = arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] *
|
||||
arg.filter_spatial_lengths_[1] * arg.filter_spatial_lengths_[2] *
|
||||
sizeof(float);
|
||||
}
|
||||
}
|
||||
return WorkSpaceSize;
|
||||
}
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override final
|
||||
{
|
||||
return GetWorkSpaceSize<NumDimSpatial>(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
Reference in New Issue
Block a user