mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Add examples for GEMM + AddAddFastGelu (data type: int8, bf16, fp32) (#340)
* Add always_false<> util to delay symbol resolution
* Use always_false<> to prevent trying instantiate unwanted method
* Add new specializations of AddAddFastGelu::operator() method
* Add GEMM + AddAddFastGelu examples for data types: int8, bf16, fp32
* Use floating point literal to simplify code
* Remove unnecessary capture in lambda expressions
* Extract fast GeLU calculation as standalone method
* Mark methods as 'constexpr'
* Add constraint for HostTensorDescriptor templated ctors
* Simplify HostTensorDescriptor ctor calls
* Add C++23 std::size_t literal suffix
* Use _uz suffix to shorten example code
* Remove unnecessary conversion to std::array<>
* Re-order include directives
* Remove C-style casting by literal suffix
* Remove unnecessary statements in main()
* Remove unused type parameter of always_false<>
* Remove unused include directive
* Exit main() by returning meaningful value
* Use 'if constexpr' to switch example flow
* Use std::is_same_v<> to shorten example code
* Add 'inline' specifier to literal functions
* Unify output methods in example
* Move common codes into .inc file
* Add type check in type_convert<>()
* Add type_convert<float>() before computation
* Merge AddAddFastGelu method specializations
* Remove always_false<>
* Add constraint to AddAddFastGelu::operator() parameter types
[ROCm/composable_kernel commit: 68b61504a3]
This commit is contained in:
@@ -114,28 +114,33 @@ struct AddHardswishAdd
|
||||
// E = FastGelu(C + D0 + D1)
|
||||
struct AddAddFastGelu
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ void operator()(E&, const C&, const D0&, const D1&) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
|
||||
const float& c,
|
||||
const half_t& d0,
|
||||
const half_t& d1) const
|
||||
// Fast GeLU
|
||||
// https://paperswithcode.com/method/gelu
|
||||
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
|
||||
__host__ __device__ static constexpr float GetFastGeLU(float x)
|
||||
{
|
||||
// Fast GeLU
|
||||
// https://paperswithcode.com/method/gelu
|
||||
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
|
||||
const auto fast_gelu = [&](float x) {
|
||||
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
|
||||
const float emu = exp(-u);
|
||||
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1));
|
||||
return x * cdf;
|
||||
};
|
||||
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
|
||||
const float emu = exp(-u);
|
||||
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
|
||||
return x * cdf;
|
||||
}
|
||||
|
||||
const float y = fast_gelu(c + float(d0) + float(d1));
|
||||
template <typename T>
|
||||
static inline constexpr bool is_valid_param_type_v =
|
||||
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>;
|
||||
|
||||
e = type_convert<half_t>(y);
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1) const
|
||||
{
|
||||
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
|
||||
is_valid_param_type_v<D0> && is_valid_param_type_v<D1>);
|
||||
|
||||
const float y =
|
||||
GetFastGeLU(type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1));
|
||||
|
||||
e = type_convert<E>(y);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -934,6 +934,8 @@ using int8x64_t = typename vector_type<int8_t, 64>::type;
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user