use proper rtol/atol

This commit is contained in:
Sami Remes
2026-02-03 09:57:20 +00:00
parent 6b50755cd2
commit 16fa73db63
2 changed files with 26 additions and 4 deletions

View File

@@ -123,16 +123,37 @@ int run_mx_gemm_with_layouts(int argc,
ck_tile::reference_mx_gemm<ADataType, BDataType, ScaleType, AccDataType, CDataType>(
a_host, b_host, c_m_n_host_ref, scale_a_host, scale_b_host);
// ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
// a_host, b_host, c_m_n_host_ref);
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
const float atol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
auto calculate_rtol_atol = [&K, &kbatch](const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
};
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto [rtol, atol] = calculate_rtol_atol(max_accumulated_value);
pass = ck_tile::check_err(
c_host, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
<< std::endl;
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
std::cout << "The CPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
}
return pass ? 0 : -1;
}
@@ -170,7 +191,7 @@ int run_mx_gemm_example(int argc, char* argv[])
}
else
{
throw std::runtime_error("Only fp4 is supported currently!");
throw std::runtime_error("Only fp4/8 is supported currently!");
}
}
else