diff --git a/experimental/builder/include/ck_tile/builder/testing/validation.hpp b/experimental/builder/include/ck_tile/builder/testing/validation.hpp index 158f271e21..b3eee9b4b4 100644 --- a/experimental/builder/include/ck_tile/builder/testing/validation.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/validation.hpp @@ -22,6 +22,99 @@ /// etc, by the actual testing framework that the user has chosen. namespace ck_tile::builder::test { +template +inline __host__ __device__ constexpr double get_rtol() +{ + // if constexpr(std::is_same_v && std::is_same_v) + // { + // return 5e-3; + // } + // else + if constexpr(DataType::FP32 == DT) + { + return 1e-3; + } + // else if constexpr(std::is_same_v) + // { + // return 1e-6; + // } + else if constexpr(DataType::FP16== DT) + + { + return 1e-3; + } + else if constexpr(DataType::BF16 == DT) + + { + return 5e-2; + } + // else if constexpr(std::is_same_v) + // { + // return 1e-1; + // } + // else if constexpr(std::is_same_v) + // { + // return 1e-1; + // } + // else if constexpr(std::is_same_v) + // { + // return 1e-1; // 240 and 224 are acceptable + // } + // else if constexpr(std::is_same_v) + // { + // return 1.5e-1; // 57344 and 49152 are acceptable + // } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ +// if constexpr(std::is_same_v && std::is_same_v) +// { +// return 1e-3; +// } + if constexpr(DataType::FP32 == DT) + { + return 1e-3; + } + // else if constexpr(std::is_same_v) + // { + // return 1e-6; + // } + else if constexpr(DataType::FP16== DT) + { + return 1e-3; + } + else if constexpr(DataType::BF16 == DT) + { + return 5e-2; + } + // else if constexpr(std::is_same_v) + // { + // return 1e-1; + // } + // else if constexpr(std::is_same_v) + // { + // return 1e-1; + // } + // else if constexpr(std::is_same_v) + // { + // return 16.1; // 240 and 224 are acceptable + // } + // else if constexpr(std::is_same_v) + // { + // return 8192.1; // 57344 and 49152 are acceptable + // } + else + { + return 1e-3; + } +} + /// @brief Information about how a set of comparisons failed or succeeded. /// @@ -51,6 +144,9 @@ struct ValidationReport /// The number of elements which were bitwise 0. uint64_t zero_elements; + // Max error. + double max_error; + /// @brief Check whether both the output and reference tensor were both all zeros. /// /// If both tensors are all zero, it indicates either an incorrect testing setup @@ -108,8 +204,8 @@ struct ValidationReport const TensorDescriptor& descriptor, const void* actual, const void* expected, - double rtol = 1e-3, - double atol = 1e-3); + double rtol = get_rtol
(), + double atol = get_atol
()); private: std::vector reports_; @@ -133,11 +229,12 @@ bool ValidationReport::check(std::string_view tensor_name, // Initial pass: count errors // Allocate and reset counter - auto d_counters = alloc_buffer(sizeof(uint64_t) * 2); - check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 2)); + auto d_counters = alloc_buffer(sizeof(uint64_t) * 3); + check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 3)); auto d_error_count = &reinterpret_cast(d_counters.get())[0]; auto d_zero_count = &reinterpret_cast(d_counters.get())[1]; + auto d_max_error = &reinterpret_cast(d_counters.get())[2]; tensor_foreach(descriptor.get_lengths(), [=](auto index) { using CKType = typename factory::internal::DataTypeToCK
::type; @@ -157,6 +254,7 @@ bool ValidationReport::check(std::string_view tensor_name, const auto r = static_cast(type_convert(b)); const auto err = std::abs(o - r); + atomicMax(d_max_error, err); if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) { // We expect the number of errors to be very low, so just use an atomic @@ -188,6 +286,8 @@ bool ValidationReport::check(std::string_view tensor_name, check_hip(hipMemcpy(&error_count, d_error_count, sizeof(uint64_t), hipMemcpyDeviceToHost)); uint64_t zero_count = 0; check_hip(hipMemcpy(&zero_count, d_zero_count, sizeof(uint64_t), hipMemcpyDeviceToHost)); + double max_error = 0; + check_hip(hipMemcpy(&max_error, d_max_error, sizeof(double), hipMemcpyDeviceToHost)); // TODO: Gather detailed coordinates. @@ -196,6 +296,7 @@ bool ValidationReport::check(std::string_view tensor_name, .wrong_elements = error_count, .total_elements = descriptor.get_element_size(), .zero_elements = zero_count, + .max_error= max_error, }); return reports_.back().is_ok(); diff --git a/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp b/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp index e58c884729..7e6cd7e9a9 100644 --- a/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp +++ b/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp @@ -125,7 +125,7 @@ run_grouped_conv_forward_tile_algs(const ckt::Args& args, { valid = false; std::cout << "Number of incorrect values: " << error.wrong_elements - << " Is all zero:" << error.is_all_zero() << std::endl; + << " Is all zero:" << error.is_all_zero() << " max err: " << error.max_error << std::endl; } best_avg_time = std::min(best_avg_time, avg_time); best_op_name = best_avg_time < avg_time ? best_op_name : op_name;