From 4693c2c2f1c57e9ec45aebc773d7b521db77f316 Mon Sep 17 00:00:00 2001 From: yinglu Date: Fri, 19 Dec 2025 09:17:29 +0800 Subject: [PATCH] ck:tf32:complement CK_ENABLE_TF32 controls (#3426) [ROCm/composable_kernel commit: ba897f8435338dbc94db5ddccdf2c7b4cdc4f142] --- include/ck/config.h.in | 4 - .../grouped_convolution_backward_data_xdl.inc | 59 +-- ...rouped_convolution_backward_weight_xdl.inc | 92 ++-- ...nvolution_forward_bias_bnorm_clamp_xdl.inc | 214 ++++----- ...ped_convolution_forward_bias_clamp_xdl.inc | 436 +++++++++--------- .../grouped_convolution_forward_clamp_xdl.inc | 404 ++++++++-------- .../grouped_convolution_forward_comp_xdl.inc | 7 + ...uped_convolution_forward_mem_inter_xdl.inc | 7 + ...uped_convolution_forward_mem_intra_xdl.inc | 8 + .../gpu/grouped_convolution_forward_xdl.inc | 36 +- ...d_convolution_forward_xdl_large_tensor.inc | 7 + ..._convolution_forward_xdl_merged_groups.inc | 33 +- .../src/profile_grouped_conv_bwd_data.cpp | 3 +- 13 files changed, 665 insertions(+), 645 deletions(-) diff --git a/include/ck/config.h.in b/include/ck/config.h.in index 113bf99243..f5421e7d5e 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -56,10 +56,8 @@ #define CK_ENABLE_FP32 "ON" #endif #ifndef CK_ENABLE_TF32 -#if defined(__gfx942__) || defined(__gfx95__) #define CK_ENABLE_TF32 "ON" #endif -#endif #ifndef CK_ENABLE_FP64 #define CK_ENABLE_FP64 "ON" #endif @@ -91,10 +89,8 @@ #endif #ifndef CK_ENABLE_TF32 -#if defined(__gfx942__) || defined(__gfx95__) #cmakedefine CK_ENABLE_TF32 @CK_ENABLE_TF32@ #endif -#endif #ifndef CK_ENABLE_FP64 #cmakedefine CK_ENABLE_FP64 @CK_ENABLE_FP64@ diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc index eb92f803ae..7c61f3ee66 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc @@ -127,6 +127,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instance PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances( std::vector>>& instances); -void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances( - std::vector>>& instances); - void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_optimized_loads_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( std::vector>>& instances); -void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances( - std::vector>>& instances); - void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_optimized_loads_instances( std::vector>>& instances); -void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( - std::vector>>& instances); - void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances( std::vector>>& instances); -void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev2_instances( +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instances( + std::vector>>& instances); + +#endif + +#ifdef CK_ENABLE_TF32 +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( std::vector>>& instances); -void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instances( +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev2_instances( std::vector>>& instances); + PassThrough, + TF32, + TF32>>>& instances); void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev5_instances( std::vector>>& instances); -void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instances( - std::vector>>& instances); - void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev2_instances( std::vector>>& instances); -void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instances( - std::vector>>& instances); - void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev5_instances( std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector>>& instances); - void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instances( std::vector>>& instances); - void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instances( std::vector>>& instances); - void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instances( std::vector>>& instances); - void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); + +#endif + +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( std::vector< std::unique_ptr>>& instances); - -void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - std::vector, - NDHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - BiasNormalizeInInferClamp>>>& instances); - -void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( - std::vector, - NDHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - BiasNormalizeInInferClamp>>>& instances); - -void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( - std::vector, - NDHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - BiasNormalizeInInferClamp>>>& instances); - -void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( - std::vector, - NDHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - BiasNormalizeInInferClamp>>>& instances); - -void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( - std::vector, - NDHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - BiasNormalizeInInferClamp>>>& instances); - -void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( - std::vector, - NDHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - BiasNormalizeInInferClamp>>>& instances); - -void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( - std::vector, - NDHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - BiasNormalizeInInferClamp>>>& instances); - void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector>>& instances); -void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( - std::vector, - NHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - AddClamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( std::vector>>& instances); -void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances( - std::vector, - NHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - AddClamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances( std::vector>>& instances); -void add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances( - std::vector, - NHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - AddClamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( std::vector>>& instances); -void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( - std::vector, - NHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - AddClamp, - TF32, - TF32>>>& instances); - -void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( - std::vector, - NHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - AddClamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( std::vector>>& instances); -void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances( - std::vector, - NHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - AddClamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( std::vector>>& instances); -void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances( - std::vector, - NHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - AddClamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( std::vector>>& instances); -void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances( - std::vector, - NHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - AddClamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( - std::vector, - NDHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - AddClamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instances( - std::vector, - NDHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - AddClamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( - std::vector, - NDHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - AddClamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( - std::vector, - NDHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - AddClamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances( - std::vector, - NDHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - AddClamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp>>>& instances); + +#endif + +#ifdef CK_ENABLE_TF32 +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances( std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( - std::vector, - NDHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - AddClamp>>>& instances); - void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances( std::vector>>& instances); - #endif } // namespace instance diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc index 7455bb4e49..bceea56c62 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc @@ -508,22 +508,6 @@ void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances( PassThrough, Clamp>>>& instances); -void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( - std::vector, - NHWGK, - F32, - F32, - Tuple<>, - F32, - PassThrough, - PassThrough, - Clamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( std::vector>>& instances); -void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances( - std::vector, - NHWGK, - F32, - F32, - Tuple<>, - F32, - PassThrough, - PassThrough, - Clamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances( std::vector>>& instances); -void add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances( - std::vector, - NHWGK, - F32, - F32, - Tuple<>, - F32, - PassThrough, - PassThrough, - Clamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( std::vector>>& instances); -void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( - std::vector, - NHWGK, - F32, - F32, - Tuple<>, - F32, - PassThrough, - PassThrough, - Clamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( std::vector>>& instances); -void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances( - std::vector, - NHWGK, - F32, - F32, - Tuple<>, - F32, - PassThrough, - PassThrough, - Clamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( std::vector>>& instances); -void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances( - std::vector, - NHWGK, - F32, - F32, - Tuple<>, - F32, - PassThrough, - PassThrough, - Clamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( std::vector>>& instances); -void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances( - std::vector, - NHWGK, - F32, - F32, - Tuple<>, - F32, - PassThrough, - PassThrough, - Clamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( - std::vector, - NDHWGK, - F32, - F32, - Tuple<>, - F32, - PassThrough, - PassThrough, - Clamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instances( - std::vector, - NDHWGK, - F32, - F32, - Tuple<>, - F32, - PassThrough, - PassThrough, - Clamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( - std::vector, - NDHWGK, - F32, - F32, - Tuple<>, - F32, - PassThrough, - PassThrough, - Clamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( - std::vector, - NDHWGK, - F32, - F32, - Tuple<>, - F32, - PassThrough, - PassThrough, - Clamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances( - std::vector, - NDHWGK, - F32, - F32, - Tuple<>, - F32, - PassThrough, - PassThrough, - Clamp, - TF32, - TF32>>>& instances); - void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp>>>& instances); + +#endif + +#ifdef CK_ENABLE_TF32 +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances( std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( - std::vector, - NDHWGK, - F32, - F32, - Tuple<>, - F32, - PassThrough, - PassThrough, - Clamp>>>& instances); - void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances( std::vector>>& instances); + +#endif + +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances( std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances( std::vector>>& instances); + +#endif + +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances( std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances( std::vector>>& instances); + +#endif + +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances( std::vector>>& instances); + +#endif + +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances( std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( std::vector>>& instances); - void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( + std::vector>>& instances); + +#endif + +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( - std::vector>>& instances); - void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instances( std::vector>>& instances); + +#endif + +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances( std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instances( - std::vector>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -193,6 +195,9 @@ void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_in PassThrough, PassThrough>>>& instances); +#endif + +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector