Add examples for GEMM + AddAddFastGelu (data type: int8, bf16, fp32) (#340)

* Add always_false<> util to delay symbol resolution

* Use always_false<> to prevent trying instantiate unwanted method

* Add new specializations of AddAddFastGelu::operator() method

* Add GEMM + AddAddFastGelu examples for data types: int8, bf16, fp32

* Use floating point literal to simplify code

* Remove unnecessary capture in lambda expressions

* Extract fast GeLU calculation as standalone method

* Mark methods as 'constexpr'

* Add constraint for HostTensorDescriptor templated ctors

* Simplify HostTensorDescriptor ctor calls

* Add C++23 std::size_t literal suffix

* Use _uz suffix to shorten example code

* Remove unnecessary conversion to std::array<>

* Re-order include directives

* Remove C-style casting by literal suffix

* Remove unnecessary statements in main()

* Remove unused type parameter of always_false<>

* Remove unused include directive

* Exit main() by returning meaningful value

* Use 'if constexpr' to switch example flow

* Use std::is_same_v<> to shorten example code

* Add 'inline' specifier to literal functions

* Unify output methods in example

* Move common codes into .inc file

* Add type check in type_convert<>()

* Add type_convert<float>() before computation

* Merge AddAddFastGelu method specializations

* Remove always_false<>

* Add constraint to AddAddFastGelu::operator() parameter types
This commit is contained in:
Po Yen Chen
2022-08-12 06:31:28 +08:00
committed by GitHub
parent fdfd7eb597
commit 68b61504a3
10 changed files with 469 additions and 225 deletions

View File

@@ -114,28 +114,33 @@ struct AddHardswishAdd
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E&, const C&, const D0&, const D1&) const;
template <>
__host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
const float& c,
const half_t& d0,
const half_t& d1) const
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
const auto fast_gelu = [&](float x) {
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1));
return x * cdf;
};
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
const float y = fast_gelu(c + float(d0) + float(d1));
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>;
e = type_convert<half_t>(y);
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const
{
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
is_valid_param_type_v<D0> && is_valid_param_type_v<D1>);
const float y =
GetFastGeLU(type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1));
e = type_convert<E>(y);
}
};