From 8f455615a8092822637fb2e7691d38a98456f276 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sun, 26 Feb 2023 21:19:11 -0800 Subject: [PATCH] Fast GeLU using built-in function (#587) * clean up * fast gelu using builtin function * clean * clean * clean * clean: * clean * fix compilation * clean * clean --------- Co-authored-by: zjing14 --- example/04_gemm_add_add_fastgelu/common.hpp | 2 +- .../gemm_add_add_fastgelu_xdl_bf16.cpp | 11 +- .../gemm_add_add_fastgelu_xdl_fp16.cpp | 11 +- .../gemm_add_add_fastgelu_xdl_fp32.cpp | 12 +-- .../gemm_add_add_fastgelu_xdl_int4.cpp | 11 +- .../gemm_add_add_fastgelu_xdl_int8.cpp | 11 +- .../run_gemm_add_add_fastgelu_example.inc | 2 +- include/ck/ck.hpp | 3 + .../element/binary_element_wise_operation.hpp | 54 +++++----- .../gpu/element/element_wise_operation.hpp | 102 +++++++++++------- .../element/unary_element_wise_operation.hpp | 91 ++++++++++++---- 11 files changed, 194 insertions(+), 116 deletions(-) diff --git a/example/04_gemm_add_add_fastgelu/common.hpp b/example/04_gemm_add_add_fastgelu/common.hpp index 3f9375e092..839587c148 100644 --- a/example/04_gemm_add_add_fastgelu/common.hpp +++ b/example/04_gemm_add_add_fastgelu/common.hpp @@ -62,7 +62,7 @@ struct ExecutionConfig final }; inline bool -parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig config) +parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) { if(argc == 1) { diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp index 5e50c14dc2..ba0476b9b9 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp @@ -7,10 +7,11 @@ using ADataType = BF16; using BDataType = BF16; using AccDataType = F32; using CShuffleDataType = F32; -using D0DataType = BF16; -using D1DataType = BF16; -using DsDataType = ck::Tuple; -using EDataType = BF16; +using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification +using D0DataType = BF16; +using D1DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; using ALayout = Row; using BLayout = Col; @@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; -using EDataType = F16; +using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification +using D0DataType = F16; +using D1DataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; using ALayout = Row; using BLayout = Col; @@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; -using EDataType = F32; +using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; using ALayout = Row; using BLayout = Col; @@ -36,7 +36,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; -using EDataType = I4; +using CDataType = I32; // C matrix doesn't exsit in GPU memory, this is used for host verification +using D0DataType = I4; +using D1DataType = I4; +using DsDataType = ck::Tuple; +using EDataType = I4; using KernelADataType = I8; using KernelBDataType = I8; @@ -47,7 +48,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; -using EDataType = I8; +using CDataType = I32; // C matrix doesn't exsit in GPU memory, this is used for host verification +using D0DataType = I8; +using D1DataType = I8; +using DsDataType = ck::Tuple; +using EDataType = I8; using ALayout = Row; using BLayout = Col; @@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm c_m_n({M, N}); + Tensor c_m_n({M, N}); auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index ffd7e74f12..1257a77649 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -168,6 +168,9 @@ // tuning parameter #define CK_WORKAROUND_SWDEV_325164 0 +// workaround: compiler not emiting reciprocal instruction frm __frcp_rn() +#define CK_WORKAROUND_SWDEV_383542 1 + // flag to enable (1) or disable (0) the debugging output in some kernels #define DEBUG_LOG 0 diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 69fa75c3fd..136017c6d1 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" namespace ck { namespace tensor_operation { @@ -280,43 +281,42 @@ struct AddHardswish }; }; -// C = A * B // E = FastGelu(C + D) struct AddFastGelu { - // Fast GeLU - // https://paperswithcode.com/method/gelu - // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) - __host__ __device__ static constexpr float GetFastGeLU(float x) - { - const float u = 2.f * x * (0.035677f * x * x + 0.797885f); - const float emu = exp(-u); - const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); - return x * cdf; - } - - template - static inline constexpr bool is_valid_param_type_v = - std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v; - template - __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void + operator()(float& e, const float& c, const float& d) const { - static_assert(is_valid_param_type_v && is_valid_param_type_v && - is_valid_param_type_v); + const float x = c + d; - const float y = GetFastGeLU(type_convert(c) + type_convert(d)); - - e = type_convert(y); + FastGelu{}.template operator()(e, x); } - template - __host__ __device__ constexpr void operator()(float& e, const float& c, const D& d) const + template <> + __host__ __device__ constexpr void + operator()(half_t& e, const half_t& c, const half_t& d) const { - static_assert(is_valid_param_type_v); + const half_t x = c + d; - e = GetFastGeLU(c + type_convert(d)); + ck::tensor_operation::element_wise::FastGelu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void + operator()(half_t& e, const float& c, const half_t& d) const + { + const float x0_f = c + d; + + float x1_f = 0; + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(x1_f, + x0_f); + + e = type_convert(x1_f); } }; diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 7f3d450a39..ceb2b665b9 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -16,7 +16,7 @@ namespace element_wise { // Need to ensure compiler will fail if there is no matching candidate, instead of compiler // siliently do implicit type conversion // -// Method 1: +// Example: // // struct ExampleElementwiseOp // { @@ -30,19 +30,6 @@ namespace element_wise { // { // } // }; -// -// Method 2: -// -// template -// struct ExampleElementwiseOp; -// -// template <> -// struct ExampleElementwiseOp -// { -// __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const -// { -// } -// }; struct AddReluAdd { @@ -208,41 +195,74 @@ struct AddMultiply } }; -// C = A * B // E = FastGelu(C + D0 + D1) struct AddAddFastGelu { - // Fast GeLU - // https://paperswithcode.com/method/gelu - // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) - __host__ __device__ static constexpr float GetFastGeLU(float x) - { - const float u = 2.f * x * (0.035677f * x * x + 0.797885f); - const float emu = exp(-u); - const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); - return x * cdf; - } - - template - static inline constexpr bool is_valid_param_type_v = - std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v -#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 - || std::is_same_v -#endif - ; - template __host__ __device__ constexpr void - operator()(E& e, const C& c, const D0& d0, const D1& d1) const + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()(float& e, + const float& c, + const float& d0, + const float& d1) const { - static_assert(is_valid_param_type_v && is_valid_param_type_v && - is_valid_param_type_v && is_valid_param_type_v); + const float x = c + d0 + d1; - const float y = - GetFastGeLU(type_convert(c) + type_convert(d0) + type_convert(d1)); + FastGelu{}.template operator()(e, x); + } - e = type_convert(y); + template <> + __host__ __device__ constexpr void operator()( + half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const + { + const half_t x = c + d0 + d1; + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void operator()( + half_t& e, const float& c, const half_t& d0, const half_t& d1) const + { + const float x0_f = c + d0 + d1; + + float x1_f = 0; + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(x1_f, + x0_f); + + e = type_convert(x1_f); + } + + template <> + __host__ __device__ constexpr void operator()( + bhalf_t& e, const float& c, const bhalf_t& d0, const bhalf_t& d1) const + { + const float x0_f = c + type_convert(d0) + type_convert(d1); + + float x1_f = 0; + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(x1_f, + x0_f); + + e = type_convert(x1_f); + } + + template <> + __host__ __device__ constexpr void operator()( + int8_t& e, const int32_t& c, const int8_t& d0, const int8_t& d1) const + { + const float x0_f = + type_convert(c) + type_convert(d0) + type_convert(d1); + + float x1_f = 0; + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(x1_f, + x0_f); + + e = type_convert(x1_f); } }; diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 2167a79e01..6b4df3b60e 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -11,6 +11,10 @@ namespace ck { namespace tensor_operation { namespace element_wise { +#if CK_WORKAROUND_SWDEV_383542 +extern "C" __device__ float __ocml_native_recip_f32(float); +#endif + struct PassThrough { template @@ -200,36 +204,83 @@ struct Relu } }; -// Y = FastGelu(X) +// Fast GeLU +// https://paperswithcode.com/method/gelu +// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) +// host code use higher accuracy "exp" and "div" +// gpu code use lower accuracy "__expf" and "rcp" function struct FastGelu { - // Fast GeLU - // https://paperswithcode.com/method/gelu - // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) - __host__ __device__ static constexpr float GetFastGeLU(float x) + template + __host__ void operator()(Y& y, const X& x) const; + + template + __device__ void operator()(Y& y, const X& x) const; + + template <> + __host__ void operator()(float& y, const float& x) const { const float u = 2.f * x * (0.035677f * x * x + 0.797885f); const float emu = exp(-u); const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); - return x * cdf; + + y = x * cdf; } - template - static inline constexpr bool is_valid_param_type_v = - std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v -#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 - || std::is_same_v -#endif - ; - - template - __host__ __device__ void operator()(Y& y, const X& x) const + // device code, use lower precision "__expf" and "rcp" + template <> + __device__ void operator()(float& y, const float& x) const { - static_assert(is_valid_param_type_v && is_valid_param_type_v); + const float u = 2.f * x * (0.035677f * x * x + 0.797885f); + const float emu = __expf(-u); - const float tmp_y = GetFastGeLU(type_convert(x)); - y = type_convert(tmp_y); +#if !CK_WORKAROUND_SWDEV_383542 + const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f); +#else + const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f); +#endif + + y = x * cdf; + } + + template <> + __host__ void operator()(half_t& y, const half_t& x) const + { + float y_f; + + this->operator()(y_f, type_convert(x)); + + y = type_convert(y_f); + } + + template <> + __device__ void operator()(half_t& y, const half_t& x) const + { + float y_f; + + this->operator()(y_f, type_convert(x)); + + y = type_convert(y_f); + } + + template <> + __host__ void operator()(half_t& y, const float& x) const + { + float y_f; + + this->operator()(y_f, x); + + y = type_convert(y_f); + } + + template <> + __device__ void operator()(half_t& y, const float& x) const + { + float y_f; + + this->operator()(y_f, x); + + y = type_convert(y_f); } };