fix bug in gemm profiler (#344)

[ROCm/composable_kernel commit: 146972f447]
This commit is contained in:
Chao Liu
2022-08-07 12:23:32 -05:00
committed by GitHub
parent c5a39f834f
commit be8f189a9e
6 changed files with 166 additions and 115 deletions

View File

@@ -72,43 +72,43 @@ int profile_gemm(int argc, char* argv[])
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
auto profile = [&](auto a_type,
auto profile = [&](auto a_layout,
auto b_layout,
auto c_layout,
auto a_type,
auto b_type,
auto acc_type,
auto c_type,
auto a_layout,
auto b_layout,
auto c_layout) {
auto c_type) {
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using CLayout = decltype(c_layout);
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using AccDataType = decltype(acc_type);
using CDataType = decltype(c_type);
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using CLayout = decltype(c_layout);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
bool pass =
ck::profiler::profile_gemm_impl<ADataType,
ck::profiler::profile_gemm_impl<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideC < 0) ? DefaultStrideC : StrideC);
CDataType>(do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideC < 0) ? DefaultStrideC : StrideC);
return pass ? 0 : 1;
};