Restore example tolerance calculation

This commit is contained in:
Enrico Degregori
2025-12-12 11:17:31 +00:00
parent a87256a676
commit df75061576

View File

@@ -131,28 +131,21 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
wei_device_buf.FromDevice(wei_device_result.mData.data());
if(split_k == 1)
{
return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData);
}
else
{
float max_accumulated_value =
*std::max_element(wei_host_result.mData.begin(), wei_host_result.mData.end());
float max_accumulated_value =
*std::max_element(wei_host_result.mData.begin(), wei_host_result.mData.end());
const ck::index_t num_accums = out.GetElementSize() / conv_param.K_;
const ck::index_t num_accums_split_k = split_k;
double rtol = ck::utils::get_relative_threshold<InDataType, WeiDataType, AccDataType>(
num_accums / num_accums_split_k);
double atol = ck::utils::get_absolute_threshold<InDataType, WeiDataType, AccDataType>(
max_accumulated_value / num_accums_split_k, num_accums / num_accums_split_k);
const ck::index_t num_accums = out.GetElementSize() / conv_param.K_;
const ck::index_t num_accums_split_k = split_k;
double rtol = ck::utils::get_relative_threshold<InDataType, WeiDataType, AccDataType>(
num_accums / num_accums_split_k);
double atol = ck::utils::get_absolute_threshold<InDataType, WeiDataType, AccDataType>(
max_accumulated_value / num_accums_split_k, num_accums / num_accums_split_k);
return ck::utils::check_err(wei_device_result.mData,
wei_host_result.mData,
"Error: Incorrect results!",
rtol,
atol);
}
return ck::utils::check_err(wei_device_result.mData,
wei_host_result.mData,
"Error: Incorrect results!",
rtol,
atol);
}
else if(config.do_verification == 2)
{