diff --git a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc index c4f5165e42..e8f630bdcb 100644 --- a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc +++ b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc @@ -128,7 +128,21 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, wei_device_buf.FromDevice(wei_device_result.mData.data()); - return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData); + float max_accumulated_value = + *std::max_element(wei_host_result.mData.begin(),wei_host_result.mData.end()); + + const ck::index_t num_accums = out.GetElementSize() / conv_param.K_; + const ck::index_t num_accums_split_k = split_k; + double rtol = + ck::utils::get_relative_threshold( + num_accums / num_accums_split_k); + double atol = + ck::utils::get_absolute_threshold( + max_accumulated_value / num_accums_split_k, + num_accums / num_accums_split_k); + + return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData, + "Error: Incorrect results!", rtol, atol); } float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});