mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Rangify constructor of HostTensorDescriptor & Tensor<> (#445)
* Rangify STL algorithms This commit adapts rangified std::copy(), std::fill() & std::transform() * Rangify check_err() By rangifying check_err(), we can not only compare values between std::vector<>s, but also compare any ranges which have same value type. * Allow constructing Tensor<> like a HostTensorDescriptor * Simplify Tensor<> object construction logics * Remove more unnecessary 'HostTensorDescriptor' objects * Re-format example code * Re-write more HostTensorDescriptor ctor call
This commit is contained in:
@@ -86,12 +86,10 @@ int main()
|
||||
constexpr auto index_length = 2048;
|
||||
constexpr AccDataType epsilon = 1e-4;
|
||||
|
||||
auto f_host_tensor_desc_1d = [](std::size_t len_) {
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({len_}));
|
||||
};
|
||||
auto f_host_tensor_desc_1d = [](std::size_t len_) { return HostTensorDescriptor({len_}); };
|
||||
|
||||
auto f_host_tensor_desc_2d = [](std::size_t rows_, std::size_t cols_) {
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({rows_, cols_}));
|
||||
return HostTensorDescriptor({rows_, cols_});
|
||||
};
|
||||
|
||||
using ReferenceInstance =
|
||||
@@ -203,8 +201,7 @@ int main()
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
out_dev.FromDevice(out_from_dev.mData.data());
|
||||
pass &= ck::utils::check_err(
|
||||
out_from_dev.mData, out.mData, "Error: Incorrect results", 1e-3, 1e-3);
|
||||
pass &= ck::utils::check_err(out_from_dev, out, "Error: Incorrect results", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
double total_read = current_dim * index_length * 3 * sizeof(EmbType) +
|
||||
|
||||
Reference in New Issue
Block a user