mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Add ScaleAddScaleAddRelu post op for conv fwd (#1006)
* Add ScaleAddScaleAddRelu post op for conv fwd * Fixes * Fix instance file name * Minor fix
This commit is contained in:
@@ -311,6 +311,71 @@ struct AddAddFastGelu
|
||||
}
|
||||
};
|
||||
|
||||
// E = Relu(alpha1 * C + alpha2 * D0 + D1)
|
||||
struct ScaleAddScaleAddRelu
|
||||
{
|
||||
|
||||
ScaleAddScaleAddRelu(const float alpha1 = 1.f, const float alpha2 = 1.f)
|
||||
: alpha1_(alpha1), alpha2_(alpha2)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<float, float, float, float>(float& e,
|
||||
const float& c,
|
||||
const float& d0,
|
||||
const float& d1) const
|
||||
{
|
||||
const float x = c * alpha1_ + alpha2_ * d0 + d1;
|
||||
Relu{}.template operator()<float>(e, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
|
||||
half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const
|
||||
{
|
||||
const float x = type_convert<float>(c) * alpha1_ + alpha2_ * type_convert<float>(d0) +
|
||||
type_convert<float>(d1);
|
||||
|
||||
float result = 0;
|
||||
Relu{}.template operator()<float>(result, x);
|
||||
|
||||
e = type_convert<half_t>(result);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<bhalf_t, bhalf_t, bhalf_t, bhalf_t>(
|
||||
bhalf_t& e, const bhalf_t& c, const bhalf_t& d0, const bhalf_t& d1) const
|
||||
{
|
||||
const float x = type_convert<float>(c) * alpha1_ + alpha2_ * type_convert<float>(d0) +
|
||||
type_convert<float>(d1);
|
||||
|
||||
float result = 0;
|
||||
Relu{}.template operator()<float>(result, x);
|
||||
|
||||
e = type_convert<bhalf_t>(result);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<int8_t, int8_t, float, float>(
|
||||
int8_t& e, const int8_t& c, const float& d0, const float& d1) const
|
||||
{
|
||||
const float x = type_convert<float>(c) * alpha1_ + alpha2_ * d0 + d1;
|
||||
|
||||
float result = 0;
|
||||
Relu{}.template operator()<float>(result, x);
|
||||
|
||||
e = type_convert<int8_t>(result);
|
||||
}
|
||||
|
||||
const float alpha1_;
|
||||
const float alpha2_;
|
||||
};
|
||||
|
||||
struct Normalize
|
||||
{
|
||||
// FIXME: is double absolutely necessary?
|
||||
|
||||
Reference in New Issue
Block a user