mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 13:29:20 +00:00
Batched gemm bf16 (#142)
* add bf16 for batched gemm * batched_gemm_bf16 works * recover accidently changed files
This commit is contained in:
@@ -32,7 +32,8 @@ enum GemmDataType
|
||||
{
|
||||
F32_F32_F32, // 0
|
||||
F16_F16_F16, // 1
|
||||
Int8_Int8_Int8, // 2
|
||||
BF16_BF16_BF16, // 2
|
||||
INT8_INT8_INT8, // 3
|
||||
};
|
||||
|
||||
int profile_batched_gemm(int argc, char* argv[])
|
||||
@@ -40,7 +41,7 @@ int profile_batched_gemm(int argc, char* argv[])
|
||||
if(!(argc == 15))
|
||||
{
|
||||
printf("arg1: tensor operation (batched_gemm: Batched GEMM)\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16, 2: int8)\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16, 2: bf16, 3: int8)\n");
|
||||
printf("arg3: matrix layout (0: A[g, m, k] * B[g, k, n] = C[g, m, n];\n");
|
||||
printf(" 1: A[g, m, k] * B[g, n, k] = C[g, m, n];\n");
|
||||
printf(" 2: A[g, k, m] * B[g, k, n] = C[g, m, n];\n");
|
||||
@@ -148,6 +149,84 @@ int profile_batched_gemm(int argc, char* argv[])
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
}
|
||||
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
nrepeat,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
BatchCount);
|
||||
}
|
||||
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
nrepeat,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
BatchCount);
|
||||
}
|
||||
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
nrepeat,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
}
|
||||
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
nrepeat,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<float,
|
||||
@@ -226,7 +305,7 @@ int profile_batched_gemm(int argc, char* argv[])
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
}
|
||||
else if(data_type == GemmDataType::Int8_Int8_Int8 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<int8_t,
|
||||
int8_t,
|
||||
@@ -246,7 +325,7 @@ int profile_batched_gemm(int argc, char* argv[])
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
BatchCount);
|
||||
}
|
||||
else if(data_type == GemmDataType::Int8_Int8_Int8 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<int8_t,
|
||||
int8_t,
|
||||
@@ -266,7 +345,7 @@ int profile_batched_gemm(int argc, char* argv[])
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
BatchCount);
|
||||
}
|
||||
else if(data_type == GemmDataType::Int8_Int8_Int8 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<int8_t,
|
||||
int8_t,
|
||||
@@ -285,7 +364,7 @@ int profile_batched_gemm(int argc, char* argv[])
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
}
|
||||
else if(data_type == GemmDataType::Int8_Int8_Int8 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_impl<int8_t,
|
||||
int8_t,
|
||||
|
||||
Reference in New Issue
Block a user