mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
ck: add tf32 in DTYPES to control instances build(#3317)
This commit is contained in:
@@ -84,9 +84,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
#if defined(__gfx942__)
|
||||
using TF32 = ck::tf32_t;
|
||||
#endif
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
@@ -143,9 +141,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
@@ -164,9 +160,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
@@ -185,9 +179,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NGCHW_GKCYX_NGKHW)
|
||||
@@ -206,9 +198,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -230,9 +220,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F32{}, F32{}, F32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
@@ -251,9 +239,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F32{}, F32{}, F32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
@@ -272,9 +258,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
@@ -293,9 +277,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,9 +99,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F8 = ck::f8_t;
|
||||
using BF8 = ck::bf8_t;
|
||||
#if defined(__gfx942__)
|
||||
using TF32 = ck::tf32_t;
|
||||
#endif
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
@@ -162,9 +160,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
@@ -184,9 +180,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
@@ -210,9 +204,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
@@ -243,9 +235,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
@@ -270,9 +260,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
@@ -306,9 +294,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
@@ -340,9 +326,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -105,9 +105,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
using INT8 = int8_t;
|
||||
using F8 = ck::f8_t;
|
||||
using BF8 = ck::bf8_t;
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
using TF32 = ck::tf32_t;
|
||||
#endif
|
||||
|
||||
//
|
||||
using GNWC = ck::tensor_layout::convolution::GNWC;
|
||||
@@ -228,9 +226,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
@@ -253,9 +249,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
@@ -280,9 +274,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
// NHWGC_GKYXC_NHWGK
|
||||
@@ -306,9 +298,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
@@ -331,9 +321,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
@@ -352,9 +340,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
|
||||
@@ -373,9 +359,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
@@ -416,9 +400,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
// NGCDHW_GKCZYX_NGKDHW
|
||||
@@ -439,9 +421,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -105,9 +105,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[])
|
||||
using F32 = float;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
#if defined(__gfx942__)
|
||||
using TF32 = ck::tf32_t;
|
||||
#endif
|
||||
|
||||
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
|
||||
using NDHWGC = ck::tensor_layout::convolution::NDHWGC;
|
||||
@@ -172,9 +170,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
@@ -194,9 +190,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -105,9 +105,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[])
|
||||
using F32 = float;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
#if defined(__gfx942__)
|
||||
using TF32 = ck::tf32_t;
|
||||
#endif
|
||||
|
||||
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
|
||||
using NDHWGC = ck::tensor_layout::convolution::NDHWGC;
|
||||
@@ -175,9 +173,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
@@ -197,9 +193,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user