diff --git a/driver/include/device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp index ab41e325b3..c10c54793d 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp @@ -55,13 +55,13 @@ void device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(InDesc, #if 1 // 1x1 filter, 8x8 image - constexpr index_t N1 = 2; + constexpr index_t N0 = 1; constexpr index_t N2 = 1; - constexpr index_t Ho1 = 8; + constexpr index_t Ho0 = 1; constexpr index_t Ho2 = 1; - constexpr index_t Wo1 = 1; + constexpr index_t Wo0 = 2; constexpr index_t Wo2 = 4; constexpr index_t BlockSize = 256; @@ -105,6 +105,10 @@ void device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(InDesc, constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; #endif + constexpr index_t N1 = N / (N0 * N2); + constexpr index_t Ho1 = Ho / (Ho0 * Ho2); + constexpr index_t Wo1 = Wo / (Wo0 * Wo2); + constexpr index_t B = N1 * Ho1 * Wo1; constexpr index_t GridSize =