mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
add args for packed gemm (#54)
This commit is contained in:
@@ -70,8 +70,16 @@ int gemm_profiler(int argc, char* argv[])
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC);
|
||||
ck::tensor_layout::gemm::RowMajor>(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
nrepeat,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
@@ -80,8 +88,16 @@ int gemm_profiler(int argc, char* argv[])
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC);
|
||||
ck::tensor_layout::gemm::RowMajor>(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
nrepeat,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
@@ -90,8 +106,16 @@ int gemm_profiler(int argc, char* argv[])
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC);
|
||||
ck::tensor_layout::gemm::RowMajor>(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
nrepeat,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
@@ -100,8 +124,16 @@ int gemm_profiler(int argc, char* argv[])
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC);
|
||||
ck::tensor_layout::gemm::RowMajor>(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
nrepeat,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
@@ -110,8 +142,16 @@ int gemm_profiler(int argc, char* argv[])
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC);
|
||||
ck::tensor_layout::gemm::RowMajor>(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
nrepeat,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
@@ -120,8 +160,16 @@ int gemm_profiler(int argc, char* argv[])
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC);
|
||||
ck::tensor_layout::gemm::RowMajor>(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
nrepeat,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
@@ -130,8 +178,16 @@ int gemm_profiler(int argc, char* argv[])
|
||||
float,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC);
|
||||
ck::tensor_layout::gemm::RowMajor>(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
nrepeat,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
@@ -140,8 +196,16 @@ int gemm_profiler(int argc, char* argv[])
|
||||
float,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC);
|
||||
ck::tensor_layout::gemm::RowMajor>(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
nrepeat,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -18,7 +18,28 @@ REPEAT=$7
|
||||
######## op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC
|
||||
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 256 256 256 256 256 256
|
||||
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 1024 1024 1024
|
||||
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024
|
||||
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 2048 2048 2048
|
||||
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 4096 4096 4096
|
||||
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 4096 4096 4096
|
||||
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 8192 8192 8192
|
||||
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024
|
||||
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048
|
||||
|
||||
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -1 -1 -1
|
||||
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1
|
||||
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1
|
||||
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1
|
||||
|
||||
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024
|
||||
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048
|
||||
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096
|
||||
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192
|
||||
|
||||
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056
|
||||
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080
|
||||
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128
|
||||
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224
|
||||
|
||||
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088
|
||||
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112
|
||||
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160
|
||||
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256
|
||||
|
||||
Reference in New Issue
Block a user