mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Rangify constructor of HostTensorDescriptor & Tensor<> (#445)
* Rangify STL algorithms This commit adapts rangified std::copy(), std::fill() & std::transform() * Rangify check_err() By rangifying check_err(), we can not only compare values between std::vector<>s, but also compare any ranges which have same value type. * Allow constructing Tensor<> like a HostTensorDescriptor * Simplify Tensor<> object construction logics * Remove more unnecessary 'HostTensorDescriptor' objects * Re-format example code * Re-write more HostTensorDescriptor ctor call
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
#include "ck/utility/reduction_enums.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
@@ -253,10 +254,10 @@ int mean_meansquare_dual_reduce_test(size_t n,
|
||||
std::array<ck::index_t, NumOutputDim> i_outLengths;
|
||||
std::array<ck::index_t, NumOutputDim> i_outStrides;
|
||||
|
||||
std::copy(inLengths.begin(), inLengths.end(), i_inLengths.begin());
|
||||
std::copy(inStrides.begin(), inStrides.end(), i_inStrides.begin());
|
||||
std::copy(outLengths.begin(), outLengths.end(), i_outLengths.begin());
|
||||
std::copy(outStrides.begin(), outStrides.end(), i_outStrides.begin());
|
||||
ck::ranges::copy(inLengths, i_inLengths.begin());
|
||||
ck::ranges::copy(inStrides, i_inStrides.begin());
|
||||
ck::ranges::copy(outLengths, i_outLengths.begin());
|
||||
ck::ranges::copy(outStrides, i_outStrides.begin());
|
||||
|
||||
auto dual_reduce_op = DeviceDualReduce{};
|
||||
|
||||
@@ -305,8 +306,8 @@ int mean_meansquare_dual_reduce_test(size_t n,
|
||||
{
|
||||
mean_dev.FromDevice(mean.mData.data());
|
||||
meansquare_dev.FromDevice(meansquare.mData.data());
|
||||
pass = pass && ck::utils::check_err(mean.mData, mean_ref.mData);
|
||||
pass = pass && ck::utils::check_err(meansquare.mData, meansquare_ref.mData);
|
||||
pass = pass && ck::utils::check_err(mean, mean_ref);
|
||||
pass = pass && ck::utils::check_err(meansquare, meansquare_ref);
|
||||
};
|
||||
|
||||
return (pass ? 0 : 1);
|
||||
|
||||
Reference in New Issue
Block a user