From aa0b979887a12eec94d5ea13e1c69e00a001d271 Mon Sep 17 00:00:00 2001 From: Daming Feng Date: Mon, 6 Nov 2023 18:33:11 -0600 Subject: [PATCH] Add compute type check for convolution instances (#1015) * add compute type check for fp16 in forward convolution instances * Add compute type check for default compute types --------- Co-authored-by: Bartlomiej Kocot --- .../gpu/grouped_convolution_backward_data.hpp | 49 +++++++++----- .../grouped_convolution_backward_weight.hpp | 65 +++++++++++++------ .../gpu/grouped_convolution_forward.hpp | 51 ++++++++------- 3 files changed, 108 insertions(+), 57 deletions(-) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index 5cdaa48bec..09885ccd90 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -496,7 +496,8 @@ struct DeviceOperationInstanceFactory< { #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs); add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( @@ -507,14 +508,16 @@ struct DeviceOperationInstanceFactory< #endif #ifdef CK_ENABLE_FP32 else if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 else if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( op_ptrs); @@ -522,7 +525,9 @@ struct DeviceOperationInstanceFactory< #endif #ifdef CK_ENABLE_INT8 else if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs); add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( @@ -535,7 +540,8 @@ struct DeviceOperationInstanceFactory< { #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs); add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( @@ -546,14 +552,16 @@ struct DeviceOperationInstanceFactory< #endif #ifdef CK_ENABLE_FP32 else if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 else if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( op_ptrs); @@ -561,7 +569,9 @@ struct DeviceOperationInstanceFactory< #endif #ifdef CK_ENABLE_INT8 else if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs); add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( @@ -578,7 +588,8 @@ struct DeviceOperationInstanceFactory< { #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( op_ptrs); @@ -590,7 +601,8 @@ struct DeviceOperationInstanceFactory< #endif #ifdef CK_ENABLE_FP32 else if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( op_ptrs); @@ -598,7 +610,8 @@ struct DeviceOperationInstanceFactory< #endif #ifdef CK_ENABLE_BF16 else if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( op_ptrs); @@ -606,7 +619,9 @@ struct DeviceOperationInstanceFactory< #endif #ifdef CK_ENABLE_INT8 else if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances( op_ptrs); @@ -642,7 +657,8 @@ struct DeviceOperationInstanceFactory< #endif #ifdef CK_ENABLE_FP32 else if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( op_ptrs); @@ -650,7 +666,8 @@ struct DeviceOperationInstanceFactory< #endif #ifdef CK_ENABLE_BF16 else if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( op_ptrs); @@ -658,7 +675,9 @@ struct DeviceOperationInstanceFactory< #endif #ifdef CK_ENABLE_INT8 else if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index f15008349a..b8ca2c5fac 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -618,7 +618,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); @@ -628,7 +629,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); @@ -638,7 +640,9 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && + is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( @@ -655,21 +659,25 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances( op_ptrs); @@ -685,7 +693,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances( @@ -697,7 +706,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances( @@ -709,7 +719,9 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && + is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( @@ -725,7 +737,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances( @@ -737,7 +750,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances( @@ -749,7 +763,9 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && + is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( @@ -768,7 +784,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances( @@ -780,7 +797,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances( @@ -796,7 +814,9 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && + is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( @@ -808,7 +828,9 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( op_ptrs); @@ -822,7 +844,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances( @@ -851,7 +874,9 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && + is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( @@ -863,7 +888,9 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 888c00f900..9043ebc545 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -928,28 +928,29 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && - is_same_v && is_same_v) + is_same_v && + is_same_v && is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(op_ptrs); } @@ -961,7 +962,7 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); } @@ -969,7 +970,7 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); } @@ -977,7 +978,7 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); @@ -989,7 +990,7 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); } @@ -997,7 +998,8 @@ struct DeviceOperationInstanceFactory && - is_same_v && is_same_v) + is_same_v && + is_same_v && is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs); } @@ -1005,7 +1007,7 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs); add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(op_ptrs); @@ -1021,7 +1023,7 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); } @@ -1029,7 +1031,7 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); } @@ -1037,7 +1039,7 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); } @@ -1045,7 +1047,7 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); } @@ -1053,14 +1055,15 @@ struct DeviceOperationInstanceFactory && - is_same_v && is_same_v) + is_same_v && + is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_INT8 else if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(op_ptrs); add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(op_ptrs); @@ -1075,14 +1078,14 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); @@ -1095,14 +1098,15 @@ struct DeviceOperationInstanceFactory && - is_same_v && is_same_v) + is_same_v && + is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(op_ptrs); @@ -1119,7 +1123,7 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); } @@ -1148,14 +1152,15 @@ struct DeviceOperationInstanceFactory && - is_same_v && is_same_v) + is_same_v && + is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(op_ptrs);