mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
Pool2d max/avg kernel in the BWD version (#1494)
* Add pool2d instance BWD AVG * Add pool2d instance BWD MAX * Fix: avg review * Fix review: part2 * Fix - enable test when type is compiled * Fix review part3
This commit is contained in:
@@ -355,12 +355,39 @@ struct UnaryDivide
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, int32_t>::value,
|
||||
is_same<T, int32_t>::value || is_same<T, int8_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = x / type_convert<T>(divider_);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t>(half_t& y, const half_t& x) const
|
||||
{
|
||||
float x_ = type_convert<float>(x);
|
||||
float divider_f_ = type_convert<float>(divider_);
|
||||
|
||||
y = type_convert<half_t>(x_ / divider_f_);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x) const
|
||||
{
|
||||
float x_ = type_convert<float>(x);
|
||||
float divider_f_ = type_convert<float>(divider_);
|
||||
|
||||
y = type_convert<bhalf_t>(x_ / divider_f_);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t>(f8_t& y, const f8_t& x) const
|
||||
{
|
||||
float x_ = type_convert<float>(x);
|
||||
float divider_f_ = type_convert<float>(divider_);
|
||||
|
||||
y = type_convert<f8_t>(x_ / divider_f_);
|
||||
};
|
||||
|
||||
int32_t divider_ = 1;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user