Add constraint to Tensor<> templated methods

This commit is contained in:
Po-Yen, Chen
2022-08-19 03:27:41 -04:00
parent f3f61f836b
commit 463d15f9b5

View File

@@ -254,7 +254,7 @@ struct Tensor
Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {}
template <typename OutT>
Tensor<OutT> CopyAsType() const
std::enable_if_t<std::is_convertible_v<T, OutT>, Tensor<OutT>> CopyAsType() const
{
Tensor<OutT> ret(mDesc);
for(size_t i = 0; i < mData.size(); i++)
@@ -268,8 +268,8 @@ struct Tensor
Tensor(Tensor&& other) = default;
template <typename OtherT>
Tensor(const Tensor<OtherT>& other) : Tensor(other.template CopyAsType<T>())
template <typename FromT, typename = std::enable_if_t<std::is_convertible_v<FromT, T>>>
Tensor(const Tensor<FromT>& other) : Tensor(other.template CopyAsType<T>())
{
}