From 0435c33638375e98bf967efbe3dd30d69010e4aa Mon Sep 17 00:00:00 2001 From: letaoqin Date: Sat, 12 Oct 2024 09:35:02 +0800 Subject: [PATCH] refactor code --- .../66_gemm_bias_activation/gemm_bias_add.hpp | 20 ++----- .../gemm_bias_add_fp16.cpp | 60 +++++++------------ .../gemm_bias_add_xdl_fp16.cpp | 39 ++++++------ 3 files changed, 45 insertions(+), 74 deletions(-) diff --git a/example/66_gemm_bias_activation/gemm_bias_add.hpp b/example/66_gemm_bias_activation/gemm_bias_add.hpp index 07e2993627..5c94c4c201 100644 --- a/example/66_gemm_bias_activation/gemm_bias_add.hpp +++ b/example/66_gemm_bias_activation/gemm_bias_add.hpp @@ -71,19 +71,6 @@ using Relu = ck::tensor_operation::element_wise::Relu; using Silu = ck::tensor_operation::element_wise::Silu; using Sigmoid = ck::tensor_operation::element_wise::Sigmoid; -enum class ActivationType -{ - Gelu = 0, - Relu, - Silu, - Swiglu, - Geglu, - Sigmoid, - Identity, - GeluNoneApproximate, - GeGluNoneApproximate, - InvalidType -}; struct GemmBiasAddArgs { const void* mat_a; @@ -95,6 +82,7 @@ struct GemmBiasAddArgs ck::index_t K; }; -float gemm_bias_add_fp16(const GemmBiasAddArgs& args, - const StreamConfig& config, - ActivationType op_type); +float gemm_bias_add_relu_fp16(const GemmBiasAddArgs& args, const StreamConfig& config); +float gemm_bias_add_gelu_fp16(const GemmBiasAddArgs& args, const StreamConfig& config); +float gemm_bias_add_silu_fp16(const GemmBiasAddArgs& args, const StreamConfig& config); +float gemm_bias_add_sigmoid_fp16(const GemmBiasAddArgs& args, const StreamConfig& config); diff --git a/example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp b/example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp index 98bb74e39e..e0ecb0d3d5 100644 --- a/example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp +++ b/example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp @@ -24,6 +24,7 @@ using ALayout = Row; using BLayout = Row; using D0Layout = Row; using CLayout = Row; +using DsLayout = ck::Tuple; using AElementOp = PassThrough; using BElementOp = PassThrough; @@ -34,7 +35,7 @@ using S = ck::Sequence; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // clang-format off -template +template using DeviceOpInstance_64_16_16_64 = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, @@ -51,7 +52,7 @@ using DeviceOpInstance_64_16_16_64 = ck::tensor_operation::device::DeviceGemmMul 1, 1, S<1, 16, 1, 4>, S<4, 4>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16>; -template +template using DeviceOpInstance_default = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, @@ -69,7 +70,7 @@ using DeviceOpInstance_default = ck::tensor_operation::device::DeviceGemmMultiD_ S<1, 16, 1, 4>, S<2, 2>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16>; // clang-format on -template +template float run_impl(const GemmBiasAddArgs& args, const StreamConfig& config) { using ADataType = ck::half_t; @@ -127,46 +128,31 @@ float run_impl(const GemmBiasAddArgs& args, const StreamConfig& config) return true; }; - auto gemm = DeviceOpInstance_64_16_16_64{}; + auto gemm = + DeviceOpInstance_64_16_16_64{}; if(!Run(gemm)) { - auto gemm_def = DeviceOpInstance_default{}; + auto gemm_def = + DeviceOpInstance_default{}; Run(gemm_def); } return ave_time; } -float gemm_bias_add_fp16(const GemmBiasAddArgs& args, - const StreamConfig& config, - ActivationType op_type) + +float gemm_bias_add_relu_fp16(const GemmBiasAddArgs& args, const StreamConfig& config) { - using DsLayout = ck::Tuple; - switch(op_type) - { - case ActivationType::Gelu: - case ActivationType::Geglu: - case ActivationType::GeluNoneApproximate: - case ActivationType::GeGluNoneApproximate: - return run_impl>(args, config); - case ActivationType::Relu: - return run_impl>(args, config); - case ActivationType::Silu: - case ActivationType::Swiglu: - return run_impl>(args, config); - case ActivationType::Sigmoid: - return run_impl>(args, config); - case ActivationType::Identity: - case ActivationType::InvalidType: - default: return 0; - } + return run_impl>(args, config); +} +float gemm_bias_add_gelu_fp16(const GemmBiasAddArgs& args, const StreamConfig& config) +{ + return run_impl>(args, config); +} +float gemm_bias_add_silu_fp16(const GemmBiasAddArgs& args, const StreamConfig& config) +{ + return run_impl>(args, config); +} +float gemm_bias_add_sigmoid_fp16(const GemmBiasAddArgs& args, const StreamConfig& config) +{ + return run_impl>(args, config); } diff --git a/example/66_gemm_bias_activation/gemm_bias_add_xdl_fp16.cpp b/example/66_gemm_bias_activation/gemm_bias_add_xdl_fp16.cpp index b111fb5548..2491d817ae 100644 --- a/example/66_gemm_bias_activation/gemm_bias_add_xdl_fp16.cpp +++ b/example/66_gemm_bias_activation/gemm_bias_add_xdl_fp16.cpp @@ -172,8 +172,8 @@ int main(int argc, char* argv[]) printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x)m, op_type(Gelu = 0, Relu, Silu, Swiglu, " - "Geglu, Identity, GeluNoneApproximate, GeGluNoneApproximate)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x)m, op_type(Gelu = 0, Relu = 1, Silu = 2, " + "Sigmoid = 3\n"); exit(0); } @@ -238,9 +238,15 @@ int main(int argc, char* argv[]) N, K}; - float ave_time = gemm_bias_add_fp16(gemm_args, - StreamConfig{nullptr, time_kernel, 20, 50}, - static_cast(op_type)); + float ave_time = 0; + if(op_type == 0) + gemm_bias_add_gelu_fp16(gemm_args, StreamConfig{nullptr, time_kernel, 20, 50}); + else if(op_type == 1) + gemm_bias_add_relu_fp16(gemm_args, StreamConfig{nullptr, time_kernel, 20, 50}); + else if(op_type == 2) + gemm_bias_add_silu_fp16(gemm_args, StreamConfig{nullptr, time_kernel, 20, 50}); + else + gemm_bias_add_sigmoid_fp16(gemm_args, StreamConfig{nullptr, time_kernel, 20, 50}); // float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50}); std::size_t flop = std::size_t(2) * M * N * K; @@ -275,23 +281,14 @@ int main(int argc, char* argv[]) } } }; - ActivationType type = static_cast(op_type); - switch(type) - { - case ActivationType::Gelu: - case ActivationType::Geglu: - case ActivationType::GeluNoneApproximate: - case ActivationType::GeGluNoneApproximate: + if(op_type == 0) run_elementwise(ck::impl::AddActivation{}); - break; - case ActivationType::Relu: run_elementwise(ck::impl::AddActivation{}); break; - case ActivationType::Silu: - case ActivationType::Swiglu: run_elementwise(ck::impl::AddActivation{}); break; - case ActivationType::Sigmoid: run_elementwise(ck::impl::AddActivation{}); break; - case ActivationType::Identity: - case ActivationType::InvalidType: - default: break; - } + else if(op_type == 1) + run_elementwise(ck::impl::AddActivation{}); + else if(op_type == 2) + run_elementwise(ck::impl::AddActivation{}); + else + run_elementwise(ck::impl::AddActivation{}); e_device_buf.FromDevice(e_m_n_device_result.mData.data());