mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Gemm+Bilinear (#316)
* refactor * update example * update example * gemm bilinear * clean * update
This commit is contained in:
@@ -35,7 +35,6 @@ struct Add
|
||||
y = type_convert<half_t>(x0) + x1;
|
||||
};
|
||||
|
||||
// Question: should half_t be supported ?
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
@@ -43,7 +42,6 @@ struct Add
|
||||
y = x0 + x1;
|
||||
};
|
||||
|
||||
// Question: should bhalf_t be supported ?
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
|
||||
@@ -74,7 +72,6 @@ struct Subtract
|
||||
y = x0 - x1;
|
||||
};
|
||||
|
||||
// Question: should half_t be supported ?
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
@@ -82,7 +79,6 @@ struct Subtract
|
||||
y = x0 - x1;
|
||||
};
|
||||
|
||||
// Question: should bhalf_t be supported ?
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
|
||||
@@ -94,33 +90,25 @@ struct Subtract
|
||||
}
|
||||
};
|
||||
|
||||
struct AlphaBetaAdd
|
||||
struct Bilinear
|
||||
{
|
||||
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){};
|
||||
Bilinear(float alpha, float beta) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float>(float& y, const float& x0, const float& x1) const
|
||||
operator()<float, float, float>(float& y, const float& x0, const float& x1) const
|
||||
{
|
||||
y = alpha_ * x0 + beta_ * x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<double>(double& y, const double& x0, const double& x1) const
|
||||
operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
y = static_cast<double>(alpha_) * x0 + static_cast<double>(beta_) * x1;
|
||||
};
|
||||
|
||||
// Question: should half_t be supported ?
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
{
|
||||
y = static_cast<half_t>(alpha_ * static_cast<float>(x0) + beta_ * static_cast<float>(x1));
|
||||
y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1));
|
||||
};
|
||||
|
||||
float alpha_;
|
||||
@@ -148,13 +136,12 @@ struct AddRelu
|
||||
y = a > 0.0 ? a : 0.0;
|
||||
};
|
||||
|
||||
// Question: should half_t be supported ?
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
{
|
||||
const half_t a = x0 + x1;
|
||||
y = a > static_cast<half_t>(0.0f) ? a : static_cast<half_t>(0.0f);
|
||||
y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
|
||||
};
|
||||
};
|
||||
|
||||
@@ -183,7 +170,6 @@ struct AddHardswish
|
||||
y = c;
|
||||
};
|
||||
|
||||
// Question: should half_t be supported ?
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
|
||||
@@ -159,7 +159,7 @@ struct Normalize
|
||||
using ck::math::sqrt;
|
||||
|
||||
float variance = mean_square - (mean * mean);
|
||||
y = ((x - mean) / sqrt(variance + static_cast<float>(epsilon_))) * gamma + beta;
|
||||
y = ((x - mean) / sqrt(variance + type_convert<float>(epsilon_))) * gamma + beta;
|
||||
};
|
||||
|
||||
template <>
|
||||
|
||||
Reference in New Issue
Block a user