mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
update kernel
This commit is contained in:
@@ -1590,96 +1590,23 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if(filter_y == 3 && filter_x == 3)
|
||||
// {
|
||||
// if(stride_y == 1 && stride_x == 1 && pad_y == 1 && pad_x == 1)
|
||||
// {
|
||||
// // return kernel_grouped_conv_bwd_data_optimized<ADataType,
|
||||
// // EDataType,
|
||||
// // GroupPerBlock,
|
||||
// // BatchPerBlock,
|
||||
// // BlockDim,
|
||||
// // 3,
|
||||
// // 3,
|
||||
// // 1,
|
||||
// // 1,
|
||||
// // 1,
|
||||
// // 1>;
|
||||
// return kernel_grouped_conv_bwd_data_optimized_v2<ADataType,
|
||||
// EDataType,
|
||||
// DIRECTION_BACKWARD,
|
||||
// BlockDim,
|
||||
// BatchPerBlock,
|
||||
// GroupPerBlock,
|
||||
// 4,
|
||||
// 4,
|
||||
// 3,
|
||||
// 3,
|
||||
// 1,
|
||||
// 1,
|
||||
// 1,
|
||||
// 1,
|
||||
// 1,
|
||||
// 1>;
|
||||
// }
|
||||
// else if(stride_y == 2 && stride_x == 2 && pad_y == 1 && pad_x == 1)
|
||||
// {
|
||||
// return kernel_grouped_conv_bwd_data_optimized<ADataType,
|
||||
// EDataType,
|
||||
// GroupPerBlock,
|
||||
// BatchPerBlock,
|
||||
// BlockDim,
|
||||
// 3,
|
||||
// 3,
|
||||
// 2,
|
||||
// 2,
|
||||
// 1,
|
||||
// 1>;
|
||||
// }
|
||||
// }
|
||||
// else if(filter_y == 5 && filter_x == 5)
|
||||
// {
|
||||
// if(stride_y == 1 && stride_x == 1 && pad_y == 2 && pad_x == 2)
|
||||
// {
|
||||
// return kernel_grouped_conv_bwd_data_optimized<ADataType,
|
||||
// EDataType,
|
||||
// GroupPerBlock,
|
||||
// BatchPerBlock,
|
||||
// BlockDim,
|
||||
// 5,
|
||||
// 5,
|
||||
// 1,
|
||||
// 1,
|
||||
// 2,
|
||||
// 2>;
|
||||
// }
|
||||
// else if(stride_y == 2 && stride_x == 2 && pad_y == 2 && pad_x == 2)
|
||||
// {
|
||||
// return kernel_grouped_conv_bwd_data_optimized<ADataType,
|
||||
// EDataType,
|
||||
// GroupPerBlock,
|
||||
// BatchPerBlock,
|
||||
// BlockDim,
|
||||
// 5,
|
||||
// 5,
|
||||
// 2,
|
||||
// 2,
|
||||
// 2,
|
||||
// 2>;
|
||||
// }
|
||||
// }
|
||||
auto default_kernel = &kernel_grouped_conv_bwd_data_optimized<ADataType,
|
||||
EDataType,
|
||||
GroupPerBlock,
|
||||
BatchPerBlock,
|
||||
512,
|
||||
5,
|
||||
5,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
2>;
|
||||
auto default_kernel =
|
||||
&kernel_grouped_conv_bwd_data_optimized_v2<ADataType,
|
||||
EDataType,
|
||||
DIRECTION_BACKWARD,
|
||||
BlockDim,
|
||||
BatchPerBlock,
|
||||
GroupPerBlock,
|
||||
4,
|
||||
4,
|
||||
6,
|
||||
6,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
2>;
|
||||
return static_cast<decltype(default_kernel)>(nullptr);
|
||||
};
|
||||
const auto kernel = kernel_selector();
|
||||
@@ -1705,58 +1632,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
arg.a_g_n_k_wos_lengths_[NDimSpatial + 2],
|
||||
arg.e_g_n_c_wis_lengths_[NDimSpatial + 1],
|
||||
arg.e_g_n_c_wis_lengths_[NDimSpatial + 2],
|
||||
arg.b_g_k_c_xs_lengths_[NDimSpatial + 1],
|
||||
arg.b_g_k_c_xs_lengths_[NDimSpatial + 2],
|
||||
arg.a_g_n_k_wos_lengths_[0],
|
||||
arg.a_g_n_k_wos_lengths_[1]);
|
||||
// const auto kernel =
|
||||
// kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
|
||||
// GridwiseGemm,
|
||||
// ADataType, // TODO: distiguish A/B datatype
|
||||
// typename GridwiseGemm::DsGridPointer,
|
||||
// EDataType,
|
||||
// AElementwiseOp,
|
||||
// BElementwiseOp,
|
||||
// CDEElementwiseOp,
|
||||
// DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
// DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
// DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
// DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
// Block2ETileMap,
|
||||
// ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
// ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
// has_main_loop,
|
||||
// ElementOp>;
|
||||
|
||||
// return launch_and_time_kernel(
|
||||
// stream_config,
|
||||
// kernel,
|
||||
// dim3(gdx, gdy, gdz),
|
||||
// dim3(BlockSize),
|
||||
// 0,
|
||||
// p_a_grid,
|
||||
// p_b_grid,
|
||||
// arg.p_ds_grid_,
|
||||
// p_e_grid,
|
||||
// arg.a_element_op_,
|
||||
// arg.b_element_op_,
|
||||
// arg.cde_element_op_,
|
||||
// arg.a_grid_desc_ak0_m_ak1_container_[i],
|
||||
// arg.b_grid_desc_bk0_n_bk1_container_[i],
|
||||
// arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
|
||||
// arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
|
||||
// arg.block_2_etile_map_container_[i],
|
||||
// arg.compute_ptr_offset_of_batch_,
|
||||
// arg.compute_ptr_offset_of_n_,
|
||||
// arg.k_batch_);
|
||||
};
|
||||
|
||||
// if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK, arg.k_batch_))
|
||||
// {
|
||||
// ave_time += launch_kernel(integral_constant<bool, true>{});
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
ave_time += launch_kernel();
|
||||
// }
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
|
||||
Reference in New Issue
Block a user