mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
Conv:TF32: add more instances - 2 (#2879)
* add instances of device_grouped_conv_fwd_xdl_f32_comp_instances * add instances of device_grouped_conv_fwd_xdl_f32_tf32_mem_instances * add instances of device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances * tf32:conv:add instances for base class DeviceConvFwd * tf32:conv:add instances for base class DeviceGroupedConvBwdDataMultipleD * tf32:conv:add instances for base class DeviceGroupedConvBwdWeight * add tf32 in profiler * remove gnhwc/ngchw/ngcdhw instances * remove non-ndhwgc/nhwgc/nhwc instances * add check in IsSupportedArgument()
This commit is contained in:
@@ -226,6 +226,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
{
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
@@ -245,6 +251,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
{
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
@@ -292,6 +304,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
{
|
||||
return profile(I1, NWGC{}, GKXC{}, NWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
@@ -311,6 +329,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
{
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
@@ -326,6 +350,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
{
|
||||
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
|
||||
{
|
||||
@@ -341,6 +371,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
{
|
||||
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user