mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Regulate reduction accumulator operations and Element-wise operations (#274)
* Remove template from Reducton operation classes and add template to their operator() and GetIdentityValue() interfaces * Change to unary elementwise operators and the reduce_unary_operator (class for mapping) and dependent variations in all host layers * Remove the data type template parameter from reduce_binary_operator (class for mapping) and dependent variations in host layers * Add InMemoryDataOperatonSupportedOnDataType to check the matching between data type and InMemoryDataOperation * Use struct-scope operator template instantiation for binary and unary element-wise operations * Change a few more elementwise operations to use template for operator() * Tiny correction in Normalize operator * Add static_assert to check the data type appliability for some reduction accumulator and element-wise operatons * Correction in some examples with regard to using ReduceAccDataType * Use static_assert for UnaryDivide * Update to merged codes to use Element-wise operations and Reduction Accumulator operations correctly * Tiny fix with regard to SetWorkSpacePointer()
This commit is contained in:
@@ -28,100 +28,189 @@
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace binary_element_wise {
|
||||
|
||||
template <typename Y, typename X1, typename X2>
|
||||
struct Add;
|
||||
namespace element_wise {
|
||||
|
||||
template <>
|
||||
struct Add<double, double, double>
|
||||
struct Add
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(double& dst, const double& src1, const double& src2) const
|
||||
operator()<float>(float& y, const float& x0, const float& x1) const
|
||||
{
|
||||
dst = src1 + src2;
|
||||
y = x0 + x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<double>(double& y, const double& x0, const double& x1) const
|
||||
{
|
||||
y = x0 + x1;
|
||||
};
|
||||
|
||||
// Question: should half_t be supported ?
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
{
|
||||
y = x0 + x1;
|
||||
};
|
||||
|
||||
// Question: should bhalf_t be supported ?
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float x1_tmp = ck::type_convert<float>(x0);
|
||||
const float x2_tmp = ck::type_convert<float>(x1);
|
||||
const float y_tmp = x1_tmp + x2_tmp;
|
||||
y = ck::type_convert<bhalf_t>(y_tmp);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Add<float, float, float>
|
||||
struct Subtract
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(float& dst, const float& src1, const float& src2) const
|
||||
operator()<float>(float& y, const float& x0, const float& x1) const
|
||||
{
|
||||
dst = src1 + src2;
|
||||
y = x0 - x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<double>(double& y, const double& x0, const double& x1) const
|
||||
{
|
||||
y = x0 - x1;
|
||||
};
|
||||
|
||||
// Question: should half_t be supported ?
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
{
|
||||
y = x0 - x1;
|
||||
};
|
||||
|
||||
// Question: should bhalf_t be supported ?
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float x1_tmp = ck::type_convert<float>(x0);
|
||||
const float x2_tmp = ck::type_convert<float>(x1);
|
||||
const float y_tmp = x1_tmp - x2_tmp;
|
||||
y = ck::type_convert<bhalf_t>(y_tmp);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Add<half_t, half_t, half_t>
|
||||
struct AlphaBetaAdd
|
||||
{
|
||||
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(half_t& dst, const half_t& src1, const half_t& src2) const
|
||||
operator()<float>(float& y, const float& x0, const float& x1) const
|
||||
{
|
||||
dst = src1 + src2;
|
||||
}
|
||||
y = alpha_ * x0 + beta_ * x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<double>(double& y, const double& x0, const double& x1) const
|
||||
{
|
||||
y = static_cast<double>(alpha_) * x0 + static_cast<double>(beta_) * x1;
|
||||
};
|
||||
|
||||
// Question: should half_t be supported ?
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
{
|
||||
y = static_cast<half_t>(alpha_ * static_cast<float>(x0) + beta_ * static_cast<float>(x1));
|
||||
};
|
||||
|
||||
float alpha_;
|
||||
float beta_;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Add<bhalf_t, bhalf_t, bhalf_t>
|
||||
struct AddRelu
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
|
||||
operator()<float>(float& y, const float& x0, const float& x1) const
|
||||
{
|
||||
const float x1 = ck::type_convert<float>(src1);
|
||||
const float x2 = ck::type_convert<float>(src2);
|
||||
const float y = x1 + x2;
|
||||
dst = ck::type_convert<bhalf_t>(y);
|
||||
}
|
||||
const float a = x0 + x1;
|
||||
y = a > 0.0f ? a : 0.0f;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<double>(double& y, const double& x0, const double& x1) const
|
||||
{
|
||||
const double a = x0 + x1;
|
||||
y = a > 0.0 ? a : 0.0;
|
||||
};
|
||||
|
||||
// Question: should half_t be supported ?
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
{
|
||||
const half_t a = x0 + x1;
|
||||
y = a > static_cast<half_t>(0.0f) ? a : static_cast<half_t>(0.0f);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Y, typename X1, typename X2>
|
||||
struct Substract;
|
||||
|
||||
template <>
|
||||
struct Substract<double, double, double>
|
||||
struct AddHardswish
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(double& dst, const double& src1, const double& src2) const
|
||||
operator()<float>(float& y, const float& x0, const float& x1) const
|
||||
{
|
||||
dst = src1 - src2;
|
||||
}
|
||||
float a = x0 + x1;
|
||||
float b = a + float{3};
|
||||
float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f;
|
||||
y = c;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<double>(double& y, const double& x0, const double& x1) const
|
||||
{
|
||||
double a = x0 + x1;
|
||||
double b = a + 3.0;
|
||||
double c = (b > 0) * (b > 6.0 ? 6.0 : b) * a * 0.166667;
|
||||
y = c;
|
||||
};
|
||||
|
||||
// Question: should half_t be supported ?
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
{
|
||||
float a = x0 + x1;
|
||||
float b = a + 3.0f;
|
||||
float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f;
|
||||
y = c;
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Substract<float, float, float>
|
||||
{
|
||||
__host__ __device__ constexpr void
|
||||
operator()(float& dst, const float& src1, const float& src2) const
|
||||
{
|
||||
dst = src1 - src2;
|
||||
}
|
||||
};
|
||||
} // namespace element_wise
|
||||
|
||||
template <>
|
||||
struct Substract<half_t, half_t, half_t>
|
||||
{
|
||||
__host__ __device__ constexpr void
|
||||
operator()(half_t& dst, const half_t& src1, const half_t& src2) const
|
||||
{
|
||||
dst = src1 - src2;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Substract<bhalf_t, bhalf_t, bhalf_t>
|
||||
{
|
||||
__host__ __device__ constexpr void
|
||||
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
|
||||
{
|
||||
const float x1 = ck::type_convert<float>(src1);
|
||||
const float x2 = ck::type_convert<float>(src2);
|
||||
const float y = x1 - x2;
|
||||
dst = ck::type_convert<bhalf_t>(y);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace binary_element_wise
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -1,97 +1,13 @@
|
||||
#pragma once
|
||||
#include "data_type.hpp"
|
||||
#include "math_v2.hpp"
|
||||
#include "unary_element_wise_operation.hpp"
|
||||
#include "binary_element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace element_wise {
|
||||
|
||||
struct PassThrough
|
||||
{
|
||||
__host__ __device__ void operator()(float& y, const float& x) const { y = x; }
|
||||
|
||||
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; }
|
||||
|
||||
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { y = x; }
|
||||
|
||||
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; }
|
||||
|
||||
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; }
|
||||
|
||||
__host__ __device__ void operator()(double& y, const double& x) const { y = x; }
|
||||
};
|
||||
|
||||
struct Add
|
||||
{
|
||||
__host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const
|
||||
{
|
||||
y = x0 + x1;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
operator()(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
{
|
||||
// FIXME - Use float (acc type) bias in the future.
|
||||
y = x0 + x1;
|
||||
}
|
||||
};
|
||||
|
||||
struct AlphaBetaAdd
|
||||
{
|
||||
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {}
|
||||
|
||||
__host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const
|
||||
{
|
||||
y = alpha_ * x0 + beta_ * x1;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
operator()(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
{
|
||||
// FIXME - Let x0 be acc type
|
||||
y = static_cast<half_t>(alpha_ * static_cast<float>(x0) + beta_ * static_cast<float>(x1));
|
||||
}
|
||||
|
||||
float alpha_;
|
||||
float beta_;
|
||||
};
|
||||
|
||||
struct AddRelu
|
||||
{
|
||||
__host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const
|
||||
{
|
||||
const float a = x0 + x1;
|
||||
y = a > 0 ? a : 0;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
operator()(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
{
|
||||
const half_t a = x0 + x1;
|
||||
y = a > 0 ? a : 0;
|
||||
}
|
||||
};
|
||||
|
||||
struct AddHardswish
|
||||
{
|
||||
__host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const
|
||||
{
|
||||
float a = x0 + x1;
|
||||
float b = a + float{3};
|
||||
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
|
||||
y = c;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
operator()(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
{
|
||||
float a = x0 + x1;
|
||||
float b = a + float{3};
|
||||
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
|
||||
y = c;
|
||||
}
|
||||
};
|
||||
|
||||
struct AddReluAdd
|
||||
{
|
||||
__host__ __device__ constexpr void
|
||||
@@ -167,204 +83,41 @@ struct Relu
|
||||
|
||||
struct Normalize
|
||||
{
|
||||
Normalize(float epsilon = 1e-4) : epsilon_(epsilon) {}
|
||||
Normalize(double epsilon = 1e-4) : epsilon_(epsilon) {}
|
||||
|
||||
__host__ __device__ constexpr void operator()(float& y,
|
||||
const float& x,
|
||||
const float& mean,
|
||||
const float& mean_square,
|
||||
const float& gamma,
|
||||
const float& beta) const
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(
|
||||
T& y, const T& x, const T& mean, const T& mean_square, const T& gamma, const T& beta) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<float>(float& y,
|
||||
const float& x,
|
||||
const float& mean,
|
||||
const float& mean_square,
|
||||
const float& gamma,
|
||||
const float& beta) const
|
||||
{
|
||||
using ck::math::sqrt;
|
||||
|
||||
float variance = mean_square - (mean * mean);
|
||||
y = ((x - mean) / sqrtf(variance + epsilon_)) * gamma + beta;
|
||||
}
|
||||
|
||||
float epsilon_;
|
||||
};
|
||||
|
||||
// Unary operators are usually called element-wisely before/after the reduction is executed on the
|
||||
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
|
||||
|
||||
template <typename Y, typename X, bool HasDividing = false>
|
||||
struct UnaryIdentic;
|
||||
|
||||
template <>
|
||||
struct UnaryIdentic<float, float, false>
|
||||
{
|
||||
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(float& y, const float& x) const { y = x; };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnaryIdentic<float, float, true>
|
||||
{
|
||||
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
|
||||
|
||||
__host__ __device__ void operator()(float& y, const float& x) const
|
||||
{
|
||||
y = x / type_convert<float>(divider_);
|
||||
y = ((x - mean) / sqrt(variance + static_cast<float>(epsilon_))) * gamma + beta;
|
||||
};
|
||||
|
||||
int32_t divider_ = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnaryIdentic<half_t, half_t, false>
|
||||
{
|
||||
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnaryIdentic<double, double, false>
|
||||
{
|
||||
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(double& y, const double& x) const { y = x; };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnaryIdentic<double, double, true>
|
||||
{
|
||||
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
|
||||
|
||||
__host__ __device__ void operator()(double& y, const double& x) const
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<double>(double& y,
|
||||
const double& x,
|
||||
const double& mean,
|
||||
const double& mean_square,
|
||||
const double& gamma,
|
||||
const double& beta) const
|
||||
{
|
||||
y = x / type_convert<double>(divider_);
|
||||
using ck::math::sqrt;
|
||||
|
||||
double variance = mean_square - (mean * mean);
|
||||
y = ((x - mean) / sqrt(variance + epsilon_)) * gamma + beta;
|
||||
};
|
||||
|
||||
int32_t divider_ = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnaryIdentic<int32_t, int32_t, false>
|
||||
{
|
||||
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnaryIdentic<int32_t, int32_t, true>
|
||||
{
|
||||
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
|
||||
|
||||
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x / divider_; };
|
||||
|
||||
int32_t divider_ = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnaryIdentic<int8_t, int8_t, false>
|
||||
{
|
||||
__host__ __device__ UnaryIdentic(const int8_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; };
|
||||
};
|
||||
|
||||
template <typename Y, typename X, bool HasDividing = false>
|
||||
struct UnarySquare;
|
||||
|
||||
template <>
|
||||
struct UnarySquare<float, float, false>
|
||||
{
|
||||
__host__ __device__ UnarySquare(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(float& y, const float& x) const { y = x * x; };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnarySquare<float, float, true>
|
||||
{
|
||||
__host__ __device__ UnarySquare(const int32_t divider = 1) { divider_ = divider; };
|
||||
|
||||
__host__ __device__ void operator()(float& y, const float& x) const
|
||||
{
|
||||
y = x * x / type_convert<float>(divider_);
|
||||
};
|
||||
|
||||
int32_t divider_ = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnarySquare<double, double, false>
|
||||
{
|
||||
__host__ __device__ UnarySquare(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(double& y, const double& x) const { y = x * x; };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnarySquare<double, double, true>
|
||||
{
|
||||
__host__ __device__ UnarySquare(const int32_t divider = 1) { divider_ = divider; };
|
||||
|
||||
__host__ __device__ void operator()(double& y, const double& x) const
|
||||
{
|
||||
y = x * x / type_convert<double>(divider_);
|
||||
};
|
||||
|
||||
int32_t divider_ = 1;
|
||||
};
|
||||
|
||||
template <typename Y, typename X>
|
||||
struct UnaryAbs;
|
||||
|
||||
template <>
|
||||
struct UnaryAbs<float, float>
|
||||
{
|
||||
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::abs(x); };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnaryAbs<half_t, half_t>
|
||||
{
|
||||
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = ck::math::abs(x); };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnaryAbs<double, double>
|
||||
{
|
||||
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(double& y, const double& x) const { y = ck::math::abs(x); };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnaryAbs<int8_t, int8_t>
|
||||
{
|
||||
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = ck::math::abs(x); };
|
||||
};
|
||||
|
||||
template <typename Y, typename X>
|
||||
struct UnarySqrt;
|
||||
|
||||
template <>
|
||||
struct UnarySqrt<float, float>
|
||||
{
|
||||
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::sqrt(x); };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnarySqrt<double, double>
|
||||
{
|
||||
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(double& y, const double& x) const
|
||||
{
|
||||
y = ck::math::sqrt(x);
|
||||
};
|
||||
double epsilon_;
|
||||
};
|
||||
|
||||
template <typename Y, typename X>
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
#pragma once
|
||||
#include "data_type.hpp"
|
||||
#include "math_v2.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace element_wise {
|
||||
|
||||
struct PassThrough
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, half_t>::value || is_same<T, bhalf_t>::value ||
|
||||
is_same<T, int32_t>::value || is_same<T, int8_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = x;
|
||||
};
|
||||
};
|
||||
|
||||
struct UnaryDivide
|
||||
{
|
||||
__host__ __device__ UnaryDivide(const int32_t divider = 1) : divider_(divider){};
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = x / type_convert<T>(divider_);
|
||||
};
|
||||
|
||||
int32_t divider_ = 1;
|
||||
};
|
||||
|
||||
struct UnarySquare
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = x * x;
|
||||
};
|
||||
};
|
||||
|
||||
struct UnaryAbs
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
|
||||
is_same<T, int8_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::abs(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct UnarySqrt
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = ck::math::sqrt(x);
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user