ck: add tf32 in DTYPES to control instances build(#3317)

This commit is contained in:
yinglu
2025-12-08 16:24:20 +08:00
committed by GitHub
parent 86a84ae611
commit 8fec8054b2
24 changed files with 177 additions and 140 deletions

View File

@@ -84,9 +84,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
using F32 = float;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
#if defined(__gfx942__)
using TF32 = ck::tf32_t;
#endif
using namespace ck::tensor_layout::convolution;
@@ -143,9 +141,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -164,9 +160,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -185,9 +179,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
else if(layout == ConvLayout::NGCHW_GKCYX_NGKHW)
@@ -206,9 +198,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
}
@@ -230,9 +220,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -251,9 +239,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -272,9 +258,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -293,9 +277,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
}