From 220bc28498680e8c335ab5c5e8be092941416f07 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Wed, 24 Nov 2021 12:33:55 -0600 Subject: [PATCH] add args for packed gemm (#54) [ROCm/composable_kernel commit: 567f5e9c5f0aa6481570fba9267224626014542f] --- profiler/gemm_profiler.cpp | 96 +++++++++++++++++++++++++++++++------- script/profile_gemm.sh | 25 +++++++++- 2 files changed, 103 insertions(+), 18 deletions(-) diff --git a/profiler/gemm_profiler.cpp b/profiler/gemm_profiler.cpp index d832c7db50..31b2d84c53 100644 --- a/profiler/gemm_profiler.cpp +++ b/profiler/gemm_profiler.cpp @@ -70,8 +70,16 @@ int gemm_profiler(int argc, char* argv[]) ck::half_t, ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, - ck::tensor_layout::gemm::RowMajor>( - do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); + ck::tensor_layout::gemm::RowMajor>(do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); } else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) { @@ -80,8 +88,16 @@ int gemm_profiler(int argc, char* argv[]) ck::half_t, ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, - ck::tensor_layout::gemm::RowMajor>( - do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); + ck::tensor_layout::gemm::RowMajor>(do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC); } else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) { @@ -90,8 +106,16 @@ int gemm_profiler(int argc, char* argv[]) ck::half_t, ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, - ck::tensor_layout::gemm::RowMajor>( - do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); + ck::tensor_layout::gemm::RowMajor>(do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); } else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) { @@ -100,8 +124,16 @@ int gemm_profiler(int argc, char* argv[]) ck::half_t, ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, - ck::tensor_layout::gemm::RowMajor>( - do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); + ck::tensor_layout::gemm::RowMajor>(do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC); } else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) { @@ -110,8 +142,16 @@ int gemm_profiler(int argc, char* argv[]) float, ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, - ck::tensor_layout::gemm::RowMajor>( - do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); + ck::tensor_layout::gemm::RowMajor>(do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); } else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) { @@ -120,8 +160,16 @@ int gemm_profiler(int argc, char* argv[]) float, ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, - ck::tensor_layout::gemm::RowMajor>( - do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); + ck::tensor_layout::gemm::RowMajor>(do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC); } else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) { @@ -130,8 +178,16 @@ int gemm_profiler(int argc, char* argv[]) float, ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, - ck::tensor_layout::gemm::RowMajor>( - do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); + ck::tensor_layout::gemm::RowMajor>(do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); } else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) { @@ -140,8 +196,16 @@ int gemm_profiler(int argc, char* argv[]) float, ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, - ck::tensor_layout::gemm::RowMajor>( - do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); + ck::tensor_layout::gemm::RowMajor>(do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC); } else { diff --git a/script/profile_gemm.sh b/script/profile_gemm.sh index bbd9ad051e..036d0440e0 100755 --- a/script/profile_gemm.sh +++ b/script/profile_gemm.sh @@ -18,7 +18,28 @@ REPEAT=$7 ######## op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 256 256 256 256 256 256 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 1024 1024 1024 -#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 2048 2048 2048 - $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 4096 4096 4096 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 4096 4096 4096 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 8192 8192 8192 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 + +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1 + +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192 + +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224 + +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256