mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
Add bilinear conv fwd and bwd data instances (#1164)
[ROCm/composable_kernel commit: bf98b47697]
This commit is contained in:
@@ -165,7 +165,7 @@ struct Subtract
|
||||
|
||||
struct Bilinear
|
||||
{
|
||||
Bilinear(float alpha, float beta) : alpha_(alpha), beta_(beta){};
|
||||
Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const;
|
||||
@@ -184,6 +184,14 @@ struct Bilinear
|
||||
y = alpha_ * x0 + beta_ * x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
|
||||
{
|
||||
y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
|
||||
beta_ * type_convert<float>(x1));
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
@@ -221,7 +229,8 @@ struct Bilinear
|
||||
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>(
|
||||
std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const
|
||||
{
|
||||
y = type_convert<std::int8_t>(x0 + ck::type_convert<std::int32_t>(x1));
|
||||
y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
|
||||
beta_ * type_convert<float>(x1));
|
||||
};
|
||||
|
||||
float alpha_;
|
||||
|
||||
Reference in New Issue
Block a user