diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 260c1ff160..5ae2ee1182 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -197,12 +197,14 @@ __global__ void const index_t out_height, const index_t in_width, const index_t in_height, - const index_t group_num) + const index_t group_num, + const index_t whole_batch_num, + const index_t filter_size) { - constexpr int blockNumPerGroup = 128 / BatchPerBlock; - int grp_idx = GroupPerBlock * (blockIdx.x / blockNumPerGroup); - const ABDataType* weight_ptr = p_weight + grp_idx * 25; - int tid = threadIdx.x; + const int blockNumPerGroup = whole_batch_num / BatchPerBlock; + int grp_idx = GroupPerBlock * (blockIdx.x / blockNumPerGroup); + const ABDataType* weight_ptr = p_weight + grp_idx * filter_size; + int tid = threadIdx.x; const int filter_height = filter_y; const int filter_width = filter_x; @@ -224,11 +226,11 @@ __global__ void const int ingrad_row_stride = group_num * in_width; const int ingrad_col_stride = group_num; - __shared__ ABDataType shmem_weight[64 * 5 * 5]; + extern __shared__ ABDataType shmem_weight[]; constexpr index_t ElementPerFP4 = 16 / sizeof(ABDataType); - for(int index = tid; index < 64 * 5 * 5 / ElementPerFP4; index += BlockSize) + for(int index = tid; index < 64 * filter_size / ElementPerFP4; index += BlockSize) { reinterpret_cast(shmem_weight)[index] = reinterpret_cast(weight_ptr)[index]; @@ -247,7 +249,7 @@ __global__ void batch_id_in_glb_mem * outgrad_batch_stride; const int base_ingrad_offset = (grp_idx + local_grp_id) * ingrad_group_stride + batch_id_in_glb_mem * ingrad_batch_stride; - const int base_filter_offset = local_grp_id * 25; + const int base_filter_offset = local_grp_id * filter_size; for(int h_idx = 0; h_idx < in_height; h_idx += 2) { @@ -291,8 +293,10 @@ __global__ void bool col_in_axis0 = (out_col <= out_col_end_0); bool col_in_axis1 = (out_col >= out_col_start_1); - const int filter_offset0 = base_filter_offset + filter_row_0 * 5 + filter_col_0; - const int filter_offset1 = base_filter_offset + filter_row_1 * 5 + filter_col_0; + const int filter_offset0 = + base_filter_offset + filter_row_0 * filter_width + filter_col_0; + const int filter_offset1 = + base_filter_offset + filter_row_1 * filter_width + filter_col_0; sum[0] += ((row_in_axis0 && col_in_axis0) ? shmem_weight[filter_offset0] * gradOut @@ -314,15 +318,23 @@ __global__ void #pragma unroll for(int j = 0; j < 2; j++) { - const int output_offset = base_ingrad_offset + (h_idx + i) * ingrad_row_stride + - (w_idx + j) * ingrad_col_stride; - p_gradIn[output_offset] = __float2half(sum[i * 2 + j]); + if((h_idx + i < in_height) && (w_idx + j < in_width)) + { + const int output_offset = base_ingrad_offset + + (h_idx + i) * ingrad_row_stride + + (w_idx + j) * ingrad_col_stride; + p_gradIn[output_offset] = __float2half(sum[i * 2 + j]); + } } } } } } +// __global__ void kernel_grouped_conv_bwd_data_optimized_small() +// { + +// } } // namespace // Conv backward data multiple D: @@ -1114,26 +1126,32 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 GroupPerBlock, BatchPerBlock, BlockDim>; - - return launch_and_time_kernel(stream_config, - kernel, - dim3(1344 / GroupPerBlock * 16, 1, 1), - dim3(BlockDim), - 0, - p_a_grid, - p_b_grid, - p_e_grid, - 5, - 5, - 1, - 1, - 2, - 2, - 14, - 14, - 14, - 14, - 1344); + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.e_g_n_c_wis_lengths_[0] / GroupPerBlock * 128 / BatchPerBlock, + 1, + 1), + dim3(BlockDim), + GroupPerBlock * arg.b_g_k_c_xs_lengths_[NDimSpatial + 1] * + arg.b_g_k_c_xs_lengths_[NDimSpatial + 2] * sizeof(ADataType), + p_a_grid, + p_b_grid, + p_e_grid, + arg.b_g_k_c_xs_lengths_[NDimSpatial + 1], + arg.b_g_k_c_xs_lengths_[NDimSpatial + 2], + arg.conv_filter_strides_[0], + arg.conv_filter_strides_[1], + arg.input_left_pads_[0], + arg.input_left_pads_[1], + arg.a_g_n_k_wos_lengths_[NDimSpatial + 1], + arg.a_g_n_k_wos_lengths_[NDimSpatial + 2], + arg.e_g_n_c_wis_lengths_[NDimSpatial + 1], + arg.e_g_n_c_wis_lengths_[NDimSpatial + 2], + arg.a_g_n_k_wos_lengths_[0], + arg.a_g_n_k_wos_lengths_[1], + arg.b_g_k_c_xs_lengths_[NDimSpatial + 1] * + arg.b_g_k_c_xs_lengths_[NDimSpatial + 2]); // const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle< // GridwiseGemm, // ADataType, // TODO: distiguish A/B datatype