Add Clamp/Relu bf16/fp16 cast fixes (#2279)

* Add Clamp/Relu bf16/fp16 fixes

* fix
This commit is contained in:
Bartłomiej Kocot
2025-06-03 18:31:46 +02:00
committed by GitHub
parent 7f9eef40b0
commit 6e5acee0f9
2 changed files with 25 additions and 18 deletions

View File

@@ -389,10 +389,9 @@ struct AddClamp
__host__ __device__ constexpr void
operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
{
const float a = x0 + x1;
y = a > type_convert<half_t>(floor_)
? (a < type_convert<half_t>(ceil_) ? a : type_convert<half_t>(ceil_))
: type_convert<half_t>(floor_);
const float a = x0 + type_convert<float>(x1);
const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
y = type_convert<half_t>(b);
};
template <>
@@ -408,9 +407,8 @@ struct AddClamp
operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
{
const float a = x0 + type_convert<float>(x1);
y = a > type_convert<bhalf_t>(floor_)
? (a < type_convert<bhalf_t>(ceil_) ? a : type_convert<bhalf_t>(ceil_))
: type_convert<bhalf_t>(floor_);
const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
y = type_convert<bhalf_t>(b);
};
template <>
@@ -418,9 +416,8 @@ struct AddClamp
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
{
const float a = type_convert<float>(x0) + type_convert<float>(x1);
y = a > type_convert<bhalf_t>(floor_)
? (a < type_convert<bhalf_t>(ceil_) ? a : type_convert<bhalf_t>(ceil_))
: type_convert<bhalf_t>(floor_);
const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
y = type_convert<bhalf_t>(b);
};
template <>
@@ -476,8 +473,9 @@ struct AddRelu
__host__ __device__ constexpr void
operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
{
const float a = x0 + x1;
y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
const float a = x0 + type_convert<float>(x1);
const float b = a > 0.0f ? a : 0.0f;
y = type_convert<half_t>(b);
};
template <>
@@ -493,7 +491,8 @@ struct AddRelu
operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
{
const float a = x0 + type_convert<float>(x1);
y = a > type_convert<bhalf_t>(0.0f) ? a : type_convert<bhalf_t>(0.0f);
const float b = a > 0.0f ? a : 0.0f;
y = type_convert<bhalf_t>(b);
};
template <>
@@ -501,7 +500,8 @@ struct AddRelu
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
{
const float a = type_convert<float>(x0) + type_convert<float>(x1);
y = a > type_convert<bhalf_t>(0.0f) ? a : type_convert<bhalf_t>(0.0f);
const float b = a > 0.0f ? a : 0.0f;
y = type_convert<bhalf_t>(b);
};
template <>