mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
Generic threshold calculation (#1546)
* Calculate generic relative threshold pool3dfwd
* Calculate absolute error threshold pool3d fwd
* Generic threshold calculation take max input for relative error pool3dfwd
* Remove max possible value for error calculation at runtime
* Remove debug print in pool3dfwd
* Pool3d fwd adjusted types in generic threshold calculation
* Generic threshold calculation take into account number of accumulations and accdatatype
* Generic threshold fix final error formula
* Generic threshold calculation - num of accs fix
* Generic threshold calculation - adjust absolute error
* Generic threshold calculation - OutDataType in absolute error
[ROCm/composable_kernel commit: 9385caa306]
This commit is contained in:
@@ -102,11 +102,22 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
|
||||
Tensor<IndexDataType> out_indices_n_c_do_ho_wo_device(
|
||||
f_host_tensor_descriptor(N, C, Do, Ho, Wo));
|
||||
|
||||
constexpr int inDataRangeTensor1{1};
|
||||
constexpr int inDataRangeTensor2{5};
|
||||
constexpr double inDataRangeTensor3{0.5};
|
||||
|
||||
switch(in_params.init_method)
|
||||
{
|
||||
case 0: in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{}); break;
|
||||
case 1: in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); break;
|
||||
default: in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{-0.5, 0.5});
|
||||
case 0:
|
||||
in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{inDataRangeTensor1});
|
||||
break;
|
||||
case 1:
|
||||
in_n_c_di_hi_wi.GenerateTensorValue(
|
||||
GeneratorTensor_2<InDataType>{-inDataRangeTensor2, inDataRangeTensor2});
|
||||
break;
|
||||
default:
|
||||
in_n_c_di_hi_wi.GenerateTensorValue(
|
||||
GeneratorTensor_3<InDataType>{-inDataRangeTensor3, inDataRangeTensor3});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_di_hi_wi.mDesc.GetElementSpaceSize());
|
||||
@@ -229,12 +240,25 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
|
||||
{
|
||||
out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data());
|
||||
|
||||
auto tolerance = 1e-3;
|
||||
bool pass = ck::utils::check_err(out_n_c_do_ho_wo_device.mData,
|
||||
auto absolute_error_threshold = 1.0;
|
||||
switch(in_params.init_method)
|
||||
{
|
||||
case 0: absolute_error_threshold = static_cast<double>(inDataRangeTensor1); break;
|
||||
case 1: absolute_error_threshold = static_cast<double>(inDataRangeTensor2); break;
|
||||
default: absolute_error_threshold = inDataRangeTensor3;
|
||||
}
|
||||
|
||||
absolute_error_threshold =
|
||||
ck::utils::get_absolute_threshold<ComputeDataType, OutDataType>(
|
||||
absolute_error_threshold);
|
||||
auto relative_error_threshold =
|
||||
ck::utils::get_relative_threshold<ComputeDataType, OutDataType>();
|
||||
|
||||
bool pass = ck::utils::check_err(out_n_c_do_ho_wo_device.mData,
|
||||
out_n_c_do_ho_wo_host.mData,
|
||||
"Error: Incorrect results",
|
||||
tolerance,
|
||||
tolerance);
|
||||
relative_error_threshold,
|
||||
absolute_error_threshold);
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user