Pool2d max/avg kernel in the BWD version (#1494)

* Add pool2d instance BWD AVG

* Add pool2d instance BWD MAX

* Fix: avg review

* Fix review: part2

* Fix - enable test when type is compiled

* Fix review part3
This commit is contained in:
Mateusz Ozga
2024-09-12 11:47:52 +02:00
committed by GitHub
parent e8d2887cb2
commit 448c0f56d8
25 changed files with 2168 additions and 16 deletions

View File

@@ -355,12 +355,39 @@ struct UnaryDivide
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
is_same<T, int32_t>::value || is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
y = x / type_convert<T>(divider_);
};
template <>
__host__ __device__ void operator()<half_t>(half_t& y, const half_t& x) const
{
float x_ = type_convert<float>(x);
float divider_f_ = type_convert<float>(divider_);
y = type_convert<half_t>(x_ / divider_f_);
};
template <>
__host__ __device__ void operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
float x_ = type_convert<float>(x);
float divider_f_ = type_convert<float>(divider_);
y = type_convert<bhalf_t>(x_ / divider_f_);
};
template <>
__host__ __device__ void operator()<f8_t>(f8_t& y, const f8_t& x) const
{
float x_ = type_convert<float>(x);
float divider_f_ = type_convert<float>(divider_);
y = type_convert<f8_t>(x_ / divider_f_);
};
int32_t divider_ = 1;
};