Generic threshold calculation (#1546)

* Calculate generic relative threshold pool3dfwd

* Calculate absolute error threshold pool3d fwd

* Generic threshold calculation take max input for relative error pool3dfwd

* Remove max possible value for error calculation at runtime

* Remove debug print in pool3dfwd

* Pool3d fwd adjusted types in generic threshold calculation

* Generic threshold calculation take into account number of accumulations and accdatatype

* Generic threshold fix final error formula

* Generic threshold calculation - num of accs fix

* Generic threshold calculation - adjust absolute error

* Generic threshold calculation - OutDataType in absolute error

[ROCm/composable_kernel commit: 9385caa306]
This commit is contained in:
aledudek
2024-10-25 12:46:24 +02:00
committed by GitHub
parent 6cd6bf04fb
commit c534ed750d
3 changed files with 167 additions and 7 deletions

View File

@@ -102,11 +102,22 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
Tensor<IndexDataType> out_indices_n_c_do_ho_wo_device(
f_host_tensor_descriptor(N, C, Do, Ho, Wo));
constexpr int inDataRangeTensor1{1};
constexpr int inDataRangeTensor2{5};
constexpr double inDataRangeTensor3{0.5};
switch(in_params.init_method)
{
case 0: in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{}); break;
case 1: in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); break;
default: in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{-0.5, 0.5});
case 0:
in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{inDataRangeTensor1});
break;
case 1:
in_n_c_di_hi_wi.GenerateTensorValue(
GeneratorTensor_2<InDataType>{-inDataRangeTensor2, inDataRangeTensor2});
break;
default:
in_n_c_di_hi_wi.GenerateTensorValue(
GeneratorTensor_3<InDataType>{-inDataRangeTensor3, inDataRangeTensor3});
}
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_di_hi_wi.mDesc.GetElementSpaceSize());
@@ -229,12 +240,25 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
{
out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data());
auto tolerance = 1e-3;
bool pass = ck::utils::check_err(out_n_c_do_ho_wo_device.mData,
auto absolute_error_threshold = 1.0;
switch(in_params.init_method)
{
case 0: absolute_error_threshold = static_cast<double>(inDataRangeTensor1); break;
case 1: absolute_error_threshold = static_cast<double>(inDataRangeTensor2); break;
default: absolute_error_threshold = inDataRangeTensor3;
}
absolute_error_threshold =
ck::utils::get_absolute_threshold<ComputeDataType, OutDataType>(
absolute_error_threshold);
auto relative_error_threshold =
ck::utils::get_relative_threshold<ComputeDataType, OutDataType>();
bool pass = ck::utils::check_err(out_n_c_do_ho_wo_device.mData,
out_n_c_do_ho_wo_host.mData,
"Error: Incorrect results",
tolerance,
tolerance);
relative_error_threshold,
absolute_error_threshold);
if constexpr(OutputIndex)
{