diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 01b06f8d79..aeffadddd0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -1590,96 +1590,23 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } } } - - // if(filter_y == 3 && filter_x == 3) - // { - // if(stride_y == 1 && stride_x == 1 && pad_y == 1 && pad_x == 1) - // { - // // return kernel_grouped_conv_bwd_data_optimized; - // return kernel_grouped_conv_bwd_data_optimized_v2; - // } - // else if(stride_y == 2 && stride_x == 2 && pad_y == 1 && pad_x == 1) - // { - // return kernel_grouped_conv_bwd_data_optimized; - // } - // } - // else if(filter_y == 5 && filter_x == 5) - // { - // if(stride_y == 1 && stride_x == 1 && pad_y == 2 && pad_x == 2) - // { - // return kernel_grouped_conv_bwd_data_optimized; - // } - // else if(stride_y == 2 && stride_x == 2 && pad_y == 2 && pad_x == 2) - // { - // return kernel_grouped_conv_bwd_data_optimized; - // } - // } - auto default_kernel = &kernel_grouped_conv_bwd_data_optimized; + auto default_kernel = + &kernel_grouped_conv_bwd_data_optimized_v2; return static_cast(nullptr); }; const auto kernel = kernel_selector(); @@ -1705,58 +1632,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 arg.a_g_n_k_wos_lengths_[NDimSpatial + 2], arg.e_g_n_c_wis_lengths_[NDimSpatial + 1], arg.e_g_n_c_wis_lengths_[NDimSpatial + 2], + arg.b_g_k_c_xs_lengths_[NDimSpatial + 1], + arg.b_g_k_c_xs_lengths_[NDimSpatial + 2], arg.a_g_n_k_wos_lengths_[0], arg.a_g_n_k_wos_lengths_[1]); - // const auto kernel = - // kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle< - // GridwiseGemm, - // ADataType, // TODO: distiguish A/B datatype - // typename GridwiseGemm::DsGridPointer, - // EDataType, - // AElementwiseOp, - // BElementwiseOp, - // CDEElementwiseOp, - // DeviceOp::AGridDesc_AK0_M_AK1, - // DeviceOp::BGridDesc_BK0_N_BK1, - // DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - // DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - // Block2ETileMap, - // ComputePtrOffsetOfStridedBatch, - // ComputePtrOffsetOfStridedBatch, - // has_main_loop, - // ElementOp>; - - // return launch_and_time_kernel( - // stream_config, - // kernel, - // dim3(gdx, gdy, gdz), - // dim3(BlockSize), - // 0, - // p_a_grid, - // p_b_grid, - // arg.p_ds_grid_, - // p_e_grid, - // arg.a_element_op_, - // arg.b_element_op_, - // arg.cde_element_op_, - // arg.a_grid_desc_ak0_m_ak1_container_[i], - // arg.b_grid_desc_bk0_n_bk1_container_[i], - // arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], - // arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], - // arg.block_2_etile_map_container_[i], - // arg.compute_ptr_offset_of_batch_, - // arg.compute_ptr_offset_of_n_, - // arg.k_batch_); }; - // if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK, arg.k_batch_)) - // { - // ave_time += launch_kernel(integral_constant{}); - // } - // else - // { ave_time += launch_kernel(); - // } } return ave_time;