Consider gemm requant relu requant as gemm fusuion (#116)

* [What] Separate fixpoint gemm from gemm example
[Why] let example of gemm_int8 be pure gemm.
[What]
1. Add gemm_requant_relu_requant,
2. Let CDataType be int32 in pure gemm, because no one use int8 CDataType. It is also part of gemm_requant_relu_requant

* Fix path

* Revise cmakelist due to merge develop

Co-authored-by: rocking <chunylai@amd.com>

[ROCm/composable_kernel commit: 9a17e7fbfd]
This commit is contained in:
rocking5566
2022-03-12 10:41:03 +08:00
committed by GitHub
parent ea5f57fa92
commit 900ea4ae3e
4 changed files with 240 additions and 10 deletions

View File

@@ -25,12 +25,11 @@ using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using RequantReluRequant = ck::tensor_operation::element_wise::RequantReluRequant;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using CDataType = int32_t;
using AccDataType = int32_t;
using CShuffleDataType = int32_t;
@@ -50,7 +49,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
CLayout, // CLayout
PassThrough, // AElementwiseOperation
PassThrough, // BElementwiseOperation
RequantReluRequant, // CElementwiseOperation
PassThrough, // CElementwiseOperation
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
@@ -78,11 +77,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
4>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, RequantReluRequant>;
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
int main(int argc, char* argv[])
{
@@ -99,9 +98,6 @@ int main(int argc, char* argv[])
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
float scale_gemm = 0.03;
float scale_relu = 1;
if(argc == 4)
{
do_verification = std::stoi(argv[1]);
@@ -175,7 +171,7 @@ int main(int argc, char* argv[])
auto a_element_op = PassThrough{};
auto b_element_op = PassThrough{};
auto c_element_op = RequantReluRequant{scale_gemm, scale_relu};
auto c_element_op = PassThrough{};
// do GEMM
auto gemm = DeviceGemmInstance{};