Conv:TF32: add more instances - 2 (#2879)

* add instances of device_grouped_conv_fwd_xdl_f32_comp_instances
* add instances of device_grouped_conv_fwd_xdl_f32_tf32_mem_instances
* add instances of device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances
* tf32:conv:add instances for base class DeviceConvFwd
* tf32:conv:add instances for base class DeviceGroupedConvBwdDataMultipleD
* tf32:conv:add instances for base class DeviceGroupedConvBwdWeight
* add tf32 in profiler
* remove gnhwc/ngchw/ngcdhw instances
* remove non-ndhwgc/nhwgc/nhwc instances
* add check in IsSupportedArgument()
This commit is contained in:
yinglu
2025-10-10 15:28:17 +08:00
committed by GitHub
parent ad7a215aba
commit fada1a3cae
56 changed files with 2119 additions and 152 deletions

View File

@@ -31,13 +31,15 @@ double get_relative_threshold(const int number_of_accumulations = 1)
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using TF32 = ck::tf32_t;
using I8 = int8_t;
using I32 = int32_t;
static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, I8> ||
is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, int>,
is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, TF32> ||
is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
is_same_v<ComputeDataType, int>,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
double compute_error = 0;
if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
@@ -52,8 +54,9 @@ double get_relative_threshold(const int number_of_accumulations = 1)
static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
is_same_v<OutDataType, F32> || is_same_v<OutDataType, I8> ||
is_same_v<OutDataType, I32> || is_same_v<OutDataType, int>,
is_same_v<OutDataType, F32> || is_same_v<ComputeDataType, TF32> ||
is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
is_same_v<OutDataType, int>,
"Warning: Unhandled OutDataType for setting up the relative threshold!");
double output_error = 0;
if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
@@ -69,8 +72,9 @@ double get_relative_threshold(const int number_of_accumulations = 1)
static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
is_same_v<AccDataType, F32> || is_same_v<AccDataType, I8> ||
is_same_v<AccDataType, I32> || is_same_v<AccDataType, int>,
is_same_v<AccDataType, F32> || is_same_v<ComputeDataType, TF32> ||
is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
is_same_v<AccDataType, int>,
"Warning: Unhandled AccDataType for setting up the relative threshold!");
double acc_error = 0;
if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
@@ -93,13 +97,15 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using TF32 = ck::tf32_t;
using I8 = int8_t;
using I32 = int32_t;
static_assert(is_same_v<ComputeDataType, F4> || is_same_v<ComputeDataType, F8> ||
is_same_v<ComputeDataType, F16> || is_same_v<ComputeDataType, BF16> ||
is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, I8> ||
is_same_v<ComputeDataType, I32> || is_same_v<ComputeDataType, int>,
is_same_v<ComputeDataType, F32> || is_same_v<ComputeDataType, TF32> ||
is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
is_same_v<ComputeDataType, int>,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
auto expo = std::log2(std::abs(max_possible_num));
double compute_error = 0;
@@ -115,8 +121,9 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
static_assert(is_same_v<OutDataType, F4> || is_same_v<OutDataType, F8> ||
is_same_v<OutDataType, F16> || is_same_v<OutDataType, BF16> ||
is_same_v<OutDataType, F32> || is_same_v<OutDataType, I8> ||
is_same_v<OutDataType, I32> || is_same_v<OutDataType, int>,
is_same_v<OutDataType, F32> || is_same_v<ComputeDataType, TF32> ||
is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
is_same_v<OutDataType, int>,
"Warning: Unhandled OutDataType for setting up the absolute threshold!");
double output_error = 0;
if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
@@ -132,8 +139,9 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
static_assert(is_same_v<AccDataType, F4> || is_same_v<AccDataType, F8> ||
is_same_v<AccDataType, F16> || is_same_v<AccDataType, BF16> ||
is_same_v<AccDataType, F32> || is_same_v<AccDataType, I8> ||
is_same_v<AccDataType, I32> || is_same_v<AccDataType, int>,
is_same_v<AccDataType, F32> || is_same_v<ComputeDataType, TF32> ||
is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
is_same_v<AccDataType, int>,
"Warning: Unhandled AccDataType for setting up the absolute threshold!");
double acc_error = 0;
if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
@@ -149,11 +157,67 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
return std::max(acc_error, midway_error);
}
template <typename Range, typename RefRange>
template <typename Range,
typename RefRange,
typename ComputeDataType = ranges::range_value_t<Range>>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, float> &&
std::is_same_v<ComputeDataType, ck::tf32_t>,
bool>::type
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-5,
double atol = 3e-5)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
bool res{true};
int err_count = 0;
double err = 0;
double max_err = std::numeric_limits<double>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = *std::next(std::begin(out), i);
const double r = *std::next(std::begin(ref), i);
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{
max_err = err > max_err ? err : max_err;
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
err_count++;
}
}
if(!res)
{
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
}
return res;
}
template <typename Range,
typename RefRange,
typename ComputeDataType = ranges::range_value_t<Range>>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_floating_point_v<ranges::range_value_t<Range>> &&
!std::is_same_v<ranges::range_value_t<Range>, half_t>,
!std::is_same_v<ranges::range_value_t<Range>, half_t> &&
!std::is_same_v<ComputeDataType, ck::tf32_t>,
bool>::type
check_err(const Range& out,
const RefRange& ref,
@@ -200,7 +264,9 @@ check_err(const Range& out,
return res;
}
template <typename Range, typename RefRange>
template <typename Range,
typename RefRange,
typename ComputeDataType = ranges::range_value_t<Range>>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, bhalf_t>,
@@ -251,7 +317,9 @@ check_err(const Range& out,
return res;
}
template <typename Range, typename RefRange>
template <typename Range,
typename RefRange,
typename ComputeDataType = ranges::range_value_t<Range>>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, half_t>,
@@ -301,7 +369,9 @@ check_err(const Range& out,
return res;
}
template <typename Range, typename RefRange>
template <typename Range,
typename RefRange,
typename ComputeDataType = ranges::range_value_t<Range>>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_integral_v<ranges::range_value_t<Range>> &&
!std::is_same_v<ranges::range_value_t<Range>, bhalf_t> &&
@@ -358,7 +428,9 @@ check_err(const Range& out,
return res;
}
template <typename Range, typename RefRange>
template <typename Range,
typename RefRange,
typename ComputeDataType = ranges::range_value_t<Range>>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, f8_t>),
bool>
@@ -407,7 +479,9 @@ check_err(const Range& out,
return res;
}
template <typename Range, typename RefRange>
template <typename Range,
typename RefRange,
typename ComputeDataType = ranges::range_value_t<Range>>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
bool>
@@ -452,7 +526,9 @@ check_err(const Range& out,
return res;
}
template <typename Range, typename RefRange>
template <typename Range,
typename RefRange,
typename ComputeDataType = ranges::range_value_t<Range>>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, f4_t>),
bool>