Rangify constructor of HostTensorDescriptor & Tensor<> (#445)

* Rangify STL algorithms

This commit adapts rangified std::copy(), std::fill() & std::transform()

* Rangify check_err()

By rangifying check_err(), we can not only compare values between
std::vector<>s, but also compare any ranges which have same value
type.

* Allow constructing Tensor<> like a HostTensorDescriptor

* Simplify Tensor<> object construction logics

* Remove more unnecessary 'HostTensorDescriptor' objects

* Re-format example code

* Re-write more HostTensorDescriptor ctor call
This commit is contained in:
Po Yen Chen
2022-11-12 01:36:01 +08:00
committed by GitHub
parent 37f2e91832
commit 4a2a56c22f
103 changed files with 657 additions and 649 deletions

View File

@@ -86,12 +86,10 @@ int main()
constexpr auto index_length = 2048;
constexpr AccDataType epsilon = 1e-4;
auto f_host_tensor_desc_1d = [](std::size_t len_) {
return HostTensorDescriptor(std::vector<std::size_t>({len_}));
};
auto f_host_tensor_desc_1d = [](std::size_t len_) { return HostTensorDescriptor({len_}); };
auto f_host_tensor_desc_2d = [](std::size_t rows_, std::size_t cols_) {
return HostTensorDescriptor(std::vector<std::size_t>({rows_, cols_}));
return HostTensorDescriptor({rows_, cols_});
};
using ReferenceInstance =
@@ -203,8 +201,7 @@ int main()
ref_invoker.Run(ref_argument);
out_dev.FromDevice(out_from_dev.mData.data());
pass &= ck::utils::check_err(
out_from_dev.mData, out.mData, "Error: Incorrect results", 1e-3, 1e-3);
pass &= ck::utils::check_err(out_from_dev, out, "Error: Incorrect results", 1e-3, 1e-3);
}
double total_read = current_dim * index_length * 3 * sizeof(EmbType) +