diff --git a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc index c9ff4a3c1d..8cc9f582eb 100644 --- a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc +++ b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc @@ -131,28 +131,21 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, wei_device_buf.FromDevice(wei_device_result.mData.data()); - if(split_k == 1) - { - return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData); - } - else - { - float max_accumulated_value = - *std::max_element(wei_host_result.mData.begin(), wei_host_result.mData.end()); + float max_accumulated_value = + *std::max_element(wei_host_result.mData.begin(), wei_host_result.mData.end()); - const ck::index_t num_accums = out.GetElementSize() / conv_param.K_; - const ck::index_t num_accums_split_k = split_k; - double rtol = ck::utils::get_relative_threshold( - num_accums / num_accums_split_k); - double atol = ck::utils::get_absolute_threshold( - max_accumulated_value / num_accums_split_k, num_accums / num_accums_split_k); + const ck::index_t num_accums = out.GetElementSize() / conv_param.K_; + const ck::index_t num_accums_split_k = split_k; + double rtol = ck::utils::get_relative_threshold( + num_accums / num_accums_split_k); + double atol = ck::utils::get_absolute_threshold( + max_accumulated_value / num_accums_split_k, num_accums / num_accums_split_k); - return ck::utils::check_err(wei_device_result.mData, - wei_host_result.mData, - "Error: Incorrect results!", - rtol, - atol); - } + return ck::utils::check_err(wei_device_result.mData, + wei_host_result.mData, + "Error: Incorrect results!", + rtol, + atol); } else if(config.do_verification == 2) {