diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index d08196924b..6e1c9f2a0d 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -75,9 +75,10 @@ struct ProblemSizeSplitK final struct ExecutionConfig final { - bool do_verification = true; - int init_method = 2; - bool time_kernel = false; + // 0 - no verification, 1 - CPU, 2 - GPU, 3 - CPU + GPU + int do_verification = 3; + int init_method = 2; + bool time_kernel = false; }; template @@ -126,7 +127,7 @@ bool parse_cmd_args(int argc, } else { - std::cerr << "arg1: verification (0=no, 1=CPU and GPU)" << std::endl + std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl @@ -176,7 +177,7 @@ bool parse_cmd_args(int argc, else { std::cerr - << "arg1: verification (0=no, 1=CPU and GPU)" << std::endl + << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl @@ -225,7 +226,7 @@ bool parse_cmd_args(int argc, } else { - std::cerr << "arg1: verification (0=no, 1=CPU and GPU)" << std::endl + std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl @@ -275,7 +276,7 @@ bool parse_cmd_args(int argc, } else { - std::cerr << "arg1: verification (0=no, 1=CPU and GPU)" << std::endl + std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index fe12998e35..bafec3f358 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -330,7 +330,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) bool pass = true; - if(config.do_verification) + if((config.do_verification == 1) || (config.do_verification == 3)) { // CPU verification auto ref_gemm = ReferenceGemmInstance{}; @@ -353,13 +353,16 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) #else c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - pass &= !ck::utils::check_err(c_m_n_device_result, - c_m_n_host_result, - "Error: Incorrect results!", - get_rtol(), - get_atol()); + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); #endif + } + if((config.do_verification == 2) || (config.do_verification == 3)) + { // GPU verification auto ref_gemm_gpu = ReferenceGemmInstanceGPU{}; auto ref_invoker_gpu = ref_gemm_gpu.MakeInvoker(); @@ -381,14 +384,14 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) c_m_n_device_ref_buf.FromDevice(c_m_n_device_ref_result.mData.data()); c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - pass &= !ck::utils::check_err(c_m_n_device_result, - c_m_n_device_ref_result, - "Error: Incorrect results!", - get_rtol(), - get_atol()); + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_device_ref_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); } - return !pass; + return pass == true; } bool run_gemm_example(int argc, char* argv[]) diff --git a/example/01_gemm/run_gemm_example_streamk_v2.inc b/example/01_gemm/run_gemm_example_streamk_v2.inc index 6679f95157..8ed8b81bec 100644 --- a/example/01_gemm/run_gemm_example_streamk_v2.inc +++ b/example/01_gemm/run_gemm_example_streamk_v2.inc @@ -241,7 +241,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) } bool pass = true; - if(config.do_verification) + if((config.do_verification == 1) || (config.do_verification == 3)) { auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); diff --git a/example/01_gemm/run_gemm_example_v2.inc b/example/01_gemm/run_gemm_example_v2.inc index 0bcee658b9..71524fdecf 100644 --- a/example/01_gemm/run_gemm_example_v2.inc +++ b/example/01_gemm/run_gemm_example_v2.inc @@ -228,7 +228,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) } bool pass = true; - if(config.do_verification) + if((config.do_verification == 1) || (config.do_verification == 3)) { auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); diff --git a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp index 639b5fe80f..2c2cac77e3 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp @@ -45,10 +45,10 @@ __global__ void if(row_idx < m && col_idx < n) { - AccDataType v_acc = static_cast(0.0); - ComputeTypeA v_a = static_cast(0.0); - ComputeTypeB v_b = static_cast(0.0); - CDataType v_c = static_cast(0.0); + AccDataType v_acc{0}; + ComputeTypeA v_a{0}; + ComputeTypeB v_b{0}; + CDataType v_c{0}; for(int k_idx = 0; k_idx < k; ++k_idx) { @@ -76,7 +76,7 @@ __global__ void // apply b_element_op b_element_op(v_b, p_b_grid[element_idx_b]); // multiply and accumulate - v_acc += static_cast(v_a) * static_cast(v_b); + v_acc += type_convert(v_a) * type_convert(v_b); } // apply c_element_op c_element_op(v_c, v_acc);