From b1d03f7e8a23fdad9d09e893e35d7069d68abafd Mon Sep 17 00:00:00 2001 From: joyeamd Date: Mon, 9 Jun 2025 19:37:44 +0800 Subject: [PATCH] update kernel --- ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 110 +++++++++++------- 1 file changed, 65 insertions(+), 45 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 20afe6f595..bc83a1c7a4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -74,28 +74,28 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOp a_element_op, - const BElementwiseOp b_element_op, - const CDEElementwiseOp cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_, - const Block2ETileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const ComputePtrOffsetOfN compute_ptr_offset_of_n, - const index_t KBatch) + kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOp a_element_op, + const BElementwiseOp b_element_op, + const CDEElementwiseOp cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const ComputePtrOffsetOfN compute_ptr_offset_of_n, + const index_t KBatch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch); @@ -394,22 +394,30 @@ enum DepthwiseConv2dDirection DIRECTION_BACKWARD }; -template -__global__ void kernel_grouped_conv_bwd_data_optimized_v2(Argument& arg) + index_t down_w, + index_t pad_h, + index_t pad_w> +__global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __restrict__ p_gradOut, + const ABDataType* __restrict__ p_weight, + EDataType* __restrict__ p_gradIn, + const index_t out_width, + const index_t out_height, + const index_t in_width, + const index_t in_height, + const index_t group_num, + const index_t whole_batch_num) { - using ABDataType = typename Argument::ABDataType; - using EDataType = typename Argument::EDataType; - constexpr index_t ElementPerInFP4 = 16 / sizeof(ABDataType); constexpr index_t ElementPerOutFP4 = 16 / sizeof(EDataType); @@ -418,14 +426,12 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(Argument& arg) static_assert(GroupPerBlock == 32, "Currently only support GroupPerBlock == 32"); static_assert(BatchPerBlock % WaveNum == 0, "Currently BatchPerBlock should be dividable by WaveNum"); - constexpr index_t BatchIter = BatchPerBlock / WaveNum; constexpr index_t TileInW = ((TileOutW - 1) * down_w + kernelW - 1) / up_w + 1; constexpr index_t TileInH = ((TileOutH - 1) * down_h + kernelH - 1) / up_h + 1; // WaveNum * GroupPerBlk * TileInH * TileInW constexpr index_t GroupPerBlockInFP4 = GroupPerBlock / ElementPerInFP4; - constexpr index_t WaveNum = BlockSize / warpSize; __shared__ volatile ABDataType shmem_k[kernelH * kernelW * GroupPerBlock]; // layout : H->W->G __shared__ volatile ABDataType shmem_x[WaveNum * TileInH * TileInW * GroupPerBlock]; // layout : B->H->W->G will use double buffer to go through BatchPerBlock @@ -513,11 +519,11 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(Argument& arg) rel_in_h * TileInW * GroupPerBlockInFP4 + rel_in_w * GroupPerBlockInFP4 + group_id * ElementPerInFP4; - bool is_in_bound = (in_x >= 0 && in_x < inWidth) && (in_y >= 0 && in_y < inHeight); - float4_t v_0{0.f, 0.f, 0.f, 0.f}; + bool is_in_bound = (in_x >= 0 && in_x < in_width) && (in_y >= 0 && in_y < in_height); + float4_t v{0.f, 0.f, 0.f, 0.f}; if(is_in_bound) { - v_0 = reinterpret_cast( + v = reinterpret_cast( arg.p_a_grid_)[ingrad_offset / ElementPerInFP4]; } @@ -577,7 +583,7 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(Argument& arg) int outgrad_offset = (group_start_id_per_blk + group_out_id * ElementPerInFP4) * outgrad_group_stride + - (batch_start_id_per_blk + batch_id_per_wave) * outgrad_batch_stride + + (batch_start_id_per_blk + batch_id + batch_id_per_wave) * outgrad_batch_stride + out_y * outgrad_row_stride + out_x * outgrad_col_stride; reinterpret_cast(p_gradOut)[outgrad_offset / ElementPerFP4] = @@ -1386,17 +1392,31 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 { if(stride_y == 1 && stride_x == 1 && pad_y == 1 && pad_x == 1) { - return kernel_grouped_conv_bwd_data_optimized; + // return kernel_grouped_conv_bwd_data_optimized; + return kernel_grouped_conv_bwd_data_optimized_v2; } else if(stride_y == 2 && stride_x == 2 && pad_y == 1 && pad_x == 1) {