mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
Rangify constructor of HostTensorDescriptor & Tensor<> (#445)
* Rangify STL algorithms
This commit adapts rangified std::copy(), std::fill() & std::transform()
* Rangify check_err()
By rangifying check_err(), we can not only compare values between
std::vector<>s, but also compare any ranges which have same value
type.
* Allow constructing Tensor<> like a HostTensorDescriptor
* Simplify Tensor<> object construction logics
* Remove more unnecessary 'HostTensorDescriptor' objects
* Re-format example code
* Re-write more HostTensorDescriptor ctor call
[ROCm/composable_kernel commit: 4a2a56c22f]
This commit is contained in:
@@ -44,8 +44,8 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
|
||||
size_t M = acc.mDesc.GetLengths()[0];
|
||||
size_t N = acc.mDesc.GetLengths()[1];
|
||||
|
||||
Tensor<ComputeDataType> avg_acc_sq(HostTensorDescriptor(std::vector<size_t>({M})));
|
||||
Tensor<ComputeDataType> avg_acc(HostTensorDescriptor(std::vector<size_t>({M})));
|
||||
Tensor<ComputeDataType> avg_acc_sq({M});
|
||||
Tensor<ComputeDataType> avg_acc({M});
|
||||
Tensor<ComputeDataType> acc_layernorm(acc);
|
||||
|
||||
// reduce N dim
|
||||
|
||||
43
library/include/ck/library/utility/algorithm.hpp
Normal file
43
library/include/ck/library/utility/algorithm.hpp
Normal file
@@ -0,0 +1,43 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
namespace ck {
|
||||
namespace ranges {
|
||||
template <typename InputRange, typename OutputIterator>
|
||||
auto copy(InputRange&& range, OutputIterator iter)
|
||||
-> decltype(std::copy(std::begin(std::forward<InputRange>(range)),
|
||||
std::end(std::forward<InputRange>(range)),
|
||||
iter))
|
||||
{
|
||||
return std::copy(std::begin(std::forward<InputRange>(range)),
|
||||
std::end(std::forward<InputRange>(range)),
|
||||
iter);
|
||||
}
|
||||
|
||||
template <typename T, typename OutputRange>
|
||||
auto fill(OutputRange&& range, const T& init)
|
||||
-> std::void_t<decltype(std::fill(std::begin(std::forward<OutputRange>(range)),
|
||||
std::end(std::forward<OutputRange>(range)),
|
||||
init))>
|
||||
{
|
||||
std::fill(std::begin(std::forward<OutputRange>(range)),
|
||||
std::end(std::forward<OutputRange>(range)),
|
||||
init);
|
||||
}
|
||||
|
||||
template <typename InputRange, typename OutputIterator, typename UnaryOperation>
|
||||
auto transform(InputRange&& range, OutputIterator iter, UnaryOperation unary_op)
|
||||
-> decltype(std::transform(std::begin(range), std::end(range), iter, unary_op))
|
||||
{
|
||||
return std::transform(std::begin(range), std::end(range), iter, unary_op);
|
||||
}
|
||||
|
||||
} // namespace ranges
|
||||
} // namespace ck
|
||||
@@ -15,18 +15,22 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/span.hpp"
|
||||
#include "ck/utility/type.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
#include "ck/library/utility/ranges.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace utils {
|
||||
|
||||
template <typename T>
|
||||
typename std::enable_if<std::is_floating_point<T>::value && !std::is_same<T, half_t>::value,
|
||||
bool>::type
|
||||
check_err(const std::vector<T>& out,
|
||||
const std::vector<T>& ref,
|
||||
template <typename Range, typename RefRange>
|
||||
typename std::enable_if<
|
||||
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_floating_point_v<ranges::range_value_t<Range>> &&
|
||||
!std::is_same_v<ranges::range_value_t<Range>, half_t>,
|
||||
bool>::type
|
||||
check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
double rtol = 1e-5,
|
||||
double atol = 3e-6)
|
||||
@@ -44,15 +48,17 @@ check_err(const std::vector<T>& out,
|
||||
double max_err = std::numeric_limits<double>::min();
|
||||
for(std::size_t i = 0; i < ref.size(); ++i)
|
||||
{
|
||||
err = std::abs(out[i] - ref[i]);
|
||||
if(err > atol + rtol * std::abs(ref[i]) || !std::isfinite(out[i]) || !std::isfinite(ref[i]))
|
||||
const double o = *std::next(std::begin(out), i);
|
||||
const double r = *std::next(std::begin(ref), i);
|
||||
err = std::abs(o - r);
|
||||
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
{
|
||||
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
||||
<< "] != ref[" << i << "]: " << out[i] << " != " << ref[i] << std::endl;
|
||||
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
||||
}
|
||||
res = false;
|
||||
}
|
||||
@@ -64,10 +70,13 @@ check_err(const std::vector<T>& out,
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename std::enable_if<std::is_same<T, bhalf_t>::value, bool>::type
|
||||
check_err(const std::vector<T>& out,
|
||||
const std::vector<T>& ref,
|
||||
template <typename Range, typename RefRange>
|
||||
typename std::enable_if<
|
||||
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_same_v<ranges::range_value_t<Range>, bhalf_t>,
|
||||
bool>::type
|
||||
check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
double rtol = 1e-3,
|
||||
double atol = 1e-3)
|
||||
@@ -86,9 +95,9 @@ check_err(const std::vector<T>& out,
|
||||
double max_err = std::numeric_limits<float>::min();
|
||||
for(std::size_t i = 0; i < ref.size(); ++i)
|
||||
{
|
||||
double o = type_convert<float>(out[i]);
|
||||
double r = type_convert<float>(ref[i]);
|
||||
err = std::abs(o - r);
|
||||
const double o = type_convert<float>(*std::next(std::begin(out), i));
|
||||
const double r = type_convert<float>(*std::next(std::begin(ref), i));
|
||||
err = std::abs(o - r);
|
||||
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
@@ -108,10 +117,13 @@ check_err(const std::vector<T>& out,
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename std::enable_if<std::is_same_v<T, half_t>, bool>::type
|
||||
check_err(span<const T> out,
|
||||
span<const T> ref,
|
||||
template <typename Range, typename RefRange>
|
||||
typename std::enable_if<
|
||||
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_same_v<ranges::range_value_t<Range>, half_t>,
|
||||
bool>::type
|
||||
check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
double rtol = 1e-3,
|
||||
double atol = 1e-3)
|
||||
@@ -126,12 +138,12 @@ check_err(span<const T> out,
|
||||
bool res{true};
|
||||
int err_count = 0;
|
||||
double err = 0;
|
||||
double max_err = std::numeric_limits<T>::min();
|
||||
double max_err = std::numeric_limits<ranges::range_value_t<Range>>::min();
|
||||
for(std::size_t i = 0; i < ref.size(); ++i)
|
||||
{
|
||||
double o = type_convert<float>(out[i]);
|
||||
double r = type_convert<float>(ref[i]);
|
||||
err = std::abs(o - r);
|
||||
const double o = type_convert<float>(*std::next(std::begin(out), i));
|
||||
const double r = type_convert<float>(*std::next(std::begin(ref), i));
|
||||
err = std::abs(o - r);
|
||||
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
@@ -151,26 +163,17 @@ check_err(span<const T> out,
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename std::enable_if<std::is_same<T, half_t>::value, bool>::type
|
||||
check_err(const std::vector<T>& out,
|
||||
const std::vector<T>& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
double rtol = 1e-3,
|
||||
double atol = 1e-3)
|
||||
{
|
||||
return check_err(span<const T>{out}, span<const T>{ref}, msg, rtol, atol);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<(std::is_integral_v<T> && !std::is_same_v<T, bhalf_t>)
|
||||
template <typename Range, typename RefRange>
|
||||
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_integral_v<ranges::range_value_t<Range>> &&
|
||||
!std::is_same_v<ranges::range_value_t<Range>, bhalf_t>)
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
|| std::is_same_v<T, int4_t>
|
||||
|| std::is_same_v<ranges::range_value_t<Range>, int4_t>
|
||||
#endif
|
||||
,
|
||||
bool>
|
||||
check_err(const std::vector<T>& out,
|
||||
const std::vector<T>& ref,
|
||||
check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
double = 0,
|
||||
double atol = 0)
|
||||
@@ -188,9 +191,9 @@ check_err(const std::vector<T>& out,
|
||||
int64_t max_err = std::numeric_limits<int64_t>::min();
|
||||
for(std::size_t i = 0; i < ref.size(); ++i)
|
||||
{
|
||||
int64_t o = out[i];
|
||||
int64_t r = ref[i];
|
||||
err = std::abs(o - r);
|
||||
const int64_t o = *std::next(std::begin(out), i);
|
||||
const int64_t r = *std::next(std::begin(ref), i);
|
||||
err = std::abs(o - r);
|
||||
|
||||
if(err > atol)
|
||||
{
|
||||
|
||||
@@ -14,6 +14,9 @@
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/span.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/ranges.hpp"
|
||||
|
||||
template <typename Range>
|
||||
std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
|
||||
{
|
||||
@@ -84,10 +87,10 @@ struct HostTensorDescriptor
|
||||
this->CalculateStrides();
|
||||
}
|
||||
|
||||
template <typename Range,
|
||||
template <typename Lengths,
|
||||
typename = std::enable_if_t<
|
||||
std::is_convertible_v<decltype(*std::begin(std::declval<Range>())), std::size_t>>>
|
||||
HostTensorDescriptor(const Range& lens) : mLens(lens.begin(), lens.end())
|
||||
std::is_convertible_v<ck::ranges::range_value_t<Lengths>, std::size_t>>>
|
||||
HostTensorDescriptor(const Lengths& lens) : mLens(lens.begin(), lens.end())
|
||||
{
|
||||
this->CalculateStrides();
|
||||
}
|
||||
@@ -102,13 +105,12 @@ struct HostTensorDescriptor
|
||||
{
|
||||
}
|
||||
|
||||
template <
|
||||
typename Range1,
|
||||
typename Range2,
|
||||
typename = std::enable_if_t<
|
||||
std::is_convertible_v<decltype(*std::begin(std::declval<Range1>())), std::size_t> &&
|
||||
std::is_convertible_v<decltype(*std::begin(std::declval<Range2>())), std::size_t>>>
|
||||
HostTensorDescriptor(const Range1& lens, const Range2& strides)
|
||||
template <typename Lengths,
|
||||
typename Strides,
|
||||
typename = std::enable_if_t<
|
||||
std::is_convertible_v<ck::ranges::range_value_t<Lengths>, std::size_t> &&
|
||||
std::is_convertible_v<ck::ranges::range_value_t<Strides>, std::size_t>>>
|
||||
HostTensorDescriptor(const Lengths& lens, const Strides& strides)
|
||||
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
|
||||
{
|
||||
}
|
||||
@@ -244,14 +246,20 @@ struct Tensor
|
||||
{
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
Tensor(std::vector<X> lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize())
|
||||
template <typename X, typename Y>
|
||||
Tensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
|
||||
: mDesc(lens, strides), mData(mDesc.GetElementSpaceSize())
|
||||
{
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
Tensor(std::vector<X> lens, std::vector<Y> strides)
|
||||
: mDesc(lens, strides), mData(mDesc.GetElementSpaceSize())
|
||||
template <typename Lengths>
|
||||
Tensor(const Lengths& lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize())
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Lengths, typename Strides>
|
||||
Tensor(const Lengths& lens, const Strides& strides)
|
||||
: mDesc(lens, strides), mData(GetElementSpaceSize())
|
||||
{
|
||||
}
|
||||
|
||||
@@ -261,10 +269,10 @@ struct Tensor
|
||||
Tensor<OutT> CopyAsType() const
|
||||
{
|
||||
Tensor<OutT> ret(mDesc);
|
||||
for(size_t i = 0; i < mData.size(); i++)
|
||||
{
|
||||
ret.mData[i] = ck::type_convert<OutT>(mData[i]);
|
||||
}
|
||||
|
||||
ck::ranges::transform(
|
||||
mData, ret.mData.begin(), [](auto value) { return ck::type_convert<OutT>(value); });
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -294,13 +302,7 @@ struct Tensor
|
||||
|
||||
std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
|
||||
|
||||
void SetZero()
|
||||
{
|
||||
for(auto& v : mData)
|
||||
{
|
||||
v = T{0};
|
||||
}
|
||||
}
|
||||
void SetZero() { ck::ranges::fill<T>(mData, 0); }
|
||||
|
||||
template <typename F>
|
||||
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
|
||||
|
||||
22
library/include/ck/library/utility/iterator.hpp
Normal file
22
library/include/ck/library/utility/iterator.hpp
Normal file
@@ -0,0 +1,22 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iterator>
|
||||
#include <utility>
|
||||
|
||||
#include "ck/utility/type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename T>
|
||||
using iter_value_t = typename std::iterator_traits<remove_cvref_t<T>>::value_type;
|
||||
|
||||
template <typename T>
|
||||
using iter_reference_t = decltype(*std::declval<T&>());
|
||||
|
||||
template <typename T>
|
||||
using iter_difference_t = typename std::iterator_traits<remove_cvref_t<T>>::difference_type;
|
||||
|
||||
} // namespace ck
|
||||
60
library/include/ck/library/utility/ranges.hpp
Normal file
60
library/include/ck/library/utility/ranges.hpp
Normal file
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iterator>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "ck/library/utility/iterator.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace ranges {
|
||||
|
||||
template <typename R>
|
||||
using iterator_t = decltype(std::begin(std::declval<R&>()));
|
||||
|
||||
template <typename R>
|
||||
using sentinel_t = decltype(std::end(std::declval<R&>()));
|
||||
|
||||
template <typename R>
|
||||
using range_size_t = decltype(std::size(std::declval<R&>()));
|
||||
|
||||
template <typename R>
|
||||
using range_difference_t = ck::iter_difference_t<ranges::iterator_t<R>>;
|
||||
|
||||
template <typename R>
|
||||
using range_value_t = iter_value_t<ranges::iterator_t<R>>;
|
||||
|
||||
template <typename R>
|
||||
using range_reference_t = iter_reference_t<ranges::iterator_t<R>>;
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct is_range : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_range<
|
||||
T,
|
||||
std::void_t<decltype(std::begin(std::declval<T&>())), decltype(std::end(std::declval<T&>()))>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline constexpr bool is_range_v = is_range<T>::value;
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct is_sized_range : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_sized_range<T, std::void_t<decltype(std::size(std::declval<T&>()))>>
|
||||
: std::bool_constant<is_range_v<T>>
|
||||
{
|
||||
};
|
||||
} // namespace ranges
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user