Add ScaleAddScaleAddRelu post op for conv fwd (#1006)

* Add ScaleAddScaleAddRelu post op for conv fwd

* Fixes

* Fix instance file name

* Minor fix
This commit is contained in:
Bartłomiej Kocot
2023-11-02 00:31:30 +01:00
committed by GitHub
parent 306fd506b1
commit f27ea94ecb
18 changed files with 1235 additions and 9 deletions

View File

@@ -311,6 +311,71 @@ struct AddAddFastGelu
}
};
// E = Relu(alpha1 * C + alpha2 * D0 + D1)
struct ScaleAddScaleAddRelu
{
ScaleAddScaleAddRelu(const float alpha1 = 1.f, const float alpha2 = 1.f)
: alpha1_(alpha1), alpha2_(alpha2)
{
}
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<float, float, float, float>(float& e,
const float& c,
const float& d0,
const float& d1) const
{
const float x = c * alpha1_ + alpha2_ * d0 + d1;
Relu{}.template operator()<float>(e, x);
}
template <>
__host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const
{
const float x = type_convert<float>(c) * alpha1_ + alpha2_ * type_convert<float>(d0) +
type_convert<float>(d1);
float result = 0;
Relu{}.template operator()<float>(result, x);
e = type_convert<half_t>(result);
}
template <>
__host__ __device__ constexpr void operator()<bhalf_t, bhalf_t, bhalf_t, bhalf_t>(
bhalf_t& e, const bhalf_t& c, const bhalf_t& d0, const bhalf_t& d1) const
{
const float x = type_convert<float>(c) * alpha1_ + alpha2_ * type_convert<float>(d0) +
type_convert<float>(d1);
float result = 0;
Relu{}.template operator()<float>(result, x);
e = type_convert<bhalf_t>(result);
}
template <>
__host__ __device__ constexpr void operator()<int8_t, int8_t, float, float>(
int8_t& e, const int8_t& c, const float& d0, const float& d1) const
{
const float x = type_convert<float>(c) * alpha1_ + alpha2_ * d0 + d1;
float result = 0;
Relu{}.template operator()<float>(result, x);
e = type_convert<int8_t>(result);
}
const float alpha1_;
const float alpha2_;
};
struct Normalize
{
// FIXME: is double absolutely necessary?