add GetWorkSpaceSize to base arg (#253)

* add GetWorkSpaceSize to base arg and make an example on convnd_bwd_weight

* remove redundant compute

* use datatype and split k to check whether a workspace is used

* remove unused computation for work space size
This commit is contained in:
Shaojie WANG
2022-05-25 00:13:00 +08:00
committed by GitHub
parent ba58a93f60
commit 0d08cf1893
3 changed files with 100 additions and 9 deletions

View File

@@ -40,6 +40,8 @@ struct BaseOperator
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
virtual std::string GetTypeString() const { return ""; }
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
virtual ~BaseOperator() {}
};

View File

@@ -1175,6 +1175,57 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
return str.str();
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static size_t GetWorkSpaceSize(const Argument& arg)
{
size_t WorkSpaceSize = 0;
if(arg.k_batch_ > 1)
{
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
{
WorkSpaceSize =
arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] * sizeof(float);
}
}
return WorkSpaceSize;
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static size_t GetWorkSpaceSize(const Argument& arg)
{
size_t WorkSpaceSize = 0;
if(arg.k_batch_ > 1)
{
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
{
WorkSpaceSize = arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] *
arg.filter_spatial_lengths_[1] * sizeof(float);
}
}
return WorkSpaceSize;
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static size_t GetWorkSpaceSize(const Argument& arg)
{
size_t WorkSpaceSize = 0;
if(arg.k_batch_ > 1)
{
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
{
WorkSpaceSize = arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] *
arg.filter_spatial_lengths_[1] * arg.filter_spatial_lengths_[2] *
sizeof(float);
}
}
return WorkSpaceSize;
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override final
{
return GetWorkSpaceSize<NumDimSpatial>(*dynamic_cast<const Argument*>(p_arg));
}
};
} // namespace device