From 4fe245e8d57770cc92c53747a96c3ac484de0aa0 Mon Sep 17 00:00:00 2001 From: joye Date: Fri, 6 Jun 2025 18:35:13 +0800 Subject: [PATCH] update kernel; not correctly --- ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 208 +++++++++++++++++- 1 file changed, 197 insertions(+), 11 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 1902d9564a..27a787ffe5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -369,12 +369,196 @@ __global__ void } } -// template -// __global__ void kernel_grouped_conv_bwd_data_optimized_v2(Argument& arg) -// { +/* + one wave implement one batch 32 group; one group implements tile_x * tile_y output; what about the + size not divided by 4? -// } + one wave implements 32 x 4 x 4; one thread can fetch 8 groups h0w0 data; 4 thread will 32 groups' + h0w0; 64 thread can load 4x4; + tid 0 and 32 works on group 0; tid 1 and 33 works on group 1; tid 2 and 34 works on + group 2; tid 3 and 35 works on group 3; etcs. + + wave 1 in the same block goes through the batch direction + gridDim(ceiling(InWidth / TileOutW), ceiling(InHeight / TileOutH), (WholeBatchNum / BatchPerBlock) + * (WholeGroupNum / GroupPerBlk)) + + BlockDim(warpSize * warpNum, 1, 1) + when foward, up means dilate, down means stride + when backward, up means stride, down means dilate +*/ +enum DepthwiseConv2dDirection +{ + DIRECTION_FORWARD, + DIRECTION_BACKWARD +}; + +template +__global__ void kernel_grouped_conv_bwd_data_optimized_v2(Argument& arg) +{ + using ABDataType = typename Argument::ABDataType; + using EDataType = typename Argument::EDataType; + + constexpr index_t ElementPerFP4 = 16 / sizeof(ABDataType); + + static_assert(GroupPerBlk == 32, "Currently only support GroupPerWave == 32"); + constexpr index_t TileInW = ((TileOutW - 1) * down_w + kernelW - 1) / up_w + 1; + constexpr index_t TileInH = ((TileOutH - 1) * down_h + kernelH - 1) / up_h + 1; + + constexpr index_t WaveNum = BlockSize / warpSize; + __shared__ volatile ABDataType shmem_k[GroupPerBlk * kernelH * kernelW]; // layout : H->W->G + __shared__ volatile ABDataType shmem_x[2][WaveNum * GroupPerBlk * TileInH * TileInW]; + // layout : B->H->W->G will use double buffer to go through BatchPerBlock + // when backward data, will be gradOut, when forward, will be x + + const int output_tile_w = blockIdx.x * TileOutW; + const int output_tile_h = blockIdx.y * TileOutH; + + if(output_tile_w >= outWidth || output_tile_h >= outHeight) + { + return; // out of bound + } + + const int GroupBatchNum = WholeGroupNum / GroupPerBlk; + + // NHWGK todo use the stride + const int outgrad_group_stride = 1; + const int outgrad_batch_stride = group_num * out_height * out_width; + const int outgrad_row_stride = group_num * out_width; + const int outgrad_col_stride = group_num; + // NHWGC + const int ingrad_group_stride = 1; + const int ingrad_batch_stride = group_num * in_height * in_width; + const int ingrad_row_stride = group_num * in_width; + const int ingrad_col_stride = group_num; + + int tid = threadIdx.x; + + // this offset is used to calculate the start offset of the group and batch + const int group_start_id_per_blk = (blockIdx.z % GroupBatchNum) * GroupPerBlk; + const int batch_start_id_per_blk = (blockIdx.z / GroupBatchNum) * BatchPerBlk; + + int wave_id = __builtin_amdgcn_readfirstlane(tid / warpSize); + + int tile_mid_w = tileOutW * down_w + up_w - 1 - pad_w; + int tile_mid_h = tileOutH * down_h + up_h - 1 - pad_h; + int tile_in_x = tile_mid_w / up_w; + int tile_in_y = tile_mid_h / up_h; + + // WaveNum * GroupPerBlk * TileInH * TileInW + constexpr index_t GroupPerBlockInFP4 = GroupPerBlk / ElementPerFP4; + + int group_id = tid % GroupPerBlockInFP4; + int rel_in_w = (tid / GroupPerBlockInFP4) % TileInW; + int rel_in_h = (tid / (GroupPerBlockInFP4 * TileInW)) % TileInH; + + int in_x = rel_in_w + tile_in_x; + int in_y = rel_in_h + tile_in_y; + + int local_batch_id = wave_id; + int ingrad_offset = (group_start_id_per_blk + group_id * ElementPerFP4) * ingrad_group_stride + + (batch_start_id_per_blk + local_batch_id) * ingrad_batch_stride + + in_y * ingrad_row_stride + in_x * ingrad_col_stride; + + int shmem_offset = wave_id * GroupPerBlockInFP4 * TileInH * TileInW + + rel_in_h * TileInW * GroupPerBlockInFP4 + rel_in_w * GroupPerBlockInFP4 + + group_id * ElementPerFP4; + + bool is_in_bound = (in_x >= 0 && in_x < inWidth) && (in_y >= 0 && in_y < inHeight); + + // static_assert( + // WaveNum * GroupPerBlockInFP4 * TileInH * TileInW % BlockSize == 0, + // "WaveNum * GroupPerBlockInFP4 * TileInH * TileInW must be divisible by BlockSize"); + constexpr int InLoopNum = WaveNum * GroupPerBlockInFP4 * TileInH * TileInW; + +#pragma unroll + for(int i = tid; i < InLoopNum; i += BlockSize) + { + float4_t v_0{0.f, 0.f, 0.f, 0.f}; + if(is_in_bound) + { + v_0 = reinterpret_cast(p_in)[ingrad_offset / ElementPerFP4]; + } + + reinterpret_cast(shmem_x)[shmem_offset] = v; + } + + // load weight to shared memory + // global weight layout : GKYXC; shared weight layout : Y->X->G + for(int i = tid; i < GroupPerBlk * kernelH * kernelW; i += BlockSize) + { + int local_group_id = i / (kernelH * kernelW); + int glb_group_id = local_group_id + group_start_id_per_blk; + int kernel_h = (i % (kernelH * kernelW)) / kernelW; + int kernel_w = i % kernelW; + + shmem_k[kernel_h * kernelW * GroupPerBlk + kernel_w * GroupPerBlk + local_group_id] = + p_weight[glb_group_id * kernelH * kernelW + kernel_h * kernelW + kernel_w]; + } + + int ping = 0; + for(int i = 1; i < BatchPerBlk; i++) + { + ingrad_offset += ingrad_batch_stride / ElementPerFP4; + + block_sync_lds(); +#pragma unroll + for(int i = tid; i < InLoopNum; i += BlockSize) + { + float4_t v{0.f, 0.f, 0.f, 0.f}; + if(is_in_bound) + { + v = reinterpret_cast(p_in)[ingrad_offset / ElementPerFP4]; + } + reinterpret_cast(shmem_x[1 - ping])[shmem_offset] = v; + } + constexpr int OutLoopNum = WaveNum * GroupPerBlockInFP4 * TileOutW * TileOutH; + + for(int out_idx = tid; out_idx < OutLoopNum; out_idx += BlockSize) + { + int rel_out_y = out_idx / TileOutW; + int rel_out_x = out_idx - rel_out_y * TileOutW; + int out_y = rel_out_y + tile_out_y; + int out_x = rel_out_x + tile_out_x; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + if(in_x < 0 || in_x >= inWidth || in_y < 0 || in_y >= inHeight) + continue; + +#pragma unroll + for(int y = 0; y < kernel_h / up_y; y++) +#pragma unroll + for(int x = 0; x < kernel_w / up_x; x++) + { + v += sx[rel_in_y + y][rel_in_x + x] * + sk[kernel_y + y * up_y][kernel_x + x * up_x]; + } + + if(out_x < p.out_w & out_y < p.out_h) + { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) + minor_idx] = v; + } + } + + ping = 1 - ping; + } } // namespace // Conv backward data multiple D: @@ -1151,7 +1335,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 throw std::runtime_error("wrong! device_op has invalid setting"); } - // const index_t gdx = arg.block_2_etile_map_container_[i].CalculateGridSize( + // const index_t gdx = + // arg.block_2_etile_map_container_[i].CalculateGridSize( // arg.e_grid_desc_m_n_container_[i]); // const auto GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1); @@ -1192,7 +1377,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 arg.a_g_n_k_wos_lengths_[1], arg.b_g_k_c_xs_lengths_[NDimSpatial + 1] * arg.b_g_k_c_xs_lengths_[NDimSpatial + 2]); - // const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle< + // const auto kernel = + // kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle< // GridwiseGemm, // ADataType, // TODO: distiguish A/B datatype // typename GridwiseGemm::DsGridPointer, @@ -1571,11 +1757,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - std::cout - << "Warning: Workspace for " - "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::Argument is not " - "allocated, use SetWorkSpacePointer." - << std::endl; + std::cout << "Warning: Workspace for " + "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::" + "Argument is not " + "allocated, use SetWorkSpacePointer." + << std::endl; } return false; }