mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
refactor
This commit is contained in:
@@ -19,22 +19,36 @@
|
||||
#include "reference_gemm.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
|
||||
struct RequantReluRequant
|
||||
{
|
||||
// FIXME: We just need one scale for Relu / Leaky Relu / PRelu
|
||||
RequantReluRequant(float scaleGemm, float scaleRelu)
|
||||
: scaleGemm_(scaleGemm), scaleRelu_(scaleRelu)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(float& y, const float& x) const
|
||||
{
|
||||
float gemm_requant = scaleGemm_ * x;
|
||||
float relu = gemm_requant > 0 ? gemm_requant : 0;
|
||||
float relu_requant = scaleRelu_ * relu;
|
||||
y = relu_requant > 127 ? 127 : relu_requant < -128 ? -128 : relu_requant;
|
||||
}
|
||||
|
||||
float scaleGemm_;
|
||||
float scaleRelu_;
|
||||
};
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using RequantReluRequant = ck::tensor_operation::element_wise::RequantReluRequant;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using CDataType = int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using CShuffleDataType = int32_t;
|
||||
using CShuffleDataType = float;
|
||||
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
@@ -143,37 +143,6 @@ struct AddHardswishAdd
|
||||
}
|
||||
};
|
||||
|
||||
struct RequantReluRequant
|
||||
{
|
||||
// FIXME: We just need one scale for Relu / Leaky Relu / PRelu
|
||||
RequantReluRequant(float scaleGemm, float scaleRelu)
|
||||
: scaleGemm_(scaleGemm), scaleRelu_(scaleRelu)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(int8_t& y, const int& x) const
|
||||
{
|
||||
float gemm_requant = scaleGemm_ * static_cast<float>(x);
|
||||
float relu = gemm_requant > 0 ? gemm_requant : 0;
|
||||
float relu_requant = scaleRelu_ * relu;
|
||||
y = static_cast<int8_t>(relu_requant > 127 ? 127
|
||||
: relu_requant < -128 ? -128 : relu_requant);
|
||||
}
|
||||
|
||||
// for reference_gemm
|
||||
__host__ __device__ constexpr void operator()(float& y, const float& x) const
|
||||
{
|
||||
float gemm_requant = scaleGemm_ * x;
|
||||
float relu = gemm_requant > 0 ? gemm_requant : 0;
|
||||
float relu_requant = scaleRelu_ * relu;
|
||||
y = static_cast<float>(relu_requant > 127 ? 127
|
||||
: relu_requant < -128 ? -128 : relu_requant);
|
||||
}
|
||||
|
||||
float scaleGemm_;
|
||||
float scaleRelu_;
|
||||
};
|
||||
|
||||
// Unary operators are usually called element-wisely before/after the reduction is executed on the
|
||||
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
|
||||
|
||||
|
||||
@@ -171,9 +171,12 @@ check_err(const std::vector<T>& out,
|
||||
|
||||
for(std::size_t i = 0; i < ref.size(); ++i)
|
||||
{
|
||||
if(out[i] != ref[i])
|
||||
const int64_t out_v = static_cast<int64_t>(out[i]);
|
||||
const int64_t ref_v = static_cast<int64_t>(ref[i]);
|
||||
|
||||
if(out_v != ref_v)
|
||||
{
|
||||
std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << " != " << ref[i]
|
||||
std::cout << "out[" << i << "] != ref[" << i << "]: " << out_v << " != " << ref_v
|
||||
<< std::endl
|
||||
<< msg << std::endl;
|
||||
return false;
|
||||
|
||||
Reference in New Issue
Block a user