mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
Fast GeLU using built-in function (#587)
* clean up * fast gelu using builtin function * clean * clean * clean * clean: * clean * fix compilation * clean * clean --------- Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -62,7 +62,7 @@ struct ExecutionConfig final
|
|||||||
};
|
};
|
||||||
|
|
||||||
inline bool
|
inline bool
|
||||||
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig config)
|
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
|
||||||
{
|
{
|
||||||
if(argc == 1)
|
if(argc == 1)
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -7,10 +7,11 @@ using ADataType = BF16;
|
|||||||
using BDataType = BF16;
|
using BDataType = BF16;
|
||||||
using AccDataType = F32;
|
using AccDataType = F32;
|
||||||
using CShuffleDataType = F32;
|
using CShuffleDataType = F32;
|
||||||
using D0DataType = BF16;
|
using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification
|
||||||
using D1DataType = BF16;
|
using D0DataType = BF16;
|
||||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
using D1DataType = BF16;
|
||||||
using EDataType = BF16;
|
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||||
|
using EDataType = BF16;
|
||||||
|
|
||||||
using ALayout = Row;
|
using ALayout = Row;
|
||||||
using BLayout = Col;
|
using BLayout = Col;
|
||||||
@@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
|
|||||||
|
|
||||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||||
BDataType,
|
BDataType,
|
||||||
AccDataType,
|
CDataType,
|
||||||
AccDataType,
|
AccDataType,
|
||||||
AElementOp,
|
AElementOp,
|
||||||
BElementOp,
|
BElementOp,
|
||||||
|
|||||||
@@ -7,10 +7,11 @@ using ADataType = F16;
|
|||||||
using BDataType = F16;
|
using BDataType = F16;
|
||||||
using AccDataType = F32;
|
using AccDataType = F32;
|
||||||
using CShuffleDataType = F32;
|
using CShuffleDataType = F32;
|
||||||
using D0DataType = F16;
|
using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification
|
||||||
using D1DataType = F16;
|
using D0DataType = F16;
|
||||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
using D1DataType = F16;
|
||||||
using EDataType = F16;
|
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||||
|
using EDataType = F16;
|
||||||
|
|
||||||
using ALayout = Row;
|
using ALayout = Row;
|
||||||
using BLayout = Col;
|
using BLayout = Col;
|
||||||
@@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
|
|||||||
|
|
||||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||||
BDataType,
|
BDataType,
|
||||||
AccDataType,
|
CDataType,
|
||||||
AccDataType,
|
AccDataType,
|
||||||
AElementOp,
|
AElementOp,
|
||||||
BElementOp,
|
BElementOp,
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
// SPDX-License-Identifier: MIT
|
|
||||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||||
|
|
||||||
#include "common.hpp"
|
#include "common.hpp"
|
||||||
@@ -7,10 +6,11 @@ using ADataType = F32;
|
|||||||
using BDataType = F32;
|
using BDataType = F32;
|
||||||
using AccDataType = F32;
|
using AccDataType = F32;
|
||||||
using CShuffleDataType = F32;
|
using CShuffleDataType = F32;
|
||||||
using D0DataType = F32;
|
using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification
|
||||||
using D1DataType = F32;
|
using D0DataType = F32;
|
||||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
using D1DataType = F32;
|
||||||
using EDataType = F32;
|
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||||
|
using EDataType = F32;
|
||||||
|
|
||||||
using ALayout = Row;
|
using ALayout = Row;
|
||||||
using BLayout = Col;
|
using BLayout = Col;
|
||||||
@@ -36,7 +36,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
|
|||||||
|
|
||||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||||
BDataType,
|
BDataType,
|
||||||
AccDataType,
|
CDataType,
|
||||||
AccDataType,
|
AccDataType,
|
||||||
AElementOp,
|
AElementOp,
|
||||||
BElementOp,
|
BElementOp,
|
||||||
|
|||||||
@@ -11,10 +11,11 @@ using ADataType = I4;
|
|||||||
using BDataType = I4;
|
using BDataType = I4;
|
||||||
using AccDataType = I32;
|
using AccDataType = I32;
|
||||||
using CShuffleDataType = I32;
|
using CShuffleDataType = I32;
|
||||||
using D0DataType = I4;
|
using CDataType = I32; // C matrix doesn't exsit in GPU memory, this is used for host verification
|
||||||
using D1DataType = I4;
|
using D0DataType = I4;
|
||||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
using D1DataType = I4;
|
||||||
using EDataType = I4;
|
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||||
|
using EDataType = I4;
|
||||||
|
|
||||||
using KernelADataType = I8;
|
using KernelADataType = I8;
|
||||||
using KernelBDataType = I8;
|
using KernelBDataType = I8;
|
||||||
@@ -47,7 +48,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
|
|||||||
|
|
||||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||||
BDataType,
|
BDataType,
|
||||||
AccDataType,
|
CDataType,
|
||||||
AccDataType,
|
AccDataType,
|
||||||
AElementOp,
|
AElementOp,
|
||||||
BElementOp,
|
BElementOp,
|
||||||
|
|||||||
@@ -7,10 +7,11 @@ using ADataType = I8;
|
|||||||
using BDataType = I8;
|
using BDataType = I8;
|
||||||
using AccDataType = I32;
|
using AccDataType = I32;
|
||||||
using CShuffleDataType = I32;
|
using CShuffleDataType = I32;
|
||||||
using D0DataType = I8;
|
using CDataType = I32; // C matrix doesn't exsit in GPU memory, this is used for host verification
|
||||||
using D1DataType = I8;
|
using D0DataType = I8;
|
||||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
using D1DataType = I8;
|
||||||
using EDataType = I8;
|
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||||
|
using EDataType = I8;
|
||||||
|
|
||||||
using ALayout = Row;
|
using ALayout = Row;
|
||||||
using BLayout = Col;
|
using BLayout = Col;
|
||||||
@@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
|
|||||||
|
|
||||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||||
BDataType,
|
BDataType,
|
||||||
AccDataType,
|
CDataType,
|
||||||
AccDataType,
|
AccDataType,
|
||||||
AElementOp,
|
AElementOp,
|
||||||
BElementOp,
|
BElementOp,
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
|
|||||||
|
|
||||||
if(config.do_verification)
|
if(config.do_verification)
|
||||||
{
|
{
|
||||||
Tensor<AccDataType> c_m_n({M, N});
|
Tensor<CDataType> c_m_n({M, N});
|
||||||
|
|
||||||
auto ref_gemm = ReferenceGemmInstance{};
|
auto ref_gemm = ReferenceGemmInstance{};
|
||||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||||
|
|||||||
@@ -168,6 +168,9 @@
|
|||||||
// tuning parameter
|
// tuning parameter
|
||||||
#define CK_WORKAROUND_SWDEV_325164 0
|
#define CK_WORKAROUND_SWDEV_325164 0
|
||||||
|
|
||||||
|
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
|
||||||
|
#define CK_WORKAROUND_SWDEV_383542 1
|
||||||
|
|
||||||
// flag to enable (1) or disable (0) the debugging output in some kernels
|
// flag to enable (1) or disable (0) the debugging output in some kernels
|
||||||
#define DEBUG_LOG 0
|
#define DEBUG_LOG 0
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "ck/utility/data_type.hpp"
|
#include "ck/utility/data_type.hpp"
|
||||||
|
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
namespace tensor_operation {
|
namespace tensor_operation {
|
||||||
@@ -280,43 +281,42 @@ struct AddHardswish
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
// C = A * B
|
|
||||||
// E = FastGelu(C + D)
|
// E = FastGelu(C + D)
|
||||||
struct AddFastGelu
|
struct AddFastGelu
|
||||||
{
|
{
|
||||||
// Fast GeLU
|
|
||||||
// https://paperswithcode.com/method/gelu
|
|
||||||
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
|
|
||||||
__host__ __device__ static constexpr float GetFastGeLU(float x)
|
|
||||||
{
|
|
||||||
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
|
|
||||||
const float emu = exp(-u);
|
|
||||||
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
|
|
||||||
return x * cdf;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static inline constexpr bool is_valid_param_type_v =
|
|
||||||
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
|
|
||||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>;
|
|
||||||
|
|
||||||
template <typename E, typename C, typename D>
|
template <typename E, typename C, typename D>
|
||||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const
|
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__host__ __device__ constexpr void
|
||||||
|
operator()<float, float, float>(float& e, const float& c, const float& d) const
|
||||||
{
|
{
|
||||||
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
|
const float x = c + d;
|
||||||
is_valid_param_type_v<D>);
|
|
||||||
|
|
||||||
const float y = GetFastGeLU(type_convert<float>(c) + type_convert<float>(d));
|
FastGelu{}.template operator()<float, float>(e, x);
|
||||||
|
|
||||||
e = type_convert<E>(y);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename D>
|
template <>
|
||||||
__host__ __device__ constexpr void operator()(float& e, const float& c, const D& d) const
|
__host__ __device__ constexpr void
|
||||||
|
operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
|
||||||
{
|
{
|
||||||
static_assert(is_valid_param_type_v<D>);
|
const half_t x = c + d;
|
||||||
|
|
||||||
e = GetFastGeLU(c + type_convert<float>(d));
|
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__host__ __device__ constexpr void
|
||||||
|
operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
|
||||||
|
{
|
||||||
|
const float x0_f = c + d;
|
||||||
|
|
||||||
|
float x1_f = 0;
|
||||||
|
|
||||||
|
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
|
||||||
|
x0_f);
|
||||||
|
|
||||||
|
e = type_convert<half_t>(x1_f);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ namespace element_wise {
|
|||||||
// Need to ensure compiler will fail if there is no matching candidate, instead of compiler
|
// Need to ensure compiler will fail if there is no matching candidate, instead of compiler
|
||||||
// siliently do implicit type conversion
|
// siliently do implicit type conversion
|
||||||
//
|
//
|
||||||
// Method 1:
|
// Example:
|
||||||
//
|
//
|
||||||
// struct ExampleElementwiseOp
|
// struct ExampleElementwiseOp
|
||||||
// {
|
// {
|
||||||
@@ -30,19 +30,6 @@ namespace element_wise {
|
|||||||
// {
|
// {
|
||||||
// }
|
// }
|
||||||
// };
|
// };
|
||||||
//
|
|
||||||
// Method 2:
|
|
||||||
//
|
|
||||||
// template <typename Y, typename X>
|
|
||||||
// struct ExampleElementwiseOp;
|
|
||||||
//
|
|
||||||
// template <>
|
|
||||||
// struct ExampleElementwiseOp<float, ck::bhalf_t>
|
|
||||||
// {
|
|
||||||
// __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
|
|
||||||
// {
|
|
||||||
// }
|
|
||||||
// };
|
|
||||||
|
|
||||||
struct AddReluAdd
|
struct AddReluAdd
|
||||||
{
|
{
|
||||||
@@ -208,41 +195,74 @@ struct AddMultiply
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// C = A * B
|
|
||||||
// E = FastGelu(C + D0 + D1)
|
// E = FastGelu(C + D0 + D1)
|
||||||
struct AddAddFastGelu
|
struct AddAddFastGelu
|
||||||
{
|
{
|
||||||
// Fast GeLU
|
|
||||||
// https://paperswithcode.com/method/gelu
|
|
||||||
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
|
|
||||||
__host__ __device__ static constexpr float GetFastGeLU(float x)
|
|
||||||
{
|
|
||||||
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
|
|
||||||
const float emu = exp(-u);
|
|
||||||
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
|
|
||||||
return x * cdf;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static inline constexpr bool is_valid_param_type_v =
|
|
||||||
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
|
|
||||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>
|
|
||||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
|
||||||
|| std::is_same_v<T, ck::int4_t>
|
|
||||||
#endif
|
|
||||||
;
|
|
||||||
|
|
||||||
template <typename E, typename C, typename D0, typename D1>
|
template <typename E, typename C, typename D0, typename D1>
|
||||||
__host__ __device__ constexpr void
|
__host__ __device__ constexpr void
|
||||||
operator()(E& e, const C& c, const D0& d0, const D1& d1) const
|
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__host__ __device__ constexpr void operator()<float, float, float, float>(float& e,
|
||||||
|
const float& c,
|
||||||
|
const float& d0,
|
||||||
|
const float& d1) const
|
||||||
{
|
{
|
||||||
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
|
const float x = c + d0 + d1;
|
||||||
is_valid_param_type_v<D0> && is_valid_param_type_v<D1>);
|
|
||||||
|
|
||||||
const float y =
|
FastGelu{}.template operator()<float, float>(e, x);
|
||||||
GetFastGeLU(type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1));
|
}
|
||||||
|
|
||||||
e = type_convert<E>(y);
|
template <>
|
||||||
|
__host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
|
||||||
|
half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const
|
||||||
|
{
|
||||||
|
const half_t x = c + d0 + d1;
|
||||||
|
|
||||||
|
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
|
||||||
|
half_t& e, const float& c, const half_t& d0, const half_t& d1) const
|
||||||
|
{
|
||||||
|
const float x0_f = c + d0 + d1;
|
||||||
|
|
||||||
|
float x1_f = 0;
|
||||||
|
|
||||||
|
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
|
||||||
|
x0_f);
|
||||||
|
|
||||||
|
e = type_convert<half_t>(x1_f);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__host__ __device__ constexpr void operator()<bhalf_t, float, bhalf_t, bhalf_t>(
|
||||||
|
bhalf_t& e, const float& c, const bhalf_t& d0, const bhalf_t& d1) const
|
||||||
|
{
|
||||||
|
const float x0_f = c + type_convert<float>(d0) + type_convert<float>(d1);
|
||||||
|
|
||||||
|
float x1_f = 0;
|
||||||
|
|
||||||
|
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
|
||||||
|
x0_f);
|
||||||
|
|
||||||
|
e = type_convert<bhalf_t>(x1_f);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__host__ __device__ constexpr void operator()<int8_t, int32_t, int8_t, int8_t>(
|
||||||
|
int8_t& e, const int32_t& c, const int8_t& d0, const int8_t& d1) const
|
||||||
|
{
|
||||||
|
const float x0_f =
|
||||||
|
type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1);
|
||||||
|
|
||||||
|
float x1_f = 0;
|
||||||
|
|
||||||
|
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
|
||||||
|
x0_f);
|
||||||
|
|
||||||
|
e = type_convert<int8_t>(x1_f);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,10 @@ namespace ck {
|
|||||||
namespace tensor_operation {
|
namespace tensor_operation {
|
||||||
namespace element_wise {
|
namespace element_wise {
|
||||||
|
|
||||||
|
#if CK_WORKAROUND_SWDEV_383542
|
||||||
|
extern "C" __device__ float __ocml_native_recip_f32(float);
|
||||||
|
#endif
|
||||||
|
|
||||||
struct PassThrough
|
struct PassThrough
|
||||||
{
|
{
|
||||||
template <typename Y, typename X>
|
template <typename Y, typename X>
|
||||||
@@ -200,36 +204,83 @@ struct Relu
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Y = FastGelu(X)
|
// Fast GeLU
|
||||||
|
// https://paperswithcode.com/method/gelu
|
||||||
|
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
|
||||||
|
// host code use higher accuracy "exp" and "div"
|
||||||
|
// gpu code use lower accuracy "__expf" and "rcp" function
|
||||||
struct FastGelu
|
struct FastGelu
|
||||||
{
|
{
|
||||||
// Fast GeLU
|
template <typename Y, typename X>
|
||||||
// https://paperswithcode.com/method/gelu
|
__host__ void operator()(Y& y, const X& x) const;
|
||||||
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
|
|
||||||
__host__ __device__ static constexpr float GetFastGeLU(float x)
|
template <typename Y, typename X>
|
||||||
|
__device__ void operator()(Y& y, const X& x) const;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__host__ void operator()<float, float>(float& y, const float& x) const
|
||||||
{
|
{
|
||||||
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
|
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
|
||||||
const float emu = exp(-u);
|
const float emu = exp(-u);
|
||||||
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
|
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
|
||||||
return x * cdf;
|
|
||||||
|
y = x * cdf;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
// device code, use lower precision "__expf" and "rcp"
|
||||||
static inline constexpr bool is_valid_param_type_v =
|
template <>
|
||||||
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
|
__device__ void operator()<float, float>(float& y, const float& x) const
|
||||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>
|
|
||||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
|
||||||
|| std::is_same_v<T, ck::int4_t>
|
|
||||||
#endif
|
|
||||||
;
|
|
||||||
|
|
||||||
template <typename Y, typename X>
|
|
||||||
__host__ __device__ void operator()(Y& y, const X& x) const
|
|
||||||
{
|
{
|
||||||
static_assert(is_valid_param_type_v<Y> && is_valid_param_type_v<X>);
|
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
|
||||||
|
const float emu = __expf(-u);
|
||||||
|
|
||||||
const float tmp_y = GetFastGeLU(type_convert<float>(x));
|
#if !CK_WORKAROUND_SWDEV_383542
|
||||||
y = type_convert<Y>(tmp_y);
|
const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f);
|
||||||
|
#else
|
||||||
|
const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
y = x * cdf;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__host__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
|
||||||
|
{
|
||||||
|
float y_f;
|
||||||
|
|
||||||
|
this->operator()<float, float>(y_f, type_convert<float>(x));
|
||||||
|
|
||||||
|
y = type_convert<half_t>(y_f);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
|
||||||
|
{
|
||||||
|
float y_f;
|
||||||
|
|
||||||
|
this->operator()<float, float>(y_f, type_convert<float>(x));
|
||||||
|
|
||||||
|
y = type_convert<half_t>(y_f);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__host__ void operator()<half_t, float>(half_t& y, const float& x) const
|
||||||
|
{
|
||||||
|
float y_f;
|
||||||
|
|
||||||
|
this->operator()<float, float>(y_f, x);
|
||||||
|
|
||||||
|
y = type_convert<half_t>(y_f);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ void operator()<half_t, float>(half_t& y, const float& x) const
|
||||||
|
{
|
||||||
|
float y_f;
|
||||||
|
|
||||||
|
this->operator()<float, float>(y_f, x);
|
||||||
|
|
||||||
|
y = type_convert<half_t>(y_f);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user