diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index bd23c55a0b..114f2c43a1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -175,7 +175,11 @@ __global__ void // grp3 grp4 grp5 grp6 grp7 // load weight: 64*5*5 half needs 200 requests when use float4 -template +template __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(512, CK_MIN_BLOCK_PER_CU) @@ -207,7 +211,8 @@ __global__ void const int pad_width = pad_x; constexpr int batch_num = BatchPerBlock; - + static_assert(batch_num == BlockSize / warpSize, + "currently only support one wave implement one batch"); // NHWGK todo use the stride const int outgrad_group_stride = 1; const int outgrad_batch_stride = group_num * out_height * out_width; @@ -223,8 +228,7 @@ __global__ void constexpr index_t ElementPerFP4 = 16 / sizeof(ABDataType); -#pragma unroll - for(int index = tid; index < 64 * 5 * 5 / ElementPerFP4; index += blockDim.x) + for(int index = tid; index < 64 * 5 * 5 / ElementPerFP4; index += BlockSize) { reinterpret_cast(shmem_weight)[index] = reinterpret_cast(weight_ptr)[index]; @@ -232,100 +236,94 @@ __global__ void block_sync_lds(); - int batch_iter = blockDim.x / warpSize; int local_grp_id = tid % warpSize; int glb_batch_offset = (blockIdx.x % blockNumPerGroup) * BatchPerBlock; int wave_id = __builtin_amdgcn_readfirstlane(tid / warpSize); - for(int batch_idx = 0; batch_idx < batch_num; batch_idx += batch_iter) + int batch_id_in_glb_mem = wave_id + glb_batch_offset; + const int base_outgrad_offset = (grp_idx + local_grp_id) * outgrad_group_stride + + batch_id_in_glb_mem * outgrad_batch_stride; + const int base_ingrad_offset = + (grp_idx + local_grp_id) * ingrad_group_stride + batch_id_in_glb_mem * ingrad_batch_stride; + const int base_filter_offset = local_grp_id * 25; + + for(int h_idx = 0; h_idx < in_height; h_idx += 2) { - int batch_id_in_glb_mem = wave_id + batch_idx + glb_batch_offset; - const int base_outgrad_offset = (grp_idx + local_grp_id) * outgrad_group_stride + - batch_id_in_glb_mem * outgrad_batch_stride; - const int base_ingrad_offset = (grp_idx + local_grp_id) * ingrad_group_stride + - batch_id_in_glb_mem * ingrad_batch_stride; - const int base_filter_offset = local_grp_id * 25; - - for(int h_idx = 0; h_idx < in_height; h_idx += 2) + for(int w_idx = 0; w_idx < in_width; w_idx += 2) { - for(int w_idx = 0; w_idx < in_width; w_idx += 2) + float sum[4]{0.f}; + const int out_row_start_0 = __builtin_amdgcn_readfirstlane( + max(0, (h_idx - filter_height + pad_height + stride_y) / stride_y)); + const int out_row_end_0 = __builtin_amdgcn_readfirstlane( + min(out_height - 1, (h_idx + pad_height) / stride_y)); + const int out_row_start_1 = __builtin_amdgcn_readfirstlane( + max(0, (h_idx + 1 - filter_height + pad_height + stride_y) / stride_y)); + const int out_row_end_1 = __builtin_amdgcn_readfirstlane( + min(out_height - 1, (h_idx + 1 + pad_height) / stride_y)); + const int out_col_start_0 = __builtin_amdgcn_readfirstlane( + max(0, (w_idx - filter_width + pad_width + stride_x) / stride_x)); + const int out_col_start_1 = __builtin_amdgcn_readfirstlane( + max(0, (w_idx + 1 - filter_width + pad_width + stride_x) / stride_x)); + const int out_col_end_0 = + __builtin_amdgcn_readfirstlane(min(out_width - 1, (w_idx + pad_width) / stride_x)); + const int out_col_end_1 = __builtin_amdgcn_readfirstlane( + min(out_width - 1, (w_idx + 1 + pad_width) / stride_x)); + + for(int out_row = out_row_start_0; out_row <= out_row_end_1; ++out_row) { - float sum[4]{0.f}; - const int out_row_start_0 = __builtin_amdgcn_readfirstlane( - max(0, (h_idx - filter_height + pad_height + stride_y) / stride_y)); - const int out_row_end_0 = __builtin_amdgcn_readfirstlane( - min(out_height - 1, (h_idx + pad_height) / stride_y)); - const int out_row_start_1 = __builtin_amdgcn_readfirstlane( - max(0, (h_idx + 1 - filter_height + pad_height + stride_y) / stride_y)); - const int out_row_end_1 = __builtin_amdgcn_readfirstlane( - min(out_height - 1, (h_idx + 1 + pad_height) / stride_y)); - const int out_col_start_0 = __builtin_amdgcn_readfirstlane( - max(0, (w_idx - filter_width + pad_width + stride_x) / stride_x)); - const int out_col_start_1 = __builtin_amdgcn_readfirstlane( - max(0, (w_idx + 1 - filter_width + pad_width + stride_x) / stride_x)); - const int out_col_end_0 = __builtin_amdgcn_readfirstlane( - min(out_width - 1, (w_idx + pad_width) / stride_x)); - const int out_col_end_1 = __builtin_amdgcn_readfirstlane( - min(out_width - 1, (w_idx + 1 + pad_width) / stride_x)); - - for(int out_row = out_row_start_0; out_row <= out_row_end_1; ++out_row) + const int filter_row_0 = + __builtin_amdgcn_readfirstlane(h_idx + pad_height - out_row * stride_y); + const int filter_row_1 = + __builtin_amdgcn_readfirstlane(h_idx + 1 + pad_height - out_row * stride_y); + for(int out_col = out_col_start_0; out_col <= out_col_end_1; ++out_col) { - const int filter_row_0 = - __builtin_amdgcn_readfirstlane(h_idx + pad_height - out_row * stride_y); - const int filter_row_1 = - __builtin_amdgcn_readfirstlane(h_idx + 1 + pad_height - out_row * stride_y); - for(int out_col = out_col_start_0; out_col <= out_col_end_1; ++out_col) - { - const int filter_col_0 = - __builtin_amdgcn_readfirstlane(w_idx + pad_width - out_col * stride_x); - const int filter_col_1 = __builtin_amdgcn_readfirstlane( - w_idx + 1 + pad_width - out_col * stride_x); - const int outgrad_offset = base_outgrad_offset + - out_row * outgrad_row_stride + - out_col * outgrad_col_stride; + const int filter_col_0 = + __builtin_amdgcn_readfirstlane(w_idx + pad_width - out_col * stride_x); + const int filter_col_1 = + __builtin_amdgcn_readfirstlane(w_idx + 1 + pad_width - out_col * stride_x); + const int outgrad_offset = base_outgrad_offset + out_row * outgrad_row_stride + + out_col * outgrad_col_stride; - ABDataType gradOut = p_gradOut[outgrad_offset]; + ABDataType gradOut = p_gradOut[outgrad_offset]; - bool row_in_axis0 = (out_row <= out_row_end_0); - bool row_in_axis1 = (out_row >= out_row_start_1); - bool col_in_axis0 = (out_col <= out_col_end_0); - bool col_in_axis1 = (out_col >= out_col_start_1); + bool row_in_axis0 = (out_row <= out_row_end_0); + bool row_in_axis1 = (out_row >= out_row_start_1); + bool col_in_axis0 = (out_col <= out_col_end_0); + bool col_in_axis1 = (out_col >= out_col_start_1); - sum[0] += ((row_in_axis0 && col_in_axis0) - ? shmem_weight[base_filter_offset + filter_row_0 * 5 + - filter_col_0] * - gradOut - : 0.f); - sum[1] += ((row_in_axis0 && col_in_axis1) - ? shmem_weight[base_filter_offset + filter_row_0 * 5 + - filter_col_1] * - gradOut - : 0.f); - sum[2] += ((row_in_axis1 && col_in_axis0) - ? shmem_weight[base_filter_offset + filter_row_1 * 5 + - filter_col_0] * - gradOut - : 0.f); - sum[3] += ((row_in_axis1 && col_in_axis1) - ? shmem_weight[base_filter_offset + filter_row_1 * 5 + - filter_col_1] * - gradOut - : 0.f); - } + sum[0] += + ((row_in_axis0 && col_in_axis0) + ? shmem_weight[base_filter_offset + filter_row_0 * 5 + filter_col_0] * + gradOut + : 0.f); + sum[1] += + ((row_in_axis0 && col_in_axis1) + ? shmem_weight[base_filter_offset + filter_row_0 * 5 + filter_col_1] * + gradOut + : 0.f); + sum[2] += + ((row_in_axis1 && col_in_axis0) + ? shmem_weight[base_filter_offset + filter_row_1 * 5 + filter_col_0] * + gradOut + : 0.f); + sum[3] += + ((row_in_axis1 && col_in_axis1) + ? shmem_weight[base_filter_offset + filter_row_1 * 5 + filter_col_1] * + gradOut + : 0.f); } + } #pragma unroll - for(int i = 0; i < 2; i++) + for(int i = 0; i < 2; i++) + { +#pragma unroll + for(int j = 0; j < 2; j++) { -#pragma unroll - for(int j = 0; j < 2; j++) - { - const int output_offset = base_ingrad_offset + - (h_idx + i) * ingrad_row_stride + - (w_idx + j) * ingrad_col_stride; - p_gradIn[output_offset] = __float2half(sum[i * 2 + j]); - } + const int output_offset = base_ingrad_offset + (h_idx + i) * ingrad_row_stride + + (w_idx + j) * ingrad_col_stride; + p_gradIn[output_offset] = __float2half(sum[i * 2 + j]); } } } @@ -1117,15 +1115,17 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr index_t GroupPerBlock = 64; constexpr index_t BatchPerBlock = 8; + constexpr index_t BlockDim = 512; const auto kernel = kernel_grouped_conv_bwd_data_optimized; + BatchPerBlock, + BlockDim>; return launch_and_time_kernel(stream_config, kernel, dim3(1344 / GroupPerBlock * 16, 1, 1), - dim3(512), + dim3(BlockDim), 0, p_a_grid, p_b_grid,