mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Generic threshold calculation after merge fixes (#1618)
* Generic threshold calculation add passing num of accums
* Generic threshold - after merge fixes
* Fix cmakelists
---------
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
[ROCm/composable_kernel commit: dcafb1de15]
This commit is contained in:
@@ -24,7 +24,7 @@ namespace ck {
|
||||
namespace utils {
|
||||
|
||||
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
|
||||
double get_relative_threshold(const int numberOfAccumulations = 1)
|
||||
double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
{
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
@@ -79,13 +79,13 @@ double get_relative_threshold(const int numberOfAccumulations = 1)
|
||||
}
|
||||
else
|
||||
{
|
||||
acc_error = std::pow(2, -NumericUtils<AccDataType>::mant) * 0.5 * numberOfAccumulations;
|
||||
acc_error = std::pow(2, -NumericUtils<AccDataType>::mant) * 0.5 * number_of_accumulations;
|
||||
}
|
||||
return std::max(acc_error, midway_error);
|
||||
}
|
||||
|
||||
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
|
||||
double get_absolute_threshold(const double max_possible_num, const int numberOfAccumulations = 1)
|
||||
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
|
||||
{
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
@@ -142,7 +142,7 @@ double get_absolute_threshold(const double max_possible_num, const int numberOfA
|
||||
else
|
||||
{
|
||||
acc_error =
|
||||
std::pow(2, expo - NumericUtils<AccDataType>::mant) * 0.5 * numberOfAccumulations;
|
||||
std::pow(2, expo - NumericUtils<AccDataType>::mant) * 0.5 * number_of_accumulations;
|
||||
}
|
||||
return std::max(acc_error, midway_error);
|
||||
}
|
||||
|
||||
@@ -240,6 +240,19 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
|
||||
{
|
||||
out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data());
|
||||
|
||||
auto number_of_accumulations = 1;
|
||||
static_assert(
|
||||
ReduceOpId == ck::ReduceTensorOp::AVG || ReduceOpId == ck::ReduceTensorOp::MAX,
|
||||
"Warning: Unhandled ReduceOpId for setting up the number of accumulations!");
|
||||
|
||||
if constexpr(ReduceOpId == ck::ReduceTensorOp::AVG)
|
||||
{
|
||||
for(size_t i = 0; i < kernel_params.window_spatial_lengths.size(); ++i)
|
||||
{
|
||||
number_of_accumulations *= kernel_params.window_spatial_lengths.at(i);
|
||||
}
|
||||
}
|
||||
|
||||
auto absolute_error_threshold = 1.0;
|
||||
switch(in_params.init_method)
|
||||
{
|
||||
@@ -250,9 +263,10 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
|
||||
|
||||
absolute_error_threshold =
|
||||
ck::utils::get_absolute_threshold<ComputeDataType, OutDataType>(
|
||||
absolute_error_threshold);
|
||||
absolute_error_threshold, number_of_accumulations);
|
||||
auto relative_error_threshold =
|
||||
ck::utils::get_relative_threshold<ComputeDataType, OutDataType>();
|
||||
ck::utils::get_relative_threshold<ComputeDataType, OutDataType>(
|
||||
number_of_accumulations);
|
||||
|
||||
bool pass = ck::utils::check_err(out_n_c_do_ho_wo_device.mData,
|
||||
out_n_c_do_ho_wo_host.mData,
|
||||
|
||||
Reference in New Issue
Block a user