mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
use proper rtol/atol
This commit is contained in:
@@ -123,16 +123,37 @@ int run_mx_gemm_with_layouts(int argc,
|
||||
|
||||
ck_tile::reference_mx_gemm<ADataType, BDataType, ScaleType, AccDataType, CDataType>(
|
||||
a_host, b_host, c_m_n_host_ref, scale_a_host, scale_b_host);
|
||||
// ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
// a_host, b_host, c_m_n_host_ref);
|
||||
|
||||
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
|
||||
const float atol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
|
||||
auto calculate_rtol_atol = [&K, &kbatch](const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
};
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto [rtol, atol] = calculate_rtol_atol(max_accumulated_value);
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_host, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
std::cout << "The CPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
return pass ? 0 : -1;
|
||||
}
|
||||
@@ -170,7 +191,7 @@ int run_mx_gemm_example(int argc, char* argv[])
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Only fp4 is supported currently!");
|
||||
throw std::runtime_error("Only fp4/8 is supported currently!");
|
||||
}
|
||||
}
|
||||
else
|
||||
|
||||
Reference in New Issue
Block a user