From d4901f53ce84df85a98383cbe26de06f75fa5c17 Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Thu, 9 Feb 2023 16:23:25 +0530 Subject: [PATCH] Improved performance of double complex AXPYV kernel - Increased unroll by reusing X registers that was previously used for performing shuffle. - Added loops with smaller increment steps for better problem decomposition. - Added X vector and Y vector prefetch to the kernel. - Removed redundant code that handles fringe in incx = 1 and incy = 1. This remainder will be performed by the loop that handles non-unit stride cases. - Vectorized loops that handle non-unit stride cases using SSE instructions. AMD-Internal: [CPUPL-2773] Change-Id: Ifb5dc128e17b4e21315789bfaa147e3a7ec976f0 --- kernels/zen/1/bli_axpyv_zen_int10.c | 273 ++++++++++++++++++---------- 1 file changed, 179 insertions(+), 94 deletions(-) diff --git a/kernels/zen/1/bli_axpyv_zen_int10.c b/kernels/zen/1/bli_axpyv_zen_int10.c index 4ef6981cd..a43ebb26f 100644 --- a/kernels/zen/1/bli_axpyv_zen_int10.c +++ b/kernels/zen/1/bli_axpyv_zen_int10.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2016 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2016 - 2023, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2018 - 2020, The University of Texas at Austin. All rights reserved. Redistribution and use in source and binary forms, with or without @@ -927,58 +927,58 @@ void bli_zaxpyv_zen_int5 { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) - const dim_t n_elem_per_reg = 4; - - dim_t i; - - double* restrict x0; - double* restrict y0; - double* restrict alpha0; - - double alphaR, alphaI; - - __m256d alphaRv; // for braodcast vector aR (real part of alpha) - __m256d alphaIv; // for braodcast vector aI (imaginary part of alpha) - __m256d xv[5]; - __m256d xShufv[5]; - __m256d yv[5]; - - conj_t conjx_use = conjx; - - // If the vector dimension is zero, or if alpha is zero, return early. + // If the vector dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim1( n ) || PASTEMAC(z,eq0)( *alpha ) ) { AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) return; } - // Initialize local pointers. - x0 = (double*)x; - y0 = (double*)y; - alpha0 = (double*)alpha; + dim_t i = 0; - alphaR = alpha->real; - alphaI = alpha->imag; + // Initialize local pointers. + double* x0 = (double*)x; + double* y0 = (double*)y; + double* alpha0 = (double*)alpha; + + double alphaR = alpha->real; + double alphaI = alpha->imag; if ( incx == 1 && incy == 1 ) { + const dim_t n_elem_per_reg = 4; + + __m256d alphaRv; // for braodcast vector aR (real part of alpha) + __m256d alphaIv; // for braodcast vector aI (imaginary part of alpha) + __m256d xv[7]; // Holds the X vector elements + __m256d xShufv[5]; // Holds the permuted X vector elements + __m256d yv[7]; // Holds the y vector elements + + // Prefetch distance used in the kernel based on number of cycles + // In this case, 16 cycles + const dim_t distance = 16; + + // Prefetch X vector to the L1 cache + // as these elements will be need anyway + _mm_prefetch(x0, _MM_HINT_T1); + // Broadcast the alpha scalar to all elements of a vector register. - if ( !bli_is_conj (conjx) ) // If BLIS_NO_CONJUGATE + if (bli_is_noconj(conjx)) // If BLIS_NO_CONJUGATE { - alphaRv = _mm256_broadcast_sd( &alphaR ); + alphaRv = _mm256_broadcast_sd(&alphaR); alphaIv[0] = -alphaI; - alphaIv[1] = alphaI; + alphaIv[1] = alphaI; alphaIv[2] = -alphaI; - alphaIv[3] = alphaI; + alphaIv[3] = alphaI; } else { - alphaIv = _mm256_broadcast_sd( &alphaI ); + alphaIv = _mm256_broadcast_sd(&alphaI); - alphaRv[0] = alphaR; + alphaRv[0] = alphaR; alphaRv[1] = -alphaR; - alphaRv[2] = alphaR; + alphaRv[2] = alphaR; alphaRv[3] = -alphaR; } @@ -1023,7 +1023,77 @@ void bli_zaxpyv_zen_int5 // step 3 : fma :yv = ai*xv' + yv (old) // yv = ai*xv' + ar*xv + yv - for ( i = 0; (i + 9) < n; i += 10 ) + for (i = 0; (i + 13) < n; i += 14) + { + // 14 elements will be processed per loop; 14 FMAs will run per loop. + + // alphaRv = aR aR aR aR + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1] = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2] = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3] = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); + xv[4] = _mm256_loadu_pd(x0 + 4 * n_elem_per_reg); + xv[5] = _mm256_loadu_pd(x0 + 5 * n_elem_per_reg); + xv[6] = _mm256_loadu_pd(x0 + 6 * n_elem_per_reg); + + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd(y0 + 0 * n_elem_per_reg); + yv[1] = _mm256_loadu_pd(y0 + 1 * n_elem_per_reg); + yv[2] = _mm256_loadu_pd(y0 + 2 * n_elem_per_reg); + yv[3] = _mm256_loadu_pd(y0 + 3 * n_elem_per_reg); + yv[4] = _mm256_loadu_pd(y0 + 4 * n_elem_per_reg); + yv[5] = _mm256_loadu_pd(y0 + 5 * n_elem_per_reg); + yv[6] = _mm256_loadu_pd(y0 + 6 * n_elem_per_reg); + + // yv = ar*xv + yv + // = aR.xR1 + yR1, aR.xI1 + yI1, aR.xR2 + yR2, aR.xI2 + yI2, ... + yv[0] = _mm256_fmadd_pd(xv[0], alphaRv, yv[0]); + yv[1] = _mm256_fmadd_pd(xv[1], alphaRv, yv[1]); + yv[2] = _mm256_fmadd_pd(xv[2], alphaRv, yv[2]); + yv[3] = _mm256_fmadd_pd(xv[3], alphaRv, yv[3]); + yv[4] = _mm256_fmadd_pd(xv[4], alphaRv, yv[4]); + yv[5] = _mm256_fmadd_pd(xv[5], alphaRv, yv[5]); + yv[6] = _mm256_fmadd_pd(xv[6], alphaRv, yv[6]); + + // xv' = xI1 xRI xI2 xR2 + xv[0] = _mm256_permute_pd(xv[0], 5); + xv[1] = _mm256_permute_pd(xv[1], 5); + xv[2] = _mm256_permute_pd(xv[2], 5); + xv[3] = _mm256_permute_pd(xv[3], 5); + xv[4] = _mm256_permute_pd(xv[4], 5); + xv[5] = _mm256_permute_pd(xv[5], 5); + xv[6] = _mm256_permute_pd(xv[6], 5); + + // Prefetch X and Y vectors to the L1 cache + _mm_prefetch(x0 + distance, _MM_HINT_T1); + _mm_prefetch(y0 + distance, _MM_HINT_T1); + // alphaIv = -aI aI -aI aI + + // yv = ar*xv + yv + // = aR.xR1 + yR1, aR.xI1 + yI1, aR.xR2 + yR2, aR.xI2 + yI2, ... + yv[0] = _mm256_fmadd_pd(xv[0], alphaIv, yv[0]); + yv[1] = _mm256_fmadd_pd(xv[1], alphaIv, yv[1]); + yv[2] = _mm256_fmadd_pd(xv[2], alphaIv, yv[2]); + yv[3] = _mm256_fmadd_pd(xv[3], alphaIv, yv[3]); + yv[4] = _mm256_fmadd_pd(xv[4], alphaIv, yv[4]); + yv[5] = _mm256_fmadd_pd(xv[5], alphaIv, yv[5]); + yv[6] = _mm256_fmadd_pd(xv[6], alphaIv, yv[6]); + + // Store back the result + _mm256_storeu_pd((y0 + 0 * n_elem_per_reg), yv[0]); + _mm256_storeu_pd((y0 + 1 * n_elem_per_reg), yv[1]); + _mm256_storeu_pd((y0 + 2 * n_elem_per_reg), yv[2]); + _mm256_storeu_pd((y0 + 3 * n_elem_per_reg), yv[3]); + _mm256_storeu_pd((y0 + 4 * n_elem_per_reg), yv[4]); + _mm256_storeu_pd((y0 + 5 * n_elem_per_reg), yv[5]); + _mm256_storeu_pd((y0 + 6 * n_elem_per_reg), yv[6]); + + x0 += 7 * n_elem_per_reg; + y0 += 7 * n_elem_per_reg; + } + + for ( ; (i + 9) < n; i += 10 ) { // 10 elements will be processed per loop; 10 FMAs will run per loop. @@ -1079,6 +1149,48 @@ void bli_zaxpyv_zen_int5 y0 += 5*n_elem_per_reg; } + for (; (i + 5) < n; i += 6) + { + // alphaRv = aR aR aR aR + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1] = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2] = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd(y0 + 0 * n_elem_per_reg); + yv[1] = _mm256_loadu_pd(y0 + 1 * n_elem_per_reg); + yv[2] = _mm256_loadu_pd(y0 + 2 * n_elem_per_reg); + + // xv' = xI1 xRI xI2 xR2 + xShufv[0] = _mm256_permute_pd(xv[0], 5); + xShufv[1] = _mm256_permute_pd(xv[1], 5); + xShufv[2] = _mm256_permute_pd(xv[2], 5); + + // alphaIv = -aI aI -aI aI + + // yv = ar*xv + yv + // = aR.xR1 + yR1, aR.xI1 + yI1, aR.xR2 + yR2, aR.xI2 + yI2, ... + yv[0] = _mm256_fmadd_pd(xv[0], alphaRv, yv[0]); + yv[1] = _mm256_fmadd_pd(xv[1], alphaRv, yv[1]); + yv[2] = _mm256_fmadd_pd(xv[2], alphaRv, yv[2]); + + // yv = ai*xv' + yv (old) + // yv = ai*xv' + ar*xv + yv + // = -aI*xI1 + aR.xR1 + yR1, aI.xR1 + aR.xI1 + yI1, ......... + yv[0] = _mm256_fmadd_pd(xShufv[0], alphaIv, yv[0]); + yv[1] = _mm256_fmadd_pd(xShufv[1], alphaIv, yv[1]); + yv[2] = _mm256_fmadd_pd(xShufv[2], alphaIv, yv[2]); + + // Store back the result + _mm256_storeu_pd((y0 + 0 * n_elem_per_reg), yv[0]); + _mm256_storeu_pd((y0 + 1 * n_elem_per_reg), yv[1]); + _mm256_storeu_pd((y0 + 2 * n_elem_per_reg), yv[2]); + + x0 += 3 * n_elem_per_reg; + y0 += 3 * n_elem_per_reg; + } + for ( ; (i + 3) < n; i += 4 ) { // alphaRv = aR aR aR aR @@ -1115,7 +1227,7 @@ void bli_zaxpyv_zen_int5 y0 += 2*n_elem_per_reg; } - for ( ; (i + 3) < n; i += 2 ) + for ( ; (i + 1) < n; i += 2 ) { // alphaRv = aR aR aR aR // xv = xR1 xI1 xR2 xI2 @@ -1151,72 +1263,45 @@ void bli_zaxpyv_zen_int5 // as soon as the n_left cleanup loop below if BLIS is compiled with // -mfpmath=sse). _mm256_zeroupper(); + } - /* Residual values are calculated here - y0 += (alpha) * (x0); --> BLIS_NO_CONJUGATE - y0 += ( aR.xR - aIxI + yR ) + ( aR.xI + aI.xR + yI )i + __m128d alpha_r, alpha_i, x_vec, y_vec; - y0 += (alpha) * conjx(x0); --> BLIS_CONJUGATE - y0 = ( aR.xR + aIxI + yR ) + (aI.xR - aR.xI + yI)i */ + // Broadcast the alpha scalar to all elements of a vector register. + if (bli_is_noconj(conjx)) // If BLIS_NO_CONJUGATE + { + alpha_r = _mm_set1_pd(alphaR); - if ( !bli_is_conj(conjx_use) ) // BLIS_NO_CONJUGATE - { - for ( ; (i + 0) < n; i += 1 ) - { - // real part: ( aR.xR - aIxI + yR ) - *y0 += *alpha0 * (*x0) - (*(alpha0 + 1)) * (*(x0+1)); - // img part: ( aR.xI + aI.xR + yI ) - *(y0 + 1) += *alpha0 * (*(x0+1)) + (*(alpha0 + 1)) * (*x0); - x0 += 2; - y0 += 2; - } - } - else // BLIS_CONJUGATE - { - for ( ; (i + 0) < n; i += 1 ) - { - // real part: ( aR.xR + aIxI + yR ) - *y0 += *alpha0 * (*x0) + (*(alpha0 + 1)) * (*(x0+1)); - // img part: ( aI.xR - aR.xI + yI ) - *(y0 + 1) += (*(alpha0 + 1)) * (*x0) - (*alpha0) * (*(x0+1)); - x0 += 2; - y0 += 2; - } - } + alpha_i[0] = -alphaI; + alpha_i[1] = alphaI; } else { - const double alphar = *alpha0; - const double alphai = *(alpha0 + 1); + alpha_i = _mm_set1_pd(alphaI); - if ( !bli_is_conj(conjx_use) ) // BLIS_NO_CONJUGATE - { - for ( i = 0; i < n; ++i ) - { - const double x0c = *x0; - const double x1c = *( x0+1 ); - - *y0 += alphar * x0c - alphai * x1c; - *(y0 + 1) += alphar * x1c + alphai * x0c; - - x0 += incx * 2; - y0 += incy * 2; - } - } - else // BLIS_CONJUGATE - { - for ( i = 0; i < n; ++i ) - { - const double x0c = *x0; - const double x1c = *( x0+1 ); - - *y0 += alphar * x0c + alphai * x1c; - *(y0 + 1) += alphai * x0c - alphar * x1c; - - x0 += incx * 2; - y0 += incy * 2; - } - } + alpha_r[0] = alphaR; + alpha_r[1] = -alphaR; } + + /* This loop has two functions: + 1. Acts as the the fringe case when incx == 1 and incy == 1 + 2. Performs the complete computation when incx != 1 or incy != 1 + */ + for (; i < n; ++i) + { + + x_vec = _mm_loadu_pd(x0); + y_vec = _mm_loadu_pd(y0); + + y_vec = _mm_fmadd_pd(x_vec, alpha_r, y_vec); + x_vec = _mm_permute_pd(x_vec, 0b01); + y_vec = _mm_fmadd_pd(x_vec, alpha_i, y_vec); + + _mm_storeu_pd(y0, y_vec); + + x0 += incx * 2; + y0 += incy * 2; + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) }