mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Change relu to clamp for grouped conv fwd instances (#2249)
This commit is contained in:
@@ -351,6 +351,98 @@ struct Bilinear
|
||||
float beta_;
|
||||
};
|
||||
|
||||
struct AddClamp
|
||||
{
|
||||
AddClamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
|
||||
: floor_(floor), ceil_(ceil){};
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float, float, float>(float& y, const float& x0, const float& x1) const
|
||||
{
|
||||
const float a = x0 + x1;
|
||||
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<double, double, double>(double& y, const double& x0, const double& x1) const
|
||||
{
|
||||
const double a = x0 + x1;
|
||||
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
{
|
||||
const half_t a = x0 + x1;
|
||||
y = a > type_convert<half_t>(floor_)
|
||||
? (a < type_convert<half_t>(ceil_) ? a : type_convert<half_t>(ceil_))
|
||||
: type_convert<half_t>(floor_);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
const float a = x0 + x1;
|
||||
y = a > type_convert<half_t>(floor_)
|
||||
? (a < type_convert<half_t>(ceil_) ? a : type_convert<half_t>(ceil_))
|
||||
: type_convert<half_t>(floor_);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
const float a = x0 + type_convert<float>(x1);
|
||||
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float a = x0 + type_convert<float>(x1);
|
||||
y = a > type_convert<bhalf_t>(floor_)
|
||||
? (a < type_convert<bhalf_t>(ceil_) ? a : type_convert<bhalf_t>(ceil_))
|
||||
: type_convert<bhalf_t>(floor_);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float a = type_convert<float>(x0) + type_convert<float>(x1);
|
||||
y = a > type_convert<bhalf_t>(floor_)
|
||||
? (a < type_convert<bhalf_t>(ceil_) ? a : type_convert<bhalf_t>(ceil_))
|
||||
: type_convert<bhalf_t>(floor_);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const
|
||||
{
|
||||
const int8_t a = x0 + x1;
|
||||
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
|
||||
{
|
||||
const int8_t a = x0 + x1;
|
||||
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
};
|
||||
|
||||
const float floor_;
|
||||
const float ceil_;
|
||||
};
|
||||
|
||||
struct AddRelu
|
||||
{
|
||||
template <typename Y, typename X0, typename X1>
|
||||
|
||||
Reference in New Issue
Block a user