Gemm+Bilinear (#316)

* refactor

* update example

* update example

* gemm bilinear

* clean

* update
This commit is contained in:
Chao Liu
2022-07-02 09:15:38 -05:00
committed by GitHub
parent 8e374781d5
commit 9e4429f9c3
75 changed files with 1485 additions and 4658 deletions

View File

@@ -35,7 +35,6 @@ struct Add
y = type_convert<half_t>(x0) + x1;
};
// Question: should half_t be supported ?
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
@@ -43,7 +42,6 @@ struct Add
y = x0 + x1;
};
// Question: should bhalf_t be supported ?
template <>
__host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
@@ -74,7 +72,6 @@ struct Subtract
y = x0 - x1;
};
// Question: should half_t be supported ?
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
@@ -82,7 +79,6 @@ struct Subtract
y = x0 - x1;
};
// Question: should bhalf_t be supported ?
template <>
__host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
@@ -94,33 +90,25 @@ struct Subtract
}
};
struct AlphaBetaAdd
struct Bilinear
{
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){};
Bilinear(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const;
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const float& x1) const
operator()<float, float, float>(float& y, const float& x0, const float& x1) const
{
y = alpha_ * x0 + beta_ * x1;
};
template <>
__host__ __device__ constexpr void
operator()<double>(double& y, const double& x0, const double& x1) const
operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
{
y = static_cast<double>(alpha_) * x0 + static_cast<double>(beta_) * x1;
};
// Question: should half_t be supported ?
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
y = static_cast<half_t>(alpha_ * static_cast<float>(x0) + beta_ * static_cast<float>(x1));
y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1));
};
float alpha_;
@@ -148,13 +136,12 @@ struct AddRelu
y = a > 0.0 ? a : 0.0;
};
// Question: should half_t be supported ?
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
const half_t a = x0 + x1;
y = a > static_cast<half_t>(0.0f) ? a : static_cast<half_t>(0.0f);
y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
};
};
@@ -183,7 +170,6 @@ struct AddHardswish
y = c;
};
// Question: should half_t be supported ?
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const

View File

@@ -159,7 +159,7 @@ struct Normalize
using ck::math::sqrt;
float variance = mean_square - (mean * mean);
y = ((x - mean) / sqrt(variance + static_cast<float>(epsilon_))) * gamma + beta;
y = ((x - mean) / sqrt(variance + type_convert<float>(epsilon_))) * gamma + beta;
};
template <>