From 91228f5e5030415e68700c35eb5401ffdd5519c5 Mon Sep 17 00:00:00 2001 From: aledudek Date: Wed, 6 Nov 2024 10:44:58 +0100 Subject: [PATCH] Generic threshold calculation after merge fixes (#1618) * Generic threshold calculation add passing num of accums * Generic threshold - after merge fixes * Fix cmakelists --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> [ROCm/composable_kernel commit: dcafb1de15a8fd1de3496f19fd806ac9cb185012] --- .../include/ck/library/utility/check_err.hpp | 8 ++++---- .../profiler/profile_pool3d_fwd_impl.hpp | 18 ++++++++++++++++-- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index 73ac2a189f..88741c3b96 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -24,7 +24,7 @@ namespace ck { namespace utils { template -double get_relative_threshold(const int numberOfAccumulations = 1) +double get_relative_threshold(const int number_of_accumulations = 1) { using F8 = ck::f8_t; using F16 = ck::half_t; @@ -79,13 +79,13 @@ double get_relative_threshold(const int numberOfAccumulations = 1) } else { - acc_error = std::pow(2, -NumericUtils::mant) * 0.5 * numberOfAccumulations; + acc_error = std::pow(2, -NumericUtils::mant) * 0.5 * number_of_accumulations; } return std::max(acc_error, midway_error); } template -double get_absolute_threshold(const double max_possible_num, const int numberOfAccumulations = 1) +double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1) { using F8 = ck::f8_t; using F16 = ck::half_t; @@ -142,7 +142,7 @@ double get_absolute_threshold(const double max_possible_num, const int numberOfA else { acc_error = - std::pow(2, expo - NumericUtils::mant) * 0.5 * numberOfAccumulations; + std::pow(2, expo - NumericUtils::mant) * 0.5 * number_of_accumulations; } return std::max(acc_error, midway_error); } diff --git a/profiler/include/profiler/profile_pool3d_fwd_impl.hpp b/profiler/include/profiler/profile_pool3d_fwd_impl.hpp index a0890028ac..cbdacad53b 100644 --- a/profiler/include/profiler/profile_pool3d_fwd_impl.hpp +++ b/profiler/include/profiler/profile_pool3d_fwd_impl.hpp @@ -240,6 +240,19 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams& { out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data()); + auto number_of_accumulations = 1; + static_assert( + ReduceOpId == ck::ReduceTensorOp::AVG || ReduceOpId == ck::ReduceTensorOp::MAX, + "Warning: Unhandled ReduceOpId for setting up the number of accumulations!"); + + if constexpr(ReduceOpId == ck::ReduceTensorOp::AVG) + { + for(size_t i = 0; i < kernel_params.window_spatial_lengths.size(); ++i) + { + number_of_accumulations *= kernel_params.window_spatial_lengths.at(i); + } + } + auto absolute_error_threshold = 1.0; switch(in_params.init_method) { @@ -250,9 +263,10 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams& absolute_error_threshold = ck::utils::get_absolute_threshold( - absolute_error_threshold); + absolute_error_threshold, number_of_accumulations); auto relative_error_threshold = - ck::utils::get_relative_threshold(); + ck::utils::get_relative_threshold( + number_of_accumulations); bool pass = ck::utils::check_err(out_n_c_do_ho_wo_device.mData, out_n_c_do_ho_wo_host.mData,