Improved performance of double complex AXPYV kernel

- Increased unroll by reusing X registers that was previously used
  for performing shuffle.
- Added loops with smaller increment steps for better problem
  decomposition.
- Added X vector and Y vector prefetch to the kernel.
- Removed redundant code that handles fringe in incx = 1 and
  incy = 1. This remainder will be performed by the loop that handles
  non-unit stride cases.
- Vectorized loops that handle non-unit stride cases using SSE
  instructions.

AMD-Internal: [CPUPL-2773]
Change-Id: Ifb5dc128e17b4e21315789bfaa147e3a7ec976f0
This commit is contained in:
Harihara Sudhan S
2023-02-09 16:23:25 +05:30
committed by HariharaSudhan S
parent 222e00e840
commit d4901f53ce

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2016 - 2019, Advanced Micro Devices, Inc.
Copyright (C) 2016 - 2023, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2018 - 2020, The University of Texas at Austin. All rights reserved.
Redistribution and use in source and binary forms, with or without
@@ -927,58 +927,58 @@ void bli_zaxpyv_zen_int5
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4)
const dim_t n_elem_per_reg = 4;
dim_t i;
double* restrict x0;
double* restrict y0;
double* restrict alpha0;
double alphaR, alphaI;
__m256d alphaRv; // for braodcast vector aR (real part of alpha)
__m256d alphaIv; // for braodcast vector aI (imaginary part of alpha)
__m256d xv[5];
__m256d xShufv[5];
__m256d yv[5];
conj_t conjx_use = conjx;
// If the vector dimension is zero, or if alpha is zero, return early.
// If the vector dimension is zero, or if alpha is zero, return early.
if ( bli_zero_dim1( n ) || PASTEMAC(z,eq0)( *alpha ) )
{
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4)
return;
}
// Initialize local pointers.
x0 = (double*)x;
y0 = (double*)y;
alpha0 = (double*)alpha;
dim_t i = 0;
alphaR = alpha->real;
alphaI = alpha->imag;
// Initialize local pointers.
double* x0 = (double*)x;
double* y0 = (double*)y;
double* alpha0 = (double*)alpha;
double alphaR = alpha->real;
double alphaI = alpha->imag;
if ( incx == 1 && incy == 1 )
{
const dim_t n_elem_per_reg = 4;
__m256d alphaRv; // for braodcast vector aR (real part of alpha)
__m256d alphaIv; // for braodcast vector aI (imaginary part of alpha)
__m256d xv[7]; // Holds the X vector elements
__m256d xShufv[5]; // Holds the permuted X vector elements
__m256d yv[7]; // Holds the y vector elements
// Prefetch distance used in the kernel based on number of cycles
// In this case, 16 cycles
const dim_t distance = 16;
// Prefetch X vector to the L1 cache
// as these elements will be need anyway
_mm_prefetch(x0, _MM_HINT_T1);
// Broadcast the alpha scalar to all elements of a vector register.
if ( !bli_is_conj (conjx) ) // If BLIS_NO_CONJUGATE
if (bli_is_noconj(conjx)) // If BLIS_NO_CONJUGATE
{
alphaRv = _mm256_broadcast_sd( &alphaR );
alphaRv = _mm256_broadcast_sd(&alphaR);
alphaIv[0] = -alphaI;
alphaIv[1] = alphaI;
alphaIv[1] = alphaI;
alphaIv[2] = -alphaI;
alphaIv[3] = alphaI;
alphaIv[3] = alphaI;
}
else
{
alphaIv = _mm256_broadcast_sd( &alphaI );
alphaIv = _mm256_broadcast_sd(&alphaI);
alphaRv[0] = alphaR;
alphaRv[0] = alphaR;
alphaRv[1] = -alphaR;
alphaRv[2] = alphaR;
alphaRv[2] = alphaR;
alphaRv[3] = -alphaR;
}
@@ -1023,7 +1023,77 @@ void bli_zaxpyv_zen_int5
// step 3 : fma :yv = ai*xv' + yv (old)
// yv = ai*xv' + ar*xv + yv
for ( i = 0; (i + 9) < n; i += 10 )
for (i = 0; (i + 13) < n; i += 14)
{
// 14 elements will be processed per loop; 14 FMAs will run per loop.
// alphaRv = aR aR aR aR
// xv = xR1 xI1 xR2 xI2
xv[0] = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg);
xv[1] = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg);
xv[2] = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg);
xv[3] = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg);
xv[4] = _mm256_loadu_pd(x0 + 4 * n_elem_per_reg);
xv[5] = _mm256_loadu_pd(x0 + 5 * n_elem_per_reg);
xv[6] = _mm256_loadu_pd(x0 + 6 * n_elem_per_reg);
// yv = yR1 yI1 yR2 yI2
yv[0] = _mm256_loadu_pd(y0 + 0 * n_elem_per_reg);
yv[1] = _mm256_loadu_pd(y0 + 1 * n_elem_per_reg);
yv[2] = _mm256_loadu_pd(y0 + 2 * n_elem_per_reg);
yv[3] = _mm256_loadu_pd(y0 + 3 * n_elem_per_reg);
yv[4] = _mm256_loadu_pd(y0 + 4 * n_elem_per_reg);
yv[5] = _mm256_loadu_pd(y0 + 5 * n_elem_per_reg);
yv[6] = _mm256_loadu_pd(y0 + 6 * n_elem_per_reg);
// yv = ar*xv + yv
// = aR.xR1 + yR1, aR.xI1 + yI1, aR.xR2 + yR2, aR.xI2 + yI2, ...
yv[0] = _mm256_fmadd_pd(xv[0], alphaRv, yv[0]);
yv[1] = _mm256_fmadd_pd(xv[1], alphaRv, yv[1]);
yv[2] = _mm256_fmadd_pd(xv[2], alphaRv, yv[2]);
yv[3] = _mm256_fmadd_pd(xv[3], alphaRv, yv[3]);
yv[4] = _mm256_fmadd_pd(xv[4], alphaRv, yv[4]);
yv[5] = _mm256_fmadd_pd(xv[5], alphaRv, yv[5]);
yv[6] = _mm256_fmadd_pd(xv[6], alphaRv, yv[6]);
// xv' = xI1 xRI xI2 xR2
xv[0] = _mm256_permute_pd(xv[0], 5);
xv[1] = _mm256_permute_pd(xv[1], 5);
xv[2] = _mm256_permute_pd(xv[2], 5);
xv[3] = _mm256_permute_pd(xv[3], 5);
xv[4] = _mm256_permute_pd(xv[4], 5);
xv[5] = _mm256_permute_pd(xv[5], 5);
xv[6] = _mm256_permute_pd(xv[6], 5);
// Prefetch X and Y vectors to the L1 cache
_mm_prefetch(x0 + distance, _MM_HINT_T1);
_mm_prefetch(y0 + distance, _MM_HINT_T1);
// alphaIv = -aI aI -aI aI
// yv = ar*xv + yv
// = aR.xR1 + yR1, aR.xI1 + yI1, aR.xR2 + yR2, aR.xI2 + yI2, ...
yv[0] = _mm256_fmadd_pd(xv[0], alphaIv, yv[0]);
yv[1] = _mm256_fmadd_pd(xv[1], alphaIv, yv[1]);
yv[2] = _mm256_fmadd_pd(xv[2], alphaIv, yv[2]);
yv[3] = _mm256_fmadd_pd(xv[3], alphaIv, yv[3]);
yv[4] = _mm256_fmadd_pd(xv[4], alphaIv, yv[4]);
yv[5] = _mm256_fmadd_pd(xv[5], alphaIv, yv[5]);
yv[6] = _mm256_fmadd_pd(xv[6], alphaIv, yv[6]);
// Store back the result
_mm256_storeu_pd((y0 + 0 * n_elem_per_reg), yv[0]);
_mm256_storeu_pd((y0 + 1 * n_elem_per_reg), yv[1]);
_mm256_storeu_pd((y0 + 2 * n_elem_per_reg), yv[2]);
_mm256_storeu_pd((y0 + 3 * n_elem_per_reg), yv[3]);
_mm256_storeu_pd((y0 + 4 * n_elem_per_reg), yv[4]);
_mm256_storeu_pd((y0 + 5 * n_elem_per_reg), yv[5]);
_mm256_storeu_pd((y0 + 6 * n_elem_per_reg), yv[6]);
x0 += 7 * n_elem_per_reg;
y0 += 7 * n_elem_per_reg;
}
for ( ; (i + 9) < n; i += 10 )
{
// 10 elements will be processed per loop; 10 FMAs will run per loop.
@@ -1079,6 +1149,48 @@ void bli_zaxpyv_zen_int5
y0 += 5*n_elem_per_reg;
}
for (; (i + 5) < n; i += 6)
{
// alphaRv = aR aR aR aR
// xv = xR1 xI1 xR2 xI2
xv[0] = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg);
xv[1] = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg);
xv[2] = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg);
// yv = yR1 yI1 yR2 yI2
yv[0] = _mm256_loadu_pd(y0 + 0 * n_elem_per_reg);
yv[1] = _mm256_loadu_pd(y0 + 1 * n_elem_per_reg);
yv[2] = _mm256_loadu_pd(y0 + 2 * n_elem_per_reg);
// xv' = xI1 xRI xI2 xR2
xShufv[0] = _mm256_permute_pd(xv[0], 5);
xShufv[1] = _mm256_permute_pd(xv[1], 5);
xShufv[2] = _mm256_permute_pd(xv[2], 5);
// alphaIv = -aI aI -aI aI
// yv = ar*xv + yv
// = aR.xR1 + yR1, aR.xI1 + yI1, aR.xR2 + yR2, aR.xI2 + yI2, ...
yv[0] = _mm256_fmadd_pd(xv[0], alphaRv, yv[0]);
yv[1] = _mm256_fmadd_pd(xv[1], alphaRv, yv[1]);
yv[2] = _mm256_fmadd_pd(xv[2], alphaRv, yv[2]);
// yv = ai*xv' + yv (old)
// yv = ai*xv' + ar*xv + yv
// = -aI*xI1 + aR.xR1 + yR1, aI.xR1 + aR.xI1 + yI1, .........
yv[0] = _mm256_fmadd_pd(xShufv[0], alphaIv, yv[0]);
yv[1] = _mm256_fmadd_pd(xShufv[1], alphaIv, yv[1]);
yv[2] = _mm256_fmadd_pd(xShufv[2], alphaIv, yv[2]);
// Store back the result
_mm256_storeu_pd((y0 + 0 * n_elem_per_reg), yv[0]);
_mm256_storeu_pd((y0 + 1 * n_elem_per_reg), yv[1]);
_mm256_storeu_pd((y0 + 2 * n_elem_per_reg), yv[2]);
x0 += 3 * n_elem_per_reg;
y0 += 3 * n_elem_per_reg;
}
for ( ; (i + 3) < n; i += 4 )
{
// alphaRv = aR aR aR aR
@@ -1115,7 +1227,7 @@ void bli_zaxpyv_zen_int5
y0 += 2*n_elem_per_reg;
}
for ( ; (i + 3) < n; i += 2 )
for ( ; (i + 1) < n; i += 2 )
{
// alphaRv = aR aR aR aR
// xv = xR1 xI1 xR2 xI2
@@ -1151,72 +1263,45 @@ void bli_zaxpyv_zen_int5
// as soon as the n_left cleanup loop below if BLIS is compiled with
// -mfpmath=sse).
_mm256_zeroupper();
}
/* Residual values are calculated here
y0 += (alpha) * (x0); --> BLIS_NO_CONJUGATE
y0 += ( aR.xR - aIxI + yR ) + ( aR.xI + aI.xR + yI )i
__m128d alpha_r, alpha_i, x_vec, y_vec;
y0 += (alpha) * conjx(x0); --> BLIS_CONJUGATE
y0 = ( aR.xR + aIxI + yR ) + (aI.xR - aR.xI + yI)i */
// Broadcast the alpha scalar to all elements of a vector register.
if (bli_is_noconj(conjx)) // If BLIS_NO_CONJUGATE
{
alpha_r = _mm_set1_pd(alphaR);
if ( !bli_is_conj(conjx_use) ) // BLIS_NO_CONJUGATE
{
for ( ; (i + 0) < n; i += 1 )
{
// real part: ( aR.xR - aIxI + yR )
*y0 += *alpha0 * (*x0) - (*(alpha0 + 1)) * (*(x0+1));
// img part: ( aR.xI + aI.xR + yI )
*(y0 + 1) += *alpha0 * (*(x0+1)) + (*(alpha0 + 1)) * (*x0);
x0 += 2;
y0 += 2;
}
}
else // BLIS_CONJUGATE
{
for ( ; (i + 0) < n; i += 1 )
{
// real part: ( aR.xR + aIxI + yR )
*y0 += *alpha0 * (*x0) + (*(alpha0 + 1)) * (*(x0+1));
// img part: ( aI.xR - aR.xI + yI )
*(y0 + 1) += (*(alpha0 + 1)) * (*x0) - (*alpha0) * (*(x0+1));
x0 += 2;
y0 += 2;
}
}
alpha_i[0] = -alphaI;
alpha_i[1] = alphaI;
}
else
{
const double alphar = *alpha0;
const double alphai = *(alpha0 + 1);
alpha_i = _mm_set1_pd(alphaI);
if ( !bli_is_conj(conjx_use) ) // BLIS_NO_CONJUGATE
{
for ( i = 0; i < n; ++i )
{
const double x0c = *x0;
const double x1c = *( x0+1 );
*y0 += alphar * x0c - alphai * x1c;
*(y0 + 1) += alphar * x1c + alphai * x0c;
x0 += incx * 2;
y0 += incy * 2;
}
}
else // BLIS_CONJUGATE
{
for ( i = 0; i < n; ++i )
{
const double x0c = *x0;
const double x1c = *( x0+1 );
*y0 += alphar * x0c + alphai * x1c;
*(y0 + 1) += alphai * x0c - alphar * x1c;
x0 += incx * 2;
y0 += incy * 2;
}
}
alpha_r[0] = alphaR;
alpha_r[1] = -alphaR;
}
/* This loop has two functions:
1. Acts as the the fringe case when incx == 1 and incy == 1
2. Performs the complete computation when incx != 1 or incy != 1
*/
for (; i < n; ++i)
{
x_vec = _mm_loadu_pd(x0);
y_vec = _mm_loadu_pd(y0);
y_vec = _mm_fmadd_pd(x_vec, alpha_r, y_vec);
x_vec = _mm_permute_pd(x_vec, 0b01);
y_vec = _mm_fmadd_pd(x_vec, alpha_i, y_vec);
_mm_storeu_pd(y0, y_vec);
x0 += incx * 2;
y0 += incy * 2;
}
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4)
}