diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp index 7ad83d5ad6..2f048097a1 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp @@ -291,7 +291,7 @@ int main(int argc, char* argv[]) float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << conv->GetTypeString() << std::endl; if(do_verification) @@ -320,18 +320,15 @@ int main(int argc, char* argv[]) { case 3: { auto ref_conv = ReferenceConvNDFwdInstance<3>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } case 2: { auto ref_conv = ReferenceConvNDFwdInstance<2>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } case 1: { auto ref_conv = ReferenceConvNDFwdInstance<1>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } default: { throw std::runtime_error("Unsupported number of spatial dimensions provided!"); diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp index 8a9633d84a..7fa0f0d275 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp @@ -324,18 +324,15 @@ int main(int argc, char* argv[]) { case 3: { auto ref_conv = ReferenceConvNDFwdInstance<3>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } case 2: { auto ref_conv = ReferenceConvNDFwdInstance<2>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } case 1: { auto ref_conv = ReferenceConvNDFwdInstance<1>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } default: { throw std::runtime_error("Unsupported number of spatial dimensions provided!"); diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp index f196d27182..9a1028f88b 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp @@ -322,18 +322,15 @@ int main(int argc, char* argv[]) { case 3: { auto ref_conv = ReferenceConvNDFwdInstance<3>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } case 2: { auto ref_conv = ReferenceConvNDFwdInstance<2>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } case 1: { auto ref_conv = ReferenceConvNDFwdInstance<1>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } default: { throw std::runtime_error("Unsupported number of spatial dimensions provided!"); diff --git a/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp b/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp index ff2cfac1fa..0383197358 100644 --- a/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp +++ b/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp @@ -332,18 +332,15 @@ int main(int argc, char* argv[]) { case 3: { auto ref_conv = ReferenceConvBwdDataInstance<3>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } case 2: { auto ref_conv = ReferenceConvBwdDataInstance<2>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } case 1: { auto ref_conv = ReferenceConvBwdDataInstance<1>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } default: { throw std::runtime_error("Unsupported number of spatial dimensions provided!"); diff --git a/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp index 0fc976c34a..65725d3ae8 100644 --- a/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp +++ b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp @@ -403,18 +403,15 @@ int main(int argc, char* argv[]) { case 3: { auto ref_conv = ReferenceConvBwdWeightInstance<3>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } case 2: { auto ref_conv = ReferenceConvBwdWeightInstance<2>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } case 1: { auto ref_conv = ReferenceConvBwdWeightInstance<1>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } default: { throw std::runtime_error("Unsupported number of spatial dimensions provided!"); diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index f29e59039e..707413dfd3 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -417,6 +417,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; + using Block2CTileMap = BlockToCTileMap_M00_N0_M01; + // GridwiseGemm using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1< BlockSize, @@ -477,8 +479,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, - ck::index_t M01, - ck::index_t N01, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) @@ -490,8 +490,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W c_grid_desc_m_n_{}, c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, in_element_op_{in_element_op}, wei_element_op_{wei_element_op}, out_element_op_{out_element_op}, @@ -520,10 +518,9 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W a_grid_desc_k0_m_k1_ = descs[I0]; b_grid_desc_k0_n_k1_ = descs[I1]; - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + c_grid_desc_m_n_ = descs[I2]; - c_grid_desc_m_n_ = descs[I2]; + block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_}; if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, @@ -546,9 +543,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W typename GridwiseGemm:: CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; + Block2CTileMap block_2_ctile_map_; InElementwiseOperation in_element_op_; WeiElementwiseOperation wei_element_op_; OutElementwiseOperation out_element_op_; @@ -661,7 +656,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, - remove_reference_t, + Block2CTileMap, true>; ave_time = launch_and_time_kernel( @@ -695,7 +690,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, - remove_reference_t, + Block2CTileMap, false>; ave_time = launch_and_time_kernel( @@ -814,8 +809,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W conv_filter_dilations, input_left_pads, input_right_pads, - 1, - 1, in_element_op, wei_element_op, out_element_op}; @@ -854,8 +847,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W conv_filter_dilations, input_left_pads, input_right_pads, - 1, - 1, in_element_op, wei_element_op, out_element_op); diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp index f0be2498e7..1678f9991e 100644 --- a/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -607,6 +607,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; + using Block2CTileMap = BlockToCTileMap_M00_N0_M01; + // GridwiseGemm using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, @@ -664,8 +666,6 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, - ck::index_t M01, - ck::index_t N01, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) @@ -677,8 +677,6 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K c_grid_desc_m_n_{}, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, in_element_op_{in_element_op}, wei_element_op_{wei_element_op}, out_element_op_{out_element_op}, @@ -705,8 +703,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K a_grid_desc_k0_m_k1_ = descs[I0]; b_grid_desc_k0_n_k1_ = descs[I1]; c_grid_desc_m_n_ = descs[I2]; - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_}; if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, @@ -727,9 +725,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K CGridDesc_M_N c_grid_desc_m_n_; typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; + Block2CTileMap block_2_ctile_map_; InElementwiseOperation in_element_op_; WeiElementwiseOperation wei_element_op_; OutElementwiseOperation out_element_op_; @@ -793,7 +789,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, - remove_reference_t, + Block2CTileMap, true>; ave_time = launch_and_time_kernel(stream_config, @@ -824,7 +820,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, - remove_reference_t, + Block2CTileMap, false>; ave_time = launch_and_time_kernel(stream_config, @@ -955,8 +951,6 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K conv_filter_dilations, input_left_pads, input_right_pads, - 1, - 1, in_element_op, wei_element_op, out_element_op}; @@ -995,8 +989,6 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K conv_filter_dilations, input_left_pads, input_right_pads, - 1, - 1, in_element_op, wei_element_op, out_element_op); @@ -1012,8 +1004,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K auto str = std::stringstream(); // clang-format off - str << "DeviceConv" << std::to_string(NumDimSpatial) - << "DFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + str << "DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" << "<" << BlockSize << ", " << MPerBlock << ", " diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index ffa82a7570..2828655f51 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -314,7 +314,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 using DefaultBlock2CTileMap = remove_cvref_t; - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid,