From 664f2b6251dc6a4a18ae2f2fa3f2451e396192a0 Mon Sep 17 00:00:00 2001 From: Shaojie WANG Date: Fri, 3 Jun 2022 03:06:42 +0800 Subject: [PATCH] use old ctile to avoid conv2d fwd bias relu add compute error (#271) [ROCm/composable_kernel commit: 1c5d06f270e1d091e1831a16c3e94ee425e15293] --- .../conv2d_fwd_xdl_bias_relu_add.cpp | 8 +++---- ...fle_bias_activation_add_nhwc_kyxc_nhwk.hpp | 23 +++++++------------ .../gpu/grid/gridwise_gemm_xdlops_v3r3.hpp | 2 +- 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp b/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp index 53d882778a..1a234ea851 100644 --- a/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp +++ b/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp @@ -224,10 +224,10 @@ int main(int argc, char* argv[]) { case 0: break; case 1: - input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - weights.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - residual.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + input.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + weights.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + bias.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + residual.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; default: input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp index 85063443c1..cc1c2cb2ca 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -460,6 +460,8 @@ struct using C0GridDesc_M_N = remove_cvref_t; using C1GridDesc_M_N = remove_cvref_t; + using Block2CTileMap = BlockToCTileMap_M00_N0_M01; + // GridwiseGemm using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3< BlockSize, @@ -522,8 +524,6 @@ struct std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, - ck::index_t M01, - ck::index_t N01, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) @@ -540,10 +540,7 @@ struct c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, - block_2_ctile_map_{ - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, - M01_{M01}, - N01_{N01}, + block_2_ctile_map_{}, in_element_op_{in_element_op}, wei_element_op_{wei_element_op}, out_element_op_{out_element_op}, @@ -576,6 +573,8 @@ struct c0_grid_desc_m_n_ = descs[I3]; c1_grid_desc_m_n_ = descs[I4]; + block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_}; + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, @@ -618,9 +617,7 @@ struct typename GridwiseGemm:: C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; + Block2CTileMap block_2_ctile_map_; InElementwiseOperation in_element_op_; WeiElementwiseOperation wei_element_op_; OutElementwiseOperation out_element_op_; @@ -723,7 +720,7 @@ struct InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, - remove_reference_t, + Block2CTileMap, true>; ave_time = launch_and_time_kernel( @@ -767,7 +764,7 @@ struct InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, - remove_reference_t, + Block2CTileMap, false>; ave_time = launch_and_time_kernel( @@ -894,8 +891,6 @@ struct conv_filter_dilations, input_left_pads, input_right_pads, - 1, - 1, in_element_op, wei_element_op, out_element_op}; @@ -938,8 +933,6 @@ struct conv_filter_dilations, input_left_pads, input_right_pads, - 1, - 1, in_element_op, wei_element_op, out_element_op); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp index 745dfde0ba..2e324faf13 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp @@ -340,7 +340,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 using DefaultBlock2CTileMap = remove_cvref_t; - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid,