Gemm_c_shuffle (4 layouts) X (fp32 bf16 int8) (#131)

* [What] Separate fixpoint gemm from gemm example
[Why] let example of gemm_int8 be pure gemm.
[What]
1. Add gemm_requant_relu_requant,
2. Let CDataType be int32 in pure gemm, because no one use int8 CDataType. It is also part of gemm_requant_relu_requant

* Fix path

* Revise cmakelist due to merge develop

* Add gemm fp16 test

* Extract PrepareGemmTensor

* Extract TestGemm

* Add test for different layout

* Add 4 layouts of shuffle version of fp32

* Add 4 layouts of shuffle version of int8

* Add 4 layouts of shuffle version of bf16

* replace all DeviceGemmPtr_ with DeviceGemmNoOpPtr to fit naming convension

* Add test for non-shuffle verstion of gemm

* Fix typo

* Print kernel information

* Add rest of the fp32 kernel to the test

* 1. Add rest of the fp16 device iop.
2. Mark the invalid device operation

Co-authored-by: rocking <chunylai@amd.com>
This commit is contained in:
rocking5566
2022-03-22 04:59:51 +08:00
committed by GitHub
parent b51808d7a5
commit 485ea46a40
24 changed files with 1497 additions and 322 deletions

View File

@@ -223,6 +223,26 @@ int profile_gemm(int argc, char* argv[])
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_gemm_impl<int8_t,
@@ -243,6 +263,66 @@ int profile_gemm(int argc, char* argv[])
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_gemm_impl<ck::bhalf_t,
@@ -263,6 +343,46 @@ int profile_gemm(int argc, char* argv[])
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else
{
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");