Conv:TF32: add more instances - 2 (#2879)

* add instances of device_grouped_conv_fwd_xdl_f32_comp_instances
* add instances of device_grouped_conv_fwd_xdl_f32_tf32_mem_instances
* add instances of device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances
* tf32:conv:add instances for base class DeviceConvFwd
* tf32:conv:add instances for base class DeviceGroupedConvBwdDataMultipleD
* tf32:conv:add instances for base class DeviceGroupedConvBwdWeight
* add tf32 in profiler
* remove gnhwc/ngchw/ngcdhw instances
* remove non-ndhwgc/nhwgc/nhwc instances
* add check in IsSupportedArgument()
This commit is contained in:
yinglu
2025-10-10 15:28:17 +08:00
committed by GitHub
parent ad7a215aba
commit fada1a3cae
56 changed files with 2119 additions and 152 deletions

View File

@@ -226,6 +226,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
{
return profile(I1, GNWC{}, GKXC{}, GNWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
@@ -245,6 +251,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
{
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
@@ -292,6 +304,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
{
return profile(I1, NWGC{}, GKXC{}, NWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{
@@ -311,6 +329,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
{
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
{
@@ -326,6 +350,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
{
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
{
@@ -341,6 +371,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
{
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{