From 52d313938ac2acf03a2dec54bf3f8d5b46ba6ec9 Mon Sep 17 00:00:00 2001 From: letaoqin Date: Fri, 11 Oct 2024 23:06:47 +0800 Subject: [PATCH] finish gelu,relu and silu --- .../66_gemm_bias_activation/CMakeLists.txt | 2 +- .../66_gemm_bias_activation/gemm_bias_add.hpp | 13 ++++- .../gemm_bias_add_fp16.cpp | 56 +++++++++++++----- .../gemm_bias_add_xdl_fp16.cpp | 58 +++++++++++++------ 4 files changed, 94 insertions(+), 35 deletions(-) diff --git a/example/66_gemm_bias_activation/CMakeLists.txt b/example/66_gemm_bias_activation/CMakeLists.txt index 5541b87551..385a90e4a2 100644 --- a/example/66_gemm_bias_activation/CMakeLists.txt +++ b/example/66_gemm_bias_activation/CMakeLists.txt @@ -1,6 +1,6 @@ set(GEMM_BIAS_ADD_SOURCES - gemm_bias_add_xdl_fp16.cpp gemm_bias_add_fp16.cpp + gemm_bias_add_xdl_fp16.cpp ) add_executable(example_gemm_bias_add_xdl_fp16 ${GEMM_BIAS_ADD_SOURCES}) target_link_libraries(example_gemm_bias_add_xdl_fp16 PRIVATE utility) diff --git a/example/66_gemm_bias_activation/gemm_bias_add.hpp b/example/66_gemm_bias_activation/gemm_bias_add.hpp index 6b15659020..07e2993627 100644 --- a/example/66_gemm_bias_activation/gemm_bias_add.hpp +++ b/example/66_gemm_bias_activation/gemm_bias_add.hpp @@ -7,6 +7,8 @@ #include "ck/stream_config.hpp" #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" namespace ck { namespace impl { @@ -63,6 +65,12 @@ struct AddActivation } // namespace impl } // namespace ck +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Gelu = ck::tensor_operation::element_wise::Gelu; +using Relu = ck::tensor_operation::element_wise::Relu; +using Silu = ck::tensor_operation::element_wise::Silu; +using Sigmoid = ck::tensor_operation::element_wise::Sigmoid; + enum class ActivationType { Gelu = 0, @@ -70,6 +78,7 @@ enum class ActivationType Silu, Swiglu, Geglu, + Sigmoid, Identity, GeluNoneApproximate, GeGluNoneApproximate, @@ -86,4 +95,6 @@ struct GemmBiasAddArgs ck::index_t K; }; -float gemm_bias_add_fp16(const GemmBiasAddArgs& args, const StreamConfig& config); +float gemm_bias_add_fp16(const GemmBiasAddArgs& args, + const StreamConfig& config, + ActivationType op_type); diff --git a/example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp b/example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp index e031f216ac..98bb74e39e 100644 --- a/example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp +++ b/example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp @@ -23,16 +23,10 @@ using CShuffleDataType = F32; using ALayout = Row; using BLayout = Row; using D0Layout = Row; -using DsLayout = ck::Tuple; using CLayout = Row; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using Relu = ck::tensor_operation::element_wise::Relu; - -using AElementOp = PassThrough; -using BElementOp = PassThrough; -using CDEElementOp = ck::impl::AddActivation; -; +using AElementOp = PassThrough; +using BElementOp = PassThrough; template using S = ck::Sequence; @@ -40,7 +34,7 @@ using S = ck::Sequence; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // clang-format off -template +template using DeviceOpInstance_64_16_16_64 = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, @@ -57,7 +51,7 @@ using DeviceOpInstance_64_16_16_64 = ck::tensor_operation::device::DeviceGemmMul 1, 1, S<1, 16, 1, 4>, S<4, 4>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16>; -template +template using DeviceOpInstance_default = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, @@ -75,8 +69,8 @@ using DeviceOpInstance_default = ck::tensor_operation::device::DeviceGemmMultiD_ S<1, 16, 1, 4>, S<2, 2>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16>; // clang-format on - -float gemm_bias_add_fp16(const GemmBiasAddArgs& args, const StreamConfig& config) +template +float run_impl(const GemmBiasAddArgs& args, const StreamConfig& config) { using ADataType = ck::half_t; using BDataType = ck::half_t; @@ -133,12 +127,46 @@ float gemm_bias_add_fp16(const GemmBiasAddArgs& args, const StreamConfig& config return true; }; - auto gemm = DeviceOpInstance_64_16_16_64{}; + auto gemm = DeviceOpInstance_64_16_16_64{}; if(!Run(gemm)) { - auto gemm_def = DeviceOpInstance_default{}; + auto gemm_def = DeviceOpInstance_default{}; Run(gemm_def); } return ave_time; } +float gemm_bias_add_fp16(const GemmBiasAddArgs& args, + const StreamConfig& config, + ActivationType op_type) +{ + using DsLayout = ck::Tuple; + switch(op_type) + { + case ActivationType::Gelu: + case ActivationType::Geglu: + case ActivationType::GeluNoneApproximate: + case ActivationType::GeGluNoneApproximate: + return run_impl>(args, config); + case ActivationType::Relu: + return run_impl>(args, config); + case ActivationType::Silu: + case ActivationType::Swiglu: + return run_impl>(args, config); + case ActivationType::Sigmoid: + return run_impl>(args, config); + case ActivationType::Identity: + case ActivationType::InvalidType: + default: return 0; + } +} diff --git a/example/66_gemm_bias_activation/gemm_bias_add_xdl_fp16.cpp b/example/66_gemm_bias_activation/gemm_bias_add_xdl_fp16.cpp index a4056d8564..b111fb5548 100644 --- a/example/66_gemm_bias_activation/gemm_bias_add_xdl_fp16.cpp +++ b/example/66_gemm_bias_activation/gemm_bias_add_xdl_fp16.cpp @@ -13,8 +13,6 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -39,12 +37,14 @@ using D0Layout = Row; using DsLayout = ck::Tuple; using ELayout = Row; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using Relu = ck::tensor_operation::element_wise::Relu; +// using PassThrough = ck::tensor_operation::element_wise::PassThrough; +// using Gelu = ck::tensor_operation::element_wise::Gelu; +// using Relu = ck::tensor_operation::element_wise::Relu; +// using Silu = ck::tensor_operation::element_wise::Silu; +// using Sigmoid = ck::tensor_operation::element_wise::Sigmoid; using AElementOp = PassThrough; using BElementOp = PassThrough; -using CElementOp = ck::impl::AddActivation; using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm(op_type)); // float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50}); std::size_t flop = std::size_t(2) * M * N * K; @@ -253,9 +258,6 @@ int main(int argc, char* argv[]) if(do_verification) { - - // RunUnfusedTest(a0_m_k.mData, b0_k_n.mData, d0_m_n.mData, e_m_n_host_result.mData, K, M, - // N); auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); @@ -264,13 +266,31 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - CElementOp cde_element_op; - for(int m = 0; m < M; ++m) - { - for(int n = 0; n < N; ++n) + auto run_elementwise = [&](auto cde_element_op) { + for(int m = 0; m < M; ++m) { - cde_element_op(e_m_n_host_result(m, n), e_m_n_host_result(m, n), d0_m_n(m, n)); + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), e_m_n_host_result(m, n), d0_m_n(m, n)); + } } + }; + ActivationType type = static_cast(op_type); + switch(type) + { + case ActivationType::Gelu: + case ActivationType::Geglu: + case ActivationType::GeluNoneApproximate: + case ActivationType::GeGluNoneApproximate: + run_elementwise(ck::impl::AddActivation{}); + break; + case ActivationType::Relu: run_elementwise(ck::impl::AddActivation{}); break; + case ActivationType::Silu: + case ActivationType::Swiglu: run_elementwise(ck::impl::AddActivation{}); break; + case ActivationType::Sigmoid: run_elementwise(ck::impl::AddActivation{}); break; + case ActivationType::Identity: + case ActivationType::InvalidType: + default: break; } e_device_buf.FromDevice(e_m_n_device_result.mData.data());