Int8 qunatization gemm xdl (#108)

* Add int8 of mk_nk_mn to the ckProfiler

* Add example of int8 gemm

* Fix typo, use ushort instead of half_t for bfloat16

* replace ushortXXX_t to bhalfXXX_t

* rename ushort to bhalf_t

* Add bf16 example

* Add bf16 gemm to ckProfiler

* Fix alignment

* Fix typo

* Add unit test for gemm_xdl int8

* Add gemm_xdl fp32 unit test

* Add gemm_xdl bf16 unit test

* fix build

* fix build issue due to merge conflict

* Fix build

* Fix build error

* [What] gemm + relu inference
[How] gemm + requant + relu + requant + clamp

* clean

Co-authored-by: rocking <chunylai@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
rocking5566
2022-03-05 14:48:09 +08:00
committed by GitHub
parent 5b178874a1
commit ad41aa0e7a
14 changed files with 191 additions and 151 deletions

View File

@@ -54,11 +54,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1>;
#elif 1
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
//######|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>;
//######|AData| BData| CData| AccData| Shuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| Type| Type| Type| Type| Data| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| | | | | Type| | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>;
#elif 0
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num|

View File

@@ -43,6 +43,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
BDataType, // BDataType
CDataType, // CDataType
AccDataType, // AccDataType
CDataType, // CShuffleDataType
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout

View File

@@ -25,12 +25,14 @@ using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using RequantReluRequant = ck::tensor_operation::element_wise::RequantReluRequant;
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using AccDataType = int32_t;
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using AccDataType = int32_t;
using CShuffleDataType = int32_t;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
@@ -42,12 +44,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
BDataType, // BDataType
CDataType, // CDataType
AccDataType, // AccDataType
CShuffleDataType, // CShuffleDataType
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
PassThrough, // AElementwiseOperation
PassThrough, // BElementwiseOperation
PassThrough, // CElementwiseOperation
RequantReluRequant, // CElementwiseOperation
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
@@ -79,7 +82,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, RequantReluRequant>;
int main(int argc, char* argv[])
{
@@ -96,6 +99,9 @@ int main(int argc, char* argv[])
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
float scale_gemm = 0.03;
float scale_relu = 1;
if(argc == 4)
{
do_verification = std::stoi(argv[1]);
@@ -169,7 +175,7 @@ int main(int argc, char* argv[])
auto a_element_op = PassThrough{};
auto b_element_op = PassThrough{};
auto c_element_op = PassThrough{};
auto c_element_op = RequantReluRequant{scale_gemm, scale_relu};
// do GEMM
auto gemm = DeviceGemmInstance{};