From 7cd48ef11e28a070488a7d865f31178184a6b335 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 21 Apr 2022 17:28:53 +0000 Subject: [PATCH] refactor --- .../gemm_xdl_requant_relu_requant_int8.cpp | 30 +++++++++++++----- .../gpu/element/element_wise_operation.hpp | 31 ------------------- .../include/ck/library/utility/check_err.hpp | 7 +++-- 3 files changed, 27 insertions(+), 41 deletions(-) diff --git a/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp b/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp index ca3b58bd00..482f9457ee 100644 --- a/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp +++ b/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp @@ -19,22 +19,36 @@ #include "reference_gemm.hpp" #include "gemm_specialization.hpp" +struct RequantReluRequant +{ + // FIXME: We just need one scale for Relu / Leaky Relu / PRelu + RequantReluRequant(float scaleGemm, float scaleRelu) + : scaleGemm_(scaleGemm), scaleRelu_(scaleRelu) + { + } + + __host__ __device__ constexpr void operator()(float& y, const float& x) const + { + float gemm_requant = scaleGemm_ * x; + float relu = gemm_requant > 0 ? gemm_requant : 0; + float relu_requant = scaleRelu_ * relu; + y = relu_requant > 127 ? 127 : relu_requant < -128 ? -128 : relu_requant; + } + + float scaleGemm_; + float scaleRelu_; +}; + template using S = ck::Sequence; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using RequantReluRequant = ck::tensor_operation::element_wise::RequantReluRequant; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ADataType = int8_t; using BDataType = int8_t; using CDataType = int8_t; using AccDataType = int32_t; -using CShuffleDataType = int32_t; +using CShuffleDataType = float; using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 5b3606e859..ab1cbfed45 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -143,37 +143,6 @@ struct AddHardswishAdd } }; -struct RequantReluRequant -{ - // FIXME: We just need one scale for Relu / Leaky Relu / PRelu - RequantReluRequant(float scaleGemm, float scaleRelu) - : scaleGemm_(scaleGemm), scaleRelu_(scaleRelu) - { - } - - __host__ __device__ constexpr void operator()(int8_t& y, const int& x) const - { - float gemm_requant = scaleGemm_ * static_cast(x); - float relu = gemm_requant > 0 ? gemm_requant : 0; - float relu_requant = scaleRelu_ * relu; - y = static_cast(relu_requant > 127 ? 127 - : relu_requant < -128 ? -128 : relu_requant); - } - - // for reference_gemm - __host__ __device__ constexpr void operator()(float& y, const float& x) const - { - float gemm_requant = scaleGemm_ * x; - float relu = gemm_requant > 0 ? gemm_requant : 0; - float relu_requant = scaleRelu_ * relu; - y = static_cast(relu_requant > 127 ? 127 - : relu_requant < -128 ? -128 : relu_requant); - } - - float scaleGemm_; - float scaleRelu_; -}; - // Unary operators are usually called element-wisely before/after the reduction is executed on the // elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2 diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index 280ac83883..3a2396507c 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -171,9 +171,12 @@ check_err(const std::vector& out, for(std::size_t i = 0; i < ref.size(); ++i) { - if(out[i] != ref[i]) + const int64_t out_v = static_cast(out[i]); + const int64_t ref_v = static_cast(ref[i]); + + if(out_v != ref_v) { - std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << " != " << ref[i] + std::cout << "out[" << i << "] != ref[" << i << "]: " << out_v << " != " << ref_v << std::endl << msg << std::endl; return false;