mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
Add examples for GEMM + AddAddFastGelu (data type: int8, bf16, fp32) (#340)
* Add always_false<> util to delay symbol resolution * Use always_false<> to prevent trying instantiate unwanted method * Add new specializations of AddAddFastGelu::operator() method * Add GEMM + AddAddFastGelu examples for data types: int8, bf16, fp32 * Use floating point literal to simplify code * Remove unnecessary capture in lambda expressions * Extract fast GeLU calculation as standalone method * Mark methods as 'constexpr' * Add constraint for HostTensorDescriptor templated ctors * Simplify HostTensorDescriptor ctor calls * Add C++23 std::size_t literal suffix * Use _uz suffix to shorten example code * Remove unnecessary conversion to std::array<> * Re-order include directives * Remove C-style casting by literal suffix * Remove unnecessary statements in main() * Remove unused type parameter of always_false<> * Remove unused include directive * Exit main() by returning meaningful value * Use 'if constexpr' to switch example flow * Use std::is_same_v<> to shorten example code * Add 'inline' specifier to literal functions * Unify output methods in example * Move common codes into .inc file * Add type check in type_convert<>() * Add type_convert<float>() before computation * Merge AddAddFastGelu method specializations * Remove always_false<> * Add constraint to AddAddFastGelu::operator() parameter types
This commit is contained in:
@@ -114,28 +114,33 @@ struct AddHardswishAdd
|
||||
// E = FastGelu(C + D0 + D1)
|
||||
struct AddAddFastGelu
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ void operator()(E&, const C&, const D0&, const D1&) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
|
||||
const float& c,
|
||||
const half_t& d0,
|
||||
const half_t& d1) const
|
||||
// Fast GeLU
|
||||
// https://paperswithcode.com/method/gelu
|
||||
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
|
||||
__host__ __device__ static constexpr float GetFastGeLU(float x)
|
||||
{
|
||||
// Fast GeLU
|
||||
// https://paperswithcode.com/method/gelu
|
||||
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
|
||||
const auto fast_gelu = [&](float x) {
|
||||
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
|
||||
const float emu = exp(-u);
|
||||
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1));
|
||||
return x * cdf;
|
||||
};
|
||||
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
|
||||
const float emu = exp(-u);
|
||||
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
|
||||
return x * cdf;
|
||||
}
|
||||
|
||||
const float y = fast_gelu(c + float(d0) + float(d1));
|
||||
template <typename T>
|
||||
static inline constexpr bool is_valid_param_type_v =
|
||||
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>;
|
||||
|
||||
e = type_convert<half_t>(y);
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1) const
|
||||
{
|
||||
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
|
||||
is_valid_param_type_v<D0> && is_valid_param_type_v<D1>);
|
||||
|
||||
const float y =
|
||||
GetFastGeLU(type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1));
|
||||
|
||||
e = type_convert<E>(y);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user