diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index 4bb82baabc..2991526851 100644 --- a/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -433,7 +433,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ using namespace ck; const index_t Di = input_spatial_lengths[0]; - const index_t Hi = input_spatial_lengths[2]; + const index_t Hi = input_spatial_lengths[1]; const index_t Wi = input_spatial_lengths[2]; const index_t Do = output_spatial_lengths[0]; @@ -671,11 +671,14 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ return PadDescriptor_M0_1d(desc, gridSize, blockSize); } - using TypeConvertFunctor = + using TypeConvertFp32ToBf16Functor = ck::tensor_operation::element_wise::UnaryTypeConvert; - using GridDesc_M0 = decltype(MakeDescriptor_M0<1>({1}, {1}, 1, 1)); - using GridwiseUEltwise = - GridwiseUnaryElementwise_1D; + using GridDesc_M0 = decltype(MakeDescriptor_M0<1>({1}, {1}, 1, 1)); + using GridwiseUEltwise = GridwiseUnaryElementwise_1D; using ABCGridDescs = decltype(GetABCGridDesc()); @@ -979,33 +982,32 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); - float ave_time = 0; - const auto Run = [&](const auto& kernel) { + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + const auto run_conv = [&](const auto& kernel) { hipGetErrorString(hipMemset( arg.p_c_grid_, 0, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * sizeof(CDataType))); - ave_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); }; // run kernel for bf16 with splitk @@ -1016,22 +1018,21 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * sizeof(AccDataType))); - ave_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - static_cast(arg.p_workspace_), - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + static_cast(arg.p_workspace_), + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); }; // kernel for type conversion @@ -1059,7 +1060,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ // run kernel for type conversion void* p_c_grid_tmp_ = static_cast(arg.p_c_grid_); InDataType* p_c_grid_tmp_bf16_ = static_cast(p_c_grid_tmp_); - const auto Run_type_convert = [&](const auto& kernel) { + const auto run_type_convert = [&](const auto& kernel) { float elapsed_time = launch_and_time_kernel(stream_config, kernel, @@ -1070,14 +1071,15 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ p_c_grid_tmp_bf16_, a_grid_desc_m0_, b_grid_desc_m0_, - TypeConvertFunctor{}); + TypeConvertFp32ToBf16Functor{}); return elapsed_time; }; if constexpr(std::is_same::value) { - if(has_main_k0_block_loop) - { + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + if(kbatch == 1) { const auto kernel = kernel_gemm_xdlops_bwd_weight< @@ -1092,9 +1094,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ InElementwiseOperation, WeiElementwiseOperation, remove_reference_t, - true>; + has_main_loop>; - Run(kernel); + return run_conv(kernel); } else { @@ -1103,7 +1105,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ AccDataType, InDataType, GridDesc_M0, - TypeConvertFunctor>; + TypeConvertFp32ToBf16Functor>; const auto kernel_conv = kernel_gemm_xdlops_bwd_weight< GridwiseGemmAtomicAddFloatBf16Splitk, @@ -1117,56 +1119,28 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ InElementwiseOperation, WeiElementwiseOperation, remove_reference_t, - true>; + has_main_loop>; - run_bf16_splitk(kernel_conv); - ave_time += Run_type_convert(kernel_type_convert); + float elapsed_time = 0; + elapsed_time += run_bf16_splitk(kernel_conv); + elapsed_time += run_type_convert(kernel_type_convert); + return elapsed_time; } + }; + if(has_main_k0_block_loop) + { + ave_time = launch_kernel(integral_constant{}); } else { - if(kbatch == 1) - { - const auto kernel = kernel_gemm_xdlops_bwd_weight< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - OutElementwiseOperation, - InElementwiseOperation, - WeiElementwiseOperation, - remove_reference_t, - false>; - - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_xdlops_bwd_weight< - GridwiseGemmAtomicAddFloatBf16Splitk, - ADataType, // TODO: distiguish A/B datatype - AccDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - OutElementwiseOperation, - InElementwiseOperation, - WeiElementwiseOperation, - remove_reference_t, - false>; - - run_bf16_splitk(kernel); - } + ave_time = launch_kernel(integral_constant{}); } } else { - if(has_main_k0_block_loop) - { + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + if(kbatch == 1) { const auto kernel = kernel_gemm_xdlops_bwd_weight< @@ -1181,9 +1155,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ InElementwiseOperation, WeiElementwiseOperation, remove_reference_t, - true>; + has_main_loop>; - Run(kernel); + return run_conv(kernel); } else { @@ -1199,49 +1173,18 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ InElementwiseOperation, WeiElementwiseOperation, remove_reference_t, - true>; + has_main_loop>; - Run(kernel); + return run_conv(kernel); } + }; + if(has_main_k0_block_loop) + { + ave_time = launch_kernel(integral_constant{}); } else { - if(kbatch == 1) - { - const auto kernel = kernel_gemm_xdlops_bwd_weight< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - OutElementwiseOperation, - InElementwiseOperation, - WeiElementwiseOperation, - remove_reference_t, - false>; - - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_xdlops_bwd_weight< - GridwiseGemmAtomicAdd, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - OutElementwiseOperation, - InElementwiseOperation, - WeiElementwiseOperation, - remove_reference_t, - false>; - - Run(kernel); - } + ave_time = launch_kernel(integral_constant{}); } }