From f23b8e636b337b5ea5901e82cce9ae75d0f44aca Mon Sep 17 00:00:00 2001 From: Vignesh Balasubramanian Date: Mon, 29 Jul 2024 17:23:14 +0530 Subject: [PATCH] AVX2 and AVX512 optimizations for DAXPYV - Removed some of the unrolling factors that affected the performance of AVX2 DAXPYV kernel. In addition to improving the current performance on sizes compatible to single-threaded runs, this will now perform better for tiny sizes as well since the overhead to reach the computation is less. - Updated the vector partitioning logic, by using bli_thread_range_sub( ... ), which ensures that there is no false sharing among multiple threads. - Updated the AOCL-DYNAMIC logic for the API, to include thresholds or zen4 and zen5 micro-architectures. AMD-Internal: [CPUPL-5514] Change-Id: Iee9edddac685334213cd6694421ab3df3547e930 --- frame/base/bli_rntm.c | 38 +++++++ frame/compat/bla_axpy_amd.c | 27 +++-- kernels/zen/1/bli_axpyv_zen_int10.c | 156 ++-------------------------- 3 files changed, 63 insertions(+), 158 deletions(-) diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index 51f6fe5ed..2c7d6019c 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -1779,7 +1779,45 @@ BLIS_INLINE void aocl_daxpyv_dynamic switch (arch_id) { case BLIS_ARCH_ZEN5: + + if ( n_elem <= 34000 ) + *nt_ideal = 1; + else if ( n_elem <= 82000 ) + *nt_ideal = 4; + else if ( n_elem <= 2330000 ) + *nt_ideal = 8; + else if ( n_elem <= 4250000 ) + *nt_ideal = 16; + else if ( n_elem <= 7000000 ) + *nt_ideal = 32; + else if ( n_elem <= 21300000 ) + *nt_ideal = 64; + else + // For sizes in this range, AOCL dynamic does not make any change + *nt_ideal = -1; + + break; + case BLIS_ARCH_ZEN4: + + if ( n_elem <= 11000 ) + *nt_ideal = 1; + else if ( n_elem <= 130000 ) + *nt_ideal = 4; + else if ( n_elem <= 2230000 ) + *nt_ideal = 8; + else if ( n_elem <= 3400000 ) + *nt_ideal = 16; + else if ( n_elem <= 9250000 ) + *nt_ideal = 32; + else if ( n_elem <= 15800000 ) + *nt_ideal = 64; + else + // For sizes in this range, AOCL dynamic does not make any change + *nt_ideal = -1; + + break; + case BLIS_ARCH_ZEN: case BLIS_ARCH_ZEN2: case BLIS_ARCH_ZEN3: diff --git a/frame/compat/bla_axpy_amd.c b/frame/compat/bla_axpy_amd.c index 49cd8a1e7..325c89fdb 100644 --- a/frame/compat/bla_axpy_amd.c +++ b/frame/compat/bla_axpy_amd.c @@ -397,26 +397,37 @@ void daxpy_blis_impl _Pragma("omp parallel num_threads(nt)") { - dim_t start, length; + dim_t start, end, length; + thrinfo_t thrinfo_vec; - // Get the thread ID - dim_t thread_id = omp_get_thread_num(); + // The block size is the minimum factor, whose multiple will ensure that only + // the vector code section is executed. Furthermore, for double datatype it corresponds + // to one cacheline size. + dim_t block_size = 8; // Get the actual number of threads spawned - dim_t nt_use = omp_get_num_threads(); + thrinfo_vec.n_way = omp_get_num_threads(); + + // Get the thread ID + thrinfo_vec.work_id = omp_get_thread_num(); /* Calculate the compute range for the current thread based on the actual number of threads spawned */ - bli_thread_vector_partition + + bli_thread_range_sub ( + &thrinfo_vec, n_elem, - nt_use, - &start, &length, - thread_id + block_size, + FALSE, + &start, + &end ); + length = end - start; + // Adjust the local pointer for computation double *x_thread_local = x0 + (start * incx0); double *y_thread_local = y0 + (start * incy0); diff --git a/kernels/zen/1/bli_axpyv_zen_int10.c b/kernels/zen/1/bli_axpyv_zen_int10.c index f557a95b6..691e1c111 100644 --- a/kernels/zen/1/bli_axpyv_zen_int10.c +++ b/kernels/zen/1/bli_axpyv_zen_int10.c @@ -360,9 +360,9 @@ BLIS_EXPORT_BLIS void bli_daxpyv_zen_int10 double* restrict y0 = y; __m256d alphav; - __m256d xv[13]; - __m256d yv[13]; - __m256d zv[13]; + __m256d xv[4]; + __m256d yv[4]; + __m256d zv[4]; // If the vector dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim1( n ) || PASTEMAC(d,eq0)( *alpha ) ) @@ -380,151 +380,7 @@ BLIS_EXPORT_BLIS void bli_daxpyv_zen_int10 // Broadcast the alpha scalar to all elements of a vector register. alphav = _mm256_broadcast_sd( alpha ); - for (i = 0; (i + 51) < n; i += 52) - { - // 52 elements will be processed per loop; 13 FMAs will run per loop. - xv[0] = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); - xv[1] = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); - xv[2] = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); - xv[3] = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); - xv[4] = _mm256_loadu_pd(x0 + 4 * n_elem_per_reg); - xv[5] = _mm256_loadu_pd(x0 + 5 * n_elem_per_reg); - xv[6] = _mm256_loadu_pd(x0 + 6 * n_elem_per_reg); - xv[7] = _mm256_loadu_pd(x0 + 7 * n_elem_per_reg); - xv[8] = _mm256_loadu_pd(x0 + 8 * n_elem_per_reg); - xv[9] = _mm256_loadu_pd(x0 + 9 * n_elem_per_reg); - xv[10] = _mm256_loadu_pd(x0 + 10 * n_elem_per_reg); - xv[11] = _mm256_loadu_pd(x0 + 11 * n_elem_per_reg); - xv[12] = _mm256_loadu_pd(x0 + 12 * n_elem_per_reg); - - yv[0] = _mm256_loadu_pd(y0 + 0 * n_elem_per_reg); - yv[1] = _mm256_loadu_pd(y0 + 1 * n_elem_per_reg); - yv[2] = _mm256_loadu_pd(y0 + 2 * n_elem_per_reg); - yv[3] = _mm256_loadu_pd(y0 + 3 * n_elem_per_reg); - yv[4] = _mm256_loadu_pd(y0 + 4 * n_elem_per_reg); - yv[5] = _mm256_loadu_pd(y0 + 5 * n_elem_per_reg); - yv[6] = _mm256_loadu_pd(y0 + 6 * n_elem_per_reg); - yv[7] = _mm256_loadu_pd(y0 + 7 * n_elem_per_reg); - yv[8] = _mm256_loadu_pd(y0 + 8 * n_elem_per_reg); - yv[9] = _mm256_loadu_pd(y0 + 9 * n_elem_per_reg); - yv[10] = _mm256_loadu_pd(y0 + 10 * n_elem_per_reg); - yv[11] = _mm256_loadu_pd(y0 + 11 * n_elem_per_reg); - yv[12] = _mm256_loadu_pd(y0 + 12 * n_elem_per_reg); - - zv[0] = _mm256_fmadd_pd(xv[0], alphav, yv[0]); - zv[1] = _mm256_fmadd_pd(xv[1], alphav, yv[1]); - zv[2] = _mm256_fmadd_pd(xv[2], alphav, yv[2]); - zv[3] = _mm256_fmadd_pd(xv[3], alphav, yv[3]); - zv[4] = _mm256_fmadd_pd(xv[4], alphav, yv[4]); - zv[5] = _mm256_fmadd_pd(xv[5], alphav, yv[5]); - zv[6] = _mm256_fmadd_pd(xv[6], alphav, yv[6]); - zv[7] = _mm256_fmadd_pd(xv[7], alphav, yv[7]); - zv[8] = _mm256_fmadd_pd(xv[8], alphav, yv[8]); - zv[9] = _mm256_fmadd_pd(xv[9], alphav, yv[9]); - zv[10] = _mm256_fmadd_pd(xv[10], alphav, yv[10]); - zv[11] = _mm256_fmadd_pd(xv[11], alphav, yv[11]); - zv[12] = _mm256_fmadd_pd(xv[12], alphav, yv[12]); - - _mm256_storeu_pd((y0 + 0 * n_elem_per_reg), zv[0]); - _mm256_storeu_pd((y0 + 1 * n_elem_per_reg), zv[1]); - _mm256_storeu_pd((y0 + 2 * n_elem_per_reg), zv[2]); - _mm256_storeu_pd((y0 + 3 * n_elem_per_reg), zv[3]); - _mm256_storeu_pd((y0 + 4 * n_elem_per_reg), zv[4]); - _mm256_storeu_pd((y0 + 5 * n_elem_per_reg), zv[5]); - _mm256_storeu_pd((y0 + 6 * n_elem_per_reg), zv[6]); - _mm256_storeu_pd((y0 + 7 * n_elem_per_reg), zv[7]); - _mm256_storeu_pd((y0 + 8 * n_elem_per_reg), zv[8]); - _mm256_storeu_pd((y0 + 9 * n_elem_per_reg), zv[9]); - _mm256_storeu_pd((y0 + 10 * n_elem_per_reg), zv[10]); - _mm256_storeu_pd((y0 + 11 * n_elem_per_reg), zv[11]); - _mm256_storeu_pd((y0 + 12 * n_elem_per_reg), zv[12]); - - x0 += 13 * n_elem_per_reg; - y0 += 13 * n_elem_per_reg; - } - - for ( ; (i + 39) < n; i += 40 ) - { - // 40 elements will be processed per loop; 10 FMAs will run per loop. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); - xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); - xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); - xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); - xv[8] = _mm256_loadu_pd( x0 + 8*n_elem_per_reg ); - xv[9] = _mm256_loadu_pd( x0 + 9*n_elem_per_reg ); - - yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); - yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); - yv[4] = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); - yv[5] = _mm256_loadu_pd( y0 + 5*n_elem_per_reg ); - yv[6] = _mm256_loadu_pd( y0 + 6*n_elem_per_reg ); - yv[7] = _mm256_loadu_pd( y0 + 7*n_elem_per_reg ); - yv[8] = _mm256_loadu_pd( y0 + 8*n_elem_per_reg ); - yv[9] = _mm256_loadu_pd( y0 + 9*n_elem_per_reg ); - - zv[0] = _mm256_fmadd_pd( xv[0], alphav, yv[0] ); - zv[1] = _mm256_fmadd_pd( xv[1], alphav, yv[1] ); - zv[2] = _mm256_fmadd_pd( xv[2], alphav, yv[2] ); - zv[3] = _mm256_fmadd_pd( xv[3], alphav, yv[3] ); - zv[4] = _mm256_fmadd_pd( xv[4], alphav, yv[4] ); - zv[5] = _mm256_fmadd_pd( xv[5], alphav, yv[5] ); - zv[6] = _mm256_fmadd_pd( xv[6], alphav, yv[6] ); - zv[7] = _mm256_fmadd_pd( xv[7], alphav, yv[7] ); - zv[8] = _mm256_fmadd_pd( xv[8], alphav, yv[8] ); - zv[9] = _mm256_fmadd_pd( xv[9], alphav, yv[9] ); - - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), zv[3] ); - _mm256_storeu_pd( (y0 + 4*n_elem_per_reg), zv[4] ); - _mm256_storeu_pd( (y0 + 5*n_elem_per_reg), zv[5] ); - _mm256_storeu_pd( (y0 + 6*n_elem_per_reg), zv[6] ); - _mm256_storeu_pd( (y0 + 7*n_elem_per_reg), zv[7] ); - _mm256_storeu_pd( (y0 + 8*n_elem_per_reg), zv[8] ); - _mm256_storeu_pd( (y0 + 9*n_elem_per_reg), zv[9] ); - - x0 += 10*n_elem_per_reg; - y0 += 10*n_elem_per_reg; - } - - for ( ; (i + 19) < n; i += 20 ) - { - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); - - yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); - yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); - yv[4] = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); - - zv[0] = _mm256_fmadd_pd( xv[0], alphav, yv[0] ); - zv[1] = _mm256_fmadd_pd( xv[1], alphav, yv[1] ); - zv[2] = _mm256_fmadd_pd( xv[2], alphav, yv[2] ); - zv[3] = _mm256_fmadd_pd( xv[3], alphav, yv[3] ); - zv[4] = _mm256_fmadd_pd( xv[4], alphav, yv[4] ); - - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), zv[3] ); - _mm256_storeu_pd( (y0 + 4*n_elem_per_reg), zv[4] ); - - x0 += 5*n_elem_per_reg; - y0 += 5*n_elem_per_reg; - } - - for ( ; (i + 15) < n; i += 16 ) + for ( i = 0; ( i + 15 ) < n; i += 16 ) { xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); @@ -550,7 +406,7 @@ BLIS_EXPORT_BLIS void bli_daxpyv_zen_int10 y0 += 4*n_elem_per_reg; } - for ( ; i + 7 < n; i += 8 ) + for ( ; ( i + 7 ) < n; i += 8 ) { xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); @@ -568,7 +424,7 @@ BLIS_EXPORT_BLIS void bli_daxpyv_zen_int10 y0 += 2*n_elem_per_reg; } - for ( ; i + 3 < n; i += 4 ) + for ( ; ( i + 3 ) < n; i += 4 ) { xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg );