ck: add tf32 in DTYPES to control instances build(#3317)

This commit is contained in:
yinglu
2025-12-08 16:24:20 +08:00
committed by GitHub
parent 86a84ae611
commit 8fec8054b2
24 changed files with 177 additions and 140 deletions

View File

@@ -84,9 +84,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
using F32 = float;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
#if defined(__gfx942__)
using TF32 = ck::tf32_t;
#endif
using namespace ck::tensor_layout::convolution;
@@ -143,9 +141,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -164,9 +160,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -185,9 +179,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
else if(layout == ConvLayout::NGCHW_GKCYX_NGKHW)
@@ -206,9 +198,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
}
@@ -230,9 +220,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -251,9 +239,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -272,9 +258,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -293,9 +277,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
}

View File

@@ -99,9 +99,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
#if defined(__gfx942__)
using TF32 = ck::tf32_t;
#endif
using namespace ck::tensor_layout::convolution;
@@ -162,9 +160,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
@@ -184,9 +180,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -210,9 +204,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -243,9 +235,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
@@ -270,9 +260,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -306,9 +294,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -340,9 +326,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}

View File

@@ -105,9 +105,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
using INT8 = int8_t;
using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
#if defined(__gfx942__) || defined(__gfx950__)
using TF32 = ck::tf32_t;
#endif
//
using GNWC = ck::tensor_layout::convolution::GNWC;
@@ -228,9 +226,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
@@ -253,9 +249,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
@@ -280,9 +274,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
// NHWGC_GKYXC_NHWGK
@@ -306,9 +298,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -331,9 +321,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -352,9 +340,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
@@ -373,9 +359,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -416,9 +400,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
// NGCDHW_GKCZYX_NGKDHW
@@ -439,9 +421,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}

View File

@@ -105,9 +105,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[])
using F32 = float;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
#if defined(__gfx942__)
using TF32 = ck::tf32_t;
#endif
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
using NDHWGC = ck::tensor_layout::convolution::NDHWGC;
@@ -172,9 +170,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -194,9 +190,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}

View File

@@ -105,9 +105,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[])
using F32 = float;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
#if defined(__gfx942__)
using TF32 = ck::tf32_t;
#endif
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
using NDHWGC = ck::tensor_layout::convolution::NDHWGC;
@@ -175,9 +173,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -197,9 +193,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}