mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Generic threshold calculation (#1546)
* Calculate generic relative threshold pool3dfwd
* Calculate absolute error threshold pool3d fwd
* Generic threshold calculation take max input for relative error pool3dfwd
* Remove max possible value for error calculation at runtime
* Remove debug print in pool3dfwd
* Pool3d fwd adjusted types in generic threshold calculation
* Generic threshold calculation take into account number of accumulations and accdatatype
* Generic threshold fix final error formula
* Generic threshold calculation - num of accs fix
* Generic threshold calculation - adjust absolute error
* Generic threshold calculation - OutDataType in absolute error
[ROCm/composable_kernel commit: 9385caa306]
This commit is contained in:
@@ -1803,4 +1803,13 @@ struct NumericUtils<bf8_t>
|
||||
static constexpr int bias = 16; // negative zero nan mode
|
||||
// static constexpr int bias = 15; // ieee mode
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericUtils<bhalf_t>
|
||||
{
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 7;
|
||||
static constexpr int bias = 128; // negative zero nan mode
|
||||
// static constexpr int bias = 127; // ieee mode
|
||||
};
|
||||
} // namespace ck
|
||||
|
||||
@@ -23,6 +23,130 @@
|
||||
namespace ck {
|
||||
namespace utils {
|
||||
|
||||
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
|
||||
double get_relative_threshold(const int numberOfAccumulations = 1)
|
||||
{
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
static_assert(is_same_v<ComputeDataType, F8> || is_same_v<ComputeDataType, F16> ||
|
||||
is_same_v<ComputeDataType, BF16> || is_same_v<ComputeDataType, F32> ||
|
||||
is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
|
||||
is_same_v<ComputeDataType, int>,
|
||||
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
|
||||
double compute_error = 0;
|
||||
if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
|
||||
is_same_v<ComputeDataType, int>)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
compute_error = std::pow(2, -NumericUtils<ComputeDataType>::mant) * 0.5;
|
||||
}
|
||||
|
||||
static_assert(is_same_v<OutDataType, F8> || is_same_v<OutDataType, F16> ||
|
||||
is_same_v<OutDataType, BF16> || is_same_v<OutDataType, F32> ||
|
||||
is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
|
||||
is_same_v<OutDataType, int>,
|
||||
"Warning: Unhandled OutDataType for setting up the relative threshold!");
|
||||
double output_error = 0;
|
||||
if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
|
||||
is_same_v<OutDataType, int>)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
output_error = std::pow(2, -NumericUtils<OutDataType>::mant) * 0.5;
|
||||
}
|
||||
double midway_error = std::max(compute_error, output_error);
|
||||
|
||||
static_assert(is_same_v<AccDataType, F8> || is_same_v<AccDataType, F16> ||
|
||||
is_same_v<AccDataType, BF16> || is_same_v<AccDataType, F32> ||
|
||||
is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
|
||||
is_same_v<AccDataType, int>,
|
||||
"Warning: Unhandled AccDataType for setting up the relative threshold!");
|
||||
double acc_error = 0;
|
||||
if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
|
||||
is_same_v<AccDataType, int>)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
acc_error = std::pow(2, -NumericUtils<AccDataType>::mant) * 0.5 * numberOfAccumulations;
|
||||
}
|
||||
return std::max(acc_error, midway_error);
|
||||
}
|
||||
|
||||
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
|
||||
double get_absolute_threshold(const double max_possible_num, const int numberOfAccumulations = 1)
|
||||
{
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
static_assert(is_same_v<ComputeDataType, F8> || is_same_v<ComputeDataType, F16> ||
|
||||
is_same_v<ComputeDataType, BF16> || is_same_v<ComputeDataType, F32> ||
|
||||
is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
|
||||
is_same_v<ComputeDataType, int>,
|
||||
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
|
||||
auto expo = std::log2(std::abs(max_possible_num));
|
||||
double compute_error = 0;
|
||||
if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
|
||||
is_same_v<ComputeDataType, int>)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
compute_error = std::pow(2, expo - NumericUtils<ComputeDataType>::mant) * 0.5;
|
||||
}
|
||||
|
||||
static_assert(is_same_v<OutDataType, F8> || is_same_v<OutDataType, F16> ||
|
||||
is_same_v<OutDataType, BF16> || is_same_v<OutDataType, F32> ||
|
||||
is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
|
||||
is_same_v<OutDataType, int>,
|
||||
"Warning: Unhandled OutDataType for setting up the absolute threshold!");
|
||||
double output_error = 0;
|
||||
if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
|
||||
is_same_v<OutDataType, int>)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
output_error = std::pow(2, expo - NumericUtils<OutDataType>::mant) * 0.5;
|
||||
}
|
||||
double midway_error = std::max(compute_error, output_error);
|
||||
|
||||
static_assert(is_same_v<AccDataType, F8> || is_same_v<AccDataType, F16> ||
|
||||
is_same_v<AccDataType, BF16> || is_same_v<AccDataType, F32> ||
|
||||
is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
|
||||
is_same_v<AccDataType, int>,
|
||||
"Warning: Unhandled AccDataType for setting up the absolute threshold!");
|
||||
double acc_error = 0;
|
||||
if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
|
||||
is_same_v<AccDataType, int>)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
acc_error =
|
||||
std::pow(2, expo - NumericUtils<AccDataType>::mant) * 0.5 * numberOfAccumulations;
|
||||
}
|
||||
return std::max(acc_error, midway_error);
|
||||
}
|
||||
|
||||
template <typename Range, typename RefRange>
|
||||
typename std::enable_if<
|
||||
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
@@ -253,11 +377,13 @@ check_err(const Range& out,
|
||||
int err_count = 0;
|
||||
double err = 0;
|
||||
double max_err = std::numeric_limits<float>::min();
|
||||
|
||||
for(std::size_t i = 0; i < ref.size(); ++i)
|
||||
{
|
||||
const double o = type_convert<float>(*std::next(std::begin(out), i));
|
||||
const double r = type_convert<float>(*std::next(std::begin(ref), i));
|
||||
err = std::abs(o - r);
|
||||
|
||||
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
@@ -270,6 +396,7 @@ check_err(const Range& out,
|
||||
res = false;
|
||||
}
|
||||
}
|
||||
|
||||
if(!res)
|
||||
{
|
||||
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err
|
||||
|
||||
@@ -102,11 +102,22 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
|
||||
Tensor<IndexDataType> out_indices_n_c_do_ho_wo_device(
|
||||
f_host_tensor_descriptor(N, C, Do, Ho, Wo));
|
||||
|
||||
constexpr int inDataRangeTensor1{1};
|
||||
constexpr int inDataRangeTensor2{5};
|
||||
constexpr double inDataRangeTensor3{0.5};
|
||||
|
||||
switch(in_params.init_method)
|
||||
{
|
||||
case 0: in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{}); break;
|
||||
case 1: in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); break;
|
||||
default: in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{-0.5, 0.5});
|
||||
case 0:
|
||||
in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{inDataRangeTensor1});
|
||||
break;
|
||||
case 1:
|
||||
in_n_c_di_hi_wi.GenerateTensorValue(
|
||||
GeneratorTensor_2<InDataType>{-inDataRangeTensor2, inDataRangeTensor2});
|
||||
break;
|
||||
default:
|
||||
in_n_c_di_hi_wi.GenerateTensorValue(
|
||||
GeneratorTensor_3<InDataType>{-inDataRangeTensor3, inDataRangeTensor3});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_di_hi_wi.mDesc.GetElementSpaceSize());
|
||||
@@ -229,12 +240,25 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
|
||||
{
|
||||
out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data());
|
||||
|
||||
auto tolerance = 1e-3;
|
||||
bool pass = ck::utils::check_err(out_n_c_do_ho_wo_device.mData,
|
||||
auto absolute_error_threshold = 1.0;
|
||||
switch(in_params.init_method)
|
||||
{
|
||||
case 0: absolute_error_threshold = static_cast<double>(inDataRangeTensor1); break;
|
||||
case 1: absolute_error_threshold = static_cast<double>(inDataRangeTensor2); break;
|
||||
default: absolute_error_threshold = inDataRangeTensor3;
|
||||
}
|
||||
|
||||
absolute_error_threshold =
|
||||
ck::utils::get_absolute_threshold<ComputeDataType, OutDataType>(
|
||||
absolute_error_threshold);
|
||||
auto relative_error_threshold =
|
||||
ck::utils::get_relative_threshold<ComputeDataType, OutDataType>();
|
||||
|
||||
bool pass = ck::utils::check_err(out_n_c_do_ho_wo_device.mData,
|
||||
out_n_c_do_ho_wo_host.mData,
|
||||
"Error: Incorrect results",
|
||||
tolerance,
|
||||
tolerance);
|
||||
relative_error_threshold,
|
||||
absolute_error_threshold);
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user