mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-26 08:00:13 +00:00
Consider gemm requant relu requant as gemm fusuion (#116)
* [What] Separate fixpoint gemm from gemm example
[Why] let example of gemm_int8 be pure gemm.
[What]
1. Add gemm_requant_relu_requant,
2. Let CDataType be int32 in pure gemm, because no one use int8 CDataType. It is also part of gemm_requant_relu_requant
* Fix path
* Revise cmakelist due to merge develop
Co-authored-by: rocking <chunylai@amd.com>
[ROCm/composable_kernel commit: 9a17e7fbfd]
This commit is contained in:
@@ -25,12 +25,11 @@ using F32 = float;
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using RequantReluRequant = ck::tensor_operation::element_wise::RequantReluRequant;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using CDataType = int8_t;
|
||||
using CDataType = int32_t;
|
||||
using AccDataType = int32_t;
|
||||
using CShuffleDataType = int32_t;
|
||||
|
||||
@@ -50,7 +49,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
|
||||
CLayout, // CLayout
|
||||
PassThrough, // AElementwiseOperation
|
||||
PassThrough, // BElementwiseOperation
|
||||
RequantReluRequant, // CElementwiseOperation
|
||||
PassThrough, // CElementwiseOperation
|
||||
256, // BlockSize
|
||||
256, // MPerBlock
|
||||
128, // NPerBlock
|
||||
@@ -78,11 +77,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
4>; // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, RequantReluRequant>;
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
@@ -99,9 +98,6 @@ int main(int argc, char* argv[])
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
float scale_gemm = 0.03;
|
||||
float scale_relu = 1;
|
||||
|
||||
if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
@@ -175,7 +171,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
auto a_element_op = PassThrough{};
|
||||
auto b_element_op = PassThrough{};
|
||||
auto c_element_op = RequantReluRequant{scale_gemm, scale_relu};
|
||||
auto c_element_op = PassThrough{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
|
||||
Reference in New Issue
Block a user