mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Add Clamp/Relu bf16/fp16 cast fixes (#2279)
* Add Clamp/Relu bf16/fp16 fixes * fix
This commit is contained in:
@@ -389,10 +389,9 @@ struct AddClamp
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
const float a = x0 + x1;
|
||||
y = a > type_convert<half_t>(floor_)
|
||||
? (a < type_convert<half_t>(ceil_) ? a : type_convert<half_t>(ceil_))
|
||||
: type_convert<half_t>(floor_);
|
||||
const float a = x0 + type_convert<float>(x1);
|
||||
const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
y = type_convert<half_t>(b);
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -408,9 +407,8 @@ struct AddClamp
|
||||
operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float a = x0 + type_convert<float>(x1);
|
||||
y = a > type_convert<bhalf_t>(floor_)
|
||||
? (a < type_convert<bhalf_t>(ceil_) ? a : type_convert<bhalf_t>(ceil_))
|
||||
: type_convert<bhalf_t>(floor_);
|
||||
const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
y = type_convert<bhalf_t>(b);
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -418,9 +416,8 @@ struct AddClamp
|
||||
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float a = type_convert<float>(x0) + type_convert<float>(x1);
|
||||
y = a > type_convert<bhalf_t>(floor_)
|
||||
? (a < type_convert<bhalf_t>(ceil_) ? a : type_convert<bhalf_t>(ceil_))
|
||||
: type_convert<bhalf_t>(floor_);
|
||||
const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
y = type_convert<bhalf_t>(b);
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -476,8 +473,9 @@ struct AddRelu
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
const float a = x0 + x1;
|
||||
y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
|
||||
const float a = x0 + type_convert<float>(x1);
|
||||
const float b = a > 0.0f ? a : 0.0f;
|
||||
y = type_convert<half_t>(b);
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -493,7 +491,8 @@ struct AddRelu
|
||||
operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float a = x0 + type_convert<float>(x1);
|
||||
y = a > type_convert<bhalf_t>(0.0f) ? a : type_convert<bhalf_t>(0.0f);
|
||||
const float b = a > 0.0f ? a : 0.0f;
|
||||
y = type_convert<bhalf_t>(b);
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -501,7 +500,8 @@ struct AddRelu
|
||||
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float a = type_convert<float>(x0) + type_convert<float>(x1);
|
||||
y = a > type_convert<bhalf_t>(0.0f) ? a : type_convert<bhalf_t>(0.0f);
|
||||
const float b = a > 0.0f ? a : 0.0f;
|
||||
y = type_convert<bhalf_t>(b);
|
||||
};
|
||||
|
||||
template <>
|
||||
|
||||
Reference in New Issue
Block a user