From c534ed750d13e2967478052fbb13d00a79a28fc3 Mon Sep 17 00:00:00 2001 From: aledudek Date: Fri, 25 Oct 2024 12:46:24 +0200 Subject: [PATCH] Generic threshold calculation (#1546) * Calculate generic relative threshold pool3dfwd * Calculate absolute error threshold pool3d fwd * Generic threshold calculation take max input for relative error pool3dfwd * Remove max possible value for error calculation at runtime * Remove debug print in pool3dfwd * Pool3d fwd adjusted types in generic threshold calculation * Generic threshold calculation take into account number of accumulations and accdatatype * Generic threshold fix final error formula * Generic threshold calculation - num of accs fix * Generic threshold calculation - adjust absolute error * Generic threshold calculation - OutDataType in absolute error [ROCm/composable_kernel commit: 9385caa3069b8b366c365765164df0c0b6b32925] --- include/ck/utility/data_type.hpp | 9 ++ .../include/ck/library/utility/check_err.hpp | 127 ++++++++++++++++++ .../profiler/profile_pool3d_fwd_impl.hpp | 38 +++++- 3 files changed, 167 insertions(+), 7 deletions(-) diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index debeb472ad..39f532e0e9 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -1803,4 +1803,13 @@ struct NumericUtils static constexpr int bias = 16; // negative zero nan mode // static constexpr int bias = 15; // ieee mode }; + +template <> +struct NumericUtils +{ + static constexpr int exp = 8; + static constexpr int mant = 7; + static constexpr int bias = 128; // negative zero nan mode + // static constexpr int bias = 127; // ieee mode +}; } // namespace ck diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index 58479f2127..73ac2a189f 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -23,6 +23,130 @@ namespace ck { namespace utils { +template +double get_relative_threshold(const int numberOfAccumulations = 1) +{ + using F8 = ck::f8_t; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using F32 = float; + using I8 = int8_t; + using I32 = int32_t; + + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v, + "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); + double compute_error = 0; + if constexpr(is_same_v || is_same_v || + is_same_v) + { + return 0; + } + else + { + compute_error = std::pow(2, -NumericUtils::mant) * 0.5; + } + + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v, + "Warning: Unhandled OutDataType for setting up the relative threshold!"); + double output_error = 0; + if constexpr(is_same_v || is_same_v || + is_same_v) + { + return 0; + } + else + { + output_error = std::pow(2, -NumericUtils::mant) * 0.5; + } + double midway_error = std::max(compute_error, output_error); + + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v, + "Warning: Unhandled AccDataType for setting up the relative threshold!"); + double acc_error = 0; + if constexpr(is_same_v || is_same_v || + is_same_v) + { + return 0; + } + else + { + acc_error = std::pow(2, -NumericUtils::mant) * 0.5 * numberOfAccumulations; + } + return std::max(acc_error, midway_error); +} + +template +double get_absolute_threshold(const double max_possible_num, const int numberOfAccumulations = 1) +{ + using F8 = ck::f8_t; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using F32 = float; + using I8 = int8_t; + using I32 = int32_t; + + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v, + "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); + auto expo = std::log2(std::abs(max_possible_num)); + double compute_error = 0; + if constexpr(is_same_v || is_same_v || + is_same_v) + { + return 0; + } + else + { + compute_error = std::pow(2, expo - NumericUtils::mant) * 0.5; + } + + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v, + "Warning: Unhandled OutDataType for setting up the absolute threshold!"); + double output_error = 0; + if constexpr(is_same_v || is_same_v || + is_same_v) + { + return 0; + } + else + { + output_error = std::pow(2, expo - NumericUtils::mant) * 0.5; + } + double midway_error = std::max(compute_error, output_error); + + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v, + "Warning: Unhandled AccDataType for setting up the absolute threshold!"); + double acc_error = 0; + if constexpr(is_same_v || is_same_v || + is_same_v) + { + return 0; + } + else + { + acc_error = + std::pow(2, expo - NumericUtils::mant) * 0.5 * numberOfAccumulations; + } + return std::max(acc_error, midway_error); +} + template typename std::enable_if< std::is_same_v, ranges::range_value_t> && @@ -253,11 +377,13 @@ check_err(const Range& out, int err_count = 0; double err = 0; double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) { const double o = type_convert(*std::next(std::begin(out), i)); const double r = type_convert(*std::next(std::begin(ref), i)); err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) { max_err = err > max_err ? err : max_err; @@ -270,6 +396,7 @@ check_err(const Range& out, res = false; } } + if(!res) { std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err diff --git a/profiler/include/profiler/profile_pool3d_fwd_impl.hpp b/profiler/include/profiler/profile_pool3d_fwd_impl.hpp index 3bdaa5c838..a0890028ac 100644 --- a/profiler/include/profiler/profile_pool3d_fwd_impl.hpp +++ b/profiler/include/profiler/profile_pool3d_fwd_impl.hpp @@ -102,11 +102,22 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams& Tensor out_indices_n_c_do_ho_wo_device( f_host_tensor_descriptor(N, C, Do, Ho, Wo)); + constexpr int inDataRangeTensor1{1}; + constexpr int inDataRangeTensor2{5}; + constexpr double inDataRangeTensor3{0.5}; + switch(in_params.init_method) { - case 0: in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_1{}); break; - case 1: in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; - default: in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + case 0: + in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_1{inDataRangeTensor1}); + break; + case 1: + in_n_c_di_hi_wi.GenerateTensorValue( + GeneratorTensor_2{-inDataRangeTensor2, inDataRangeTensor2}); + break; + default: + in_n_c_di_hi_wi.GenerateTensorValue( + GeneratorTensor_3{-inDataRangeTensor3, inDataRangeTensor3}); } DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_di_hi_wi.mDesc.GetElementSpaceSize()); @@ -229,12 +240,25 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams& { out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data()); - auto tolerance = 1e-3; - bool pass = ck::utils::check_err(out_n_c_do_ho_wo_device.mData, + auto absolute_error_threshold = 1.0; + switch(in_params.init_method) + { + case 0: absolute_error_threshold = static_cast(inDataRangeTensor1); break; + case 1: absolute_error_threshold = static_cast(inDataRangeTensor2); break; + default: absolute_error_threshold = inDataRangeTensor3; + } + + absolute_error_threshold = + ck::utils::get_absolute_threshold( + absolute_error_threshold); + auto relative_error_threshold = + ck::utils::get_relative_threshold(); + + bool pass = ck::utils::check_err(out_n_c_do_ho_wo_device.mData, out_n_c_do_ho_wo_host.mData, "Error: Incorrect results", - tolerance, - tolerance); + relative_error_threshold, + absolute_error_threshold); if constexpr(OutputIndex) {