Client examples AddFastGelu and FastGelu + instances. (#509)

* FastGelu support for more data types.

* AddFastGelu & FastGelu instances.

* Client example.

* clang-format

* Remove unused stride variable.

* Add new line at EOF.

Co-authored-by: Adam Osewski <aosewski@amd.com>
This commit is contained in:
Adam Osewski
2022-11-20 05:08:26 +01:00
committed by GitHub
parent 892a8d769d
commit 43a889b72e
17 changed files with 1645 additions and 10 deletions

View File

@@ -194,21 +194,36 @@ struct Relu
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// Y = FastGelu(X)
struct FastGelu
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1));
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
y = x * cdf;
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, ck::int4_t>
#endif
;
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
static_assert(is_valid_param_type_v<Y> && is_valid_param_type_v<X>);
const float tmp_y = GetFastGeLU(type_convert<float>(x));
y = type_convert<Y>(tmp_y);
}
};