mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Update GPU verification (#1596)
* Update inits
* Update static_cast to type_convert
* Add verification option selection
[ROCm/composable_kernel commit: 7d576f1748]
This commit is contained in:
@@ -75,9 +75,10 @@ struct ProblemSizeSplitK final
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 2;
|
||||
bool time_kernel = false;
|
||||
// 0 - no verification, 1 - CPU, 2 - GPU, 3 - CPU + GPU
|
||||
int do_verification = 3;
|
||||
int init_method = 2;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
template <ck::index_t... Is>
|
||||
@@ -126,7 +127,7 @@ bool parse_cmd_args<ProblemSize>(int argc,
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=CPU and GPU)" << std::endl
|
||||
std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
@@ -176,7 +177,7 @@ bool parse_cmd_args<ProblemSizeStreamK_universal>(int argc,
|
||||
else
|
||||
{
|
||||
std::cerr
|
||||
<< "arg1: verification (0=no, 1=CPU and GPU)" << std::endl
|
||||
<< "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
|
||||
@@ -225,7 +226,7 @@ bool parse_cmd_args<ProblemSizeStreamK>(int argc,
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=CPU and GPU)" << std::endl
|
||||
std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
@@ -275,7 +276,7 @@ bool parse_cmd_args<ProblemSizeSplitK>(int argc,
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=CPU and GPU)" << std::endl
|
||||
std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
|
||||
@@ -330,7 +330,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(config.do_verification)
|
||||
if((config.do_verification == 1) || (config.do_verification == 3))
|
||||
{
|
||||
// CPU verification
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
@@ -353,13 +353,16 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
#else
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
pass &= !ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
pass &= ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
#endif
|
||||
}
|
||||
|
||||
if((config.do_verification == 2) || (config.do_verification == 3))
|
||||
{
|
||||
// GPU verification
|
||||
auto ref_gemm_gpu = ReferenceGemmInstanceGPU{};
|
||||
auto ref_invoker_gpu = ref_gemm_gpu.MakeInvoker();
|
||||
@@ -381,14 +384,14 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
c_m_n_device_ref_buf.FromDevice(c_m_n_device_ref_result.mData.data());
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
pass &= !ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_device_ref_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
pass &= ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_device_ref_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
}
|
||||
|
||||
return !pass;
|
||||
return pass == true;
|
||||
}
|
||||
|
||||
bool run_gemm_example(int argc, char* argv[])
|
||||
|
||||
@@ -241,7 +241,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
if((config.do_verification == 1) || (config.do_verification == 3))
|
||||
{
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
@@ -228,7 +228,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
if((config.do_verification == 1) || (config.do_verification == 3))
|
||||
{
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
@@ -45,10 +45,10 @@ __global__ void
|
||||
if(row_idx < m && col_idx < n)
|
||||
{
|
||||
|
||||
AccDataType v_acc = static_cast<AccDataType>(0.0);
|
||||
ComputeTypeA v_a = static_cast<ComputeTypeA>(0.0);
|
||||
ComputeTypeB v_b = static_cast<ComputeTypeB>(0.0);
|
||||
CDataType v_c = static_cast<CDataType>(0.0);
|
||||
AccDataType v_acc{0};
|
||||
ComputeTypeA v_a{0};
|
||||
ComputeTypeB v_b{0};
|
||||
CDataType v_c{0};
|
||||
|
||||
for(int k_idx = 0; k_idx < k; ++k_idx)
|
||||
{
|
||||
@@ -76,7 +76,7 @@ __global__ void
|
||||
// apply b_element_op
|
||||
b_element_op(v_b, p_b_grid[element_idx_b]);
|
||||
// multiply and accumulate
|
||||
v_acc += static_cast<AccDataType>(v_a) * static_cast<AccDataType>(v_b);
|
||||
v_acc += type_convert<AccDataType>(v_a) * type_convert<AccDataType>(v_b);
|
||||
}
|
||||
// apply c_element_op
|
||||
c_element_op(v_c, v_acc);
|
||||
|
||||
Reference in New Issue
Block a user