Conv:TF32: add more instances - 2 (#2879)

* add instances of device_grouped_conv_fwd_xdl_f32_comp_instances
* add instances of device_grouped_conv_fwd_xdl_f32_tf32_mem_instances
* add instances of device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances
* tf32:conv:add instances for base class DeviceConvFwd
* tf32:conv:add instances for base class DeviceGroupedConvBwdDataMultipleD
* tf32:conv:add instances for base class DeviceGroupedConvBwdWeight
* add tf32 in profiler
* remove gnhwc/ngchw/ngcdhw instances
* remove non-ndhwgc/nhwgc/nhwc instances
* add check in IsSupportedArgument()
This commit is contained in:
yinglu
2025-10-10 15:28:17 +08:00
committed by GitHub
parent ad7a215aba
commit fada1a3cae
56 changed files with 2119 additions and 152 deletions

View File

@@ -45,6 +45,24 @@ struct NumericUtils<float>
using bitwise_type = uint32_t;
};
template <>
struct NumericUtils<ck::tf32_t>
{
static constexpr int exp = 8;
static constexpr int mant = 10;
static constexpr int bias = 127;
static constexpr uint32_t nan_mask = 0x7F800000;
static constexpr uint32_t head_mask = 0xFF800000;
static constexpr uint32_t mant_mask = 0x7FFFFF;
static constexpr uint32_t exp_mask = 0xFF;
static constexpr uint32_t Inf = 0x7F800000;
static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000;
static constexpr bool has_inf = true;
using bitwise_type = uint32_t;
};
template <>
struct NumericUtils<half_t>
{