mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Add max error metric
This commit is contained in:
@@ -22,6 +22,99 @@
|
||||
/// etc, by the actual testing framework that the user has chosen.
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
template <DataType DT, typename GemmType = DataType>
|
||||
inline __host__ __device__ constexpr double get_rtol()
|
||||
{
|
||||
// if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::tf32_t>)
|
||||
// {
|
||||
// return 5e-3;
|
||||
// }
|
||||
// else
|
||||
if constexpr(DataType::FP32 == DT)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
// else if constexpr(std::is_same_v<DataType, double>)
|
||||
// {
|
||||
// return 1e-6;
|
||||
// }
|
||||
else if constexpr(DataType::FP16== DT)
|
||||
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(DataType::BF16 == DT)
|
||||
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
// else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
// {
|
||||
// return 1e-1;
|
||||
// }
|
||||
// else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
// {
|
||||
// return 1e-1;
|
||||
// }
|
||||
// else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
// {
|
||||
// return 1e-1; // 240 and 224 are acceptable
|
||||
// }
|
||||
// else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
// {
|
||||
// return 1.5e-1; // 57344 and 49152 are acceptable
|
||||
// }
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
template <DataType DT, typename GemmType = DataType>
|
||||
inline __host__ __device__ constexpr double get_atol()
|
||||
{
|
||||
// if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::tf32_t>)
|
||||
// {
|
||||
// return 1e-3;
|
||||
// }
|
||||
if constexpr(DataType::FP32 == DT)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
// else if constexpr(std::is_same_v<DataType, double>)
|
||||
// {
|
||||
// return 1e-6;
|
||||
// }
|
||||
else if constexpr(DataType::FP16== DT)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(DataType::BF16 == DT)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
// else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
// {
|
||||
// return 1e-1;
|
||||
// }
|
||||
// else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
// {
|
||||
// return 1e-1;
|
||||
// }
|
||||
// else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
// {
|
||||
// return 16.1; // 240 and 224 are acceptable
|
||||
// }
|
||||
// else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
// {
|
||||
// return 8192.1; // 57344 and 49152 are acceptable
|
||||
// }
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// @brief Information about how a set of comparisons failed or succeeded.
|
||||
///
|
||||
@@ -51,6 +144,9 @@ struct ValidationReport
|
||||
/// The number of elements which were bitwise 0.
|
||||
uint64_t zero_elements;
|
||||
|
||||
// Max error.
|
||||
double max_error;
|
||||
|
||||
/// @brief Check whether both the output and reference tensor were both all zeros.
|
||||
///
|
||||
/// If both tensors are all zero, it indicates either an incorrect testing setup
|
||||
@@ -108,8 +204,8 @@ struct ValidationReport
|
||||
const TensorDescriptor<DT, RANK>& descriptor,
|
||||
const void* actual,
|
||||
const void* expected,
|
||||
double rtol = 1e-3,
|
||||
double atol = 1e-3);
|
||||
double rtol = get_rtol<DT>(),
|
||||
double atol = get_atol<DT>());
|
||||
|
||||
private:
|
||||
std::vector<Case> reports_;
|
||||
@@ -133,11 +229,12 @@ bool ValidationReport::check(std::string_view tensor_name,
|
||||
// Initial pass: count errors
|
||||
|
||||
// Allocate and reset counter
|
||||
auto d_counters = alloc_buffer(sizeof(uint64_t) * 2);
|
||||
check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 2));
|
||||
auto d_counters = alloc_buffer(sizeof(uint64_t) * 3);
|
||||
check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 3));
|
||||
|
||||
auto d_error_count = &reinterpret_cast<uint64_t*>(d_counters.get())[0];
|
||||
auto d_zero_count = &reinterpret_cast<uint64_t*>(d_counters.get())[1];
|
||||
auto d_max_error = &reinterpret_cast<double*>(d_counters.get())[2];
|
||||
|
||||
tensor_foreach(descriptor.get_lengths(), [=](auto index) {
|
||||
using CKType = typename factory::internal::DataTypeToCK<DT>::type;
|
||||
@@ -157,6 +254,7 @@ bool ValidationReport::check(std::string_view tensor_name,
|
||||
const auto r = static_cast<double>(type_convert<float>(b));
|
||||
const auto err = std::abs(o - r);
|
||||
|
||||
atomicMax(d_max_error, err);
|
||||
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
|
||||
{
|
||||
// We expect the number of errors to be very low, so just use an atomic
|
||||
@@ -188,6 +286,8 @@ bool ValidationReport::check(std::string_view tensor_name,
|
||||
check_hip(hipMemcpy(&error_count, d_error_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
|
||||
uint64_t zero_count = 0;
|
||||
check_hip(hipMemcpy(&zero_count, d_zero_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
|
||||
double max_error = 0;
|
||||
check_hip(hipMemcpy(&max_error, d_max_error, sizeof(double), hipMemcpyDeviceToHost));
|
||||
|
||||
// TODO: Gather detailed coordinates.
|
||||
|
||||
@@ -196,6 +296,7 @@ bool ValidationReport::check(std::string_view tensor_name,
|
||||
.wrong_elements = error_count,
|
||||
.total_elements = descriptor.get_element_size(),
|
||||
.zero_elements = zero_count,
|
||||
.max_error= max_error,
|
||||
});
|
||||
|
||||
return reports_.back().is_ok();
|
||||
|
||||
@@ -125,7 +125,7 @@ run_grouped_conv_forward_tile_algs(const ckt::Args<SIGNATURE>& args,
|
||||
{
|
||||
valid = false;
|
||||
std::cout << "Number of incorrect values: " << error.wrong_elements
|
||||
<< " Is all zero:" << error.is_all_zero() << std::endl;
|
||||
<< " Is all zero:" << error.is_all_zero() << " max err: " << error.max_error << std::endl;
|
||||
}
|
||||
best_avg_time = std::min(best_avg_time, avg_time);
|
||||
best_op_name = best_avg_time < avg_time ? best_op_name : op_name;
|
||||
|
||||
Reference in New Issue
Block a user