ck: add tf32 in DTYPES to control instances build(#3317)

[ROCm/composable_kernel commit: 8fec8054b2]
This commit is contained in:
yinglu
2025-12-08 16:24:20 +08:00
committed by GitHub
parent 771f37e4aa
commit fc7547a552
24 changed files with 177 additions and 140 deletions

View File

@@ -105,9 +105,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
using INT8 = int8_t;
using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
#if defined(__gfx942__) || defined(__gfx950__)
using TF32 = ck::tf32_t;
#endif
//
using GNWC = ck::tensor_layout::convolution::GNWC;
@@ -228,9 +226,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
@@ -253,9 +249,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
@@ -280,9 +274,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
// NHWGC_GKYXC_NHWGK
@@ -306,9 +298,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -331,9 +321,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -352,9 +340,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
@@ -373,9 +359,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -416,9 +400,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
// NGCDHW_GKCZYX_NGKDHW
@@ -439,9 +421,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}