mirror of
https://github.com/amd/blis.git
synced 2026-05-11 09:39:59 +00:00
Improved performance of double complex AXPYV kernel
- Increased unroll by reusing X registers that was previously used for performing shuffle. - Added loops with smaller increment steps for better problem decomposition. - Added X vector and Y vector prefetch to the kernel. - Removed redundant code that handles fringe in incx = 1 and incy = 1. This remainder will be performed by the loop that handles non-unit stride cases. - Vectorized loops that handle non-unit stride cases using SSE instructions. AMD-Internal: [CPUPL-2773] Change-Id: Ifb5dc128e17b4e21315789bfaa147e3a7ec976f0
This commit is contained in:
committed by
HariharaSudhan S
parent
222e00e840
commit
d4901f53ce
@@ -4,7 +4,7 @@
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2016 - 2019, Advanced Micro Devices, Inc.
|
||||
Copyright (C) 2016 - 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2018 - 2020, The University of Texas at Austin. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
@@ -927,58 +927,58 @@ void bli_zaxpyv_zen_int5
|
||||
{
|
||||
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4)
|
||||
|
||||
const dim_t n_elem_per_reg = 4;
|
||||
|
||||
dim_t i;
|
||||
|
||||
double* restrict x0;
|
||||
double* restrict y0;
|
||||
double* restrict alpha0;
|
||||
|
||||
double alphaR, alphaI;
|
||||
|
||||
__m256d alphaRv; // for braodcast vector aR (real part of alpha)
|
||||
__m256d alphaIv; // for braodcast vector aI (imaginary part of alpha)
|
||||
__m256d xv[5];
|
||||
__m256d xShufv[5];
|
||||
__m256d yv[5];
|
||||
|
||||
conj_t conjx_use = conjx;
|
||||
|
||||
// If the vector dimension is zero, or if alpha is zero, return early.
|
||||
// If the vector dimension is zero, or if alpha is zero, return early.
|
||||
if ( bli_zero_dim1( n ) || PASTEMAC(z,eq0)( *alpha ) )
|
||||
{
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4)
|
||||
return;
|
||||
}
|
||||
|
||||
// Initialize local pointers.
|
||||
x0 = (double*)x;
|
||||
y0 = (double*)y;
|
||||
alpha0 = (double*)alpha;
|
||||
dim_t i = 0;
|
||||
|
||||
alphaR = alpha->real;
|
||||
alphaI = alpha->imag;
|
||||
// Initialize local pointers.
|
||||
double* x0 = (double*)x;
|
||||
double* y0 = (double*)y;
|
||||
double* alpha0 = (double*)alpha;
|
||||
|
||||
double alphaR = alpha->real;
|
||||
double alphaI = alpha->imag;
|
||||
|
||||
if ( incx == 1 && incy == 1 )
|
||||
{
|
||||
const dim_t n_elem_per_reg = 4;
|
||||
|
||||
__m256d alphaRv; // for braodcast vector aR (real part of alpha)
|
||||
__m256d alphaIv; // for braodcast vector aI (imaginary part of alpha)
|
||||
__m256d xv[7]; // Holds the X vector elements
|
||||
__m256d xShufv[5]; // Holds the permuted X vector elements
|
||||
__m256d yv[7]; // Holds the y vector elements
|
||||
|
||||
// Prefetch distance used in the kernel based on number of cycles
|
||||
// In this case, 16 cycles
|
||||
const dim_t distance = 16;
|
||||
|
||||
// Prefetch X vector to the L1 cache
|
||||
// as these elements will be need anyway
|
||||
_mm_prefetch(x0, _MM_HINT_T1);
|
||||
|
||||
// Broadcast the alpha scalar to all elements of a vector register.
|
||||
if ( !bli_is_conj (conjx) ) // If BLIS_NO_CONJUGATE
|
||||
if (bli_is_noconj(conjx)) // If BLIS_NO_CONJUGATE
|
||||
{
|
||||
alphaRv = _mm256_broadcast_sd( &alphaR );
|
||||
alphaRv = _mm256_broadcast_sd(&alphaR);
|
||||
|
||||
alphaIv[0] = -alphaI;
|
||||
alphaIv[1] = alphaI;
|
||||
alphaIv[1] = alphaI;
|
||||
alphaIv[2] = -alphaI;
|
||||
alphaIv[3] = alphaI;
|
||||
alphaIv[3] = alphaI;
|
||||
}
|
||||
else
|
||||
{
|
||||
alphaIv = _mm256_broadcast_sd( &alphaI );
|
||||
alphaIv = _mm256_broadcast_sd(&alphaI);
|
||||
|
||||
alphaRv[0] = alphaR;
|
||||
alphaRv[0] = alphaR;
|
||||
alphaRv[1] = -alphaR;
|
||||
alphaRv[2] = alphaR;
|
||||
alphaRv[2] = alphaR;
|
||||
alphaRv[3] = -alphaR;
|
||||
}
|
||||
|
||||
@@ -1023,7 +1023,77 @@ void bli_zaxpyv_zen_int5
|
||||
// step 3 : fma :yv = ai*xv' + yv (old)
|
||||
// yv = ai*xv' + ar*xv + yv
|
||||
|
||||
for ( i = 0; (i + 9) < n; i += 10 )
|
||||
for (i = 0; (i + 13) < n; i += 14)
|
||||
{
|
||||
// 14 elements will be processed per loop; 14 FMAs will run per loop.
|
||||
|
||||
// alphaRv = aR aR aR aR
|
||||
// xv = xR1 xI1 xR2 xI2
|
||||
xv[0] = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg);
|
||||
xv[1] = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg);
|
||||
xv[2] = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg);
|
||||
xv[3] = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg);
|
||||
xv[4] = _mm256_loadu_pd(x0 + 4 * n_elem_per_reg);
|
||||
xv[5] = _mm256_loadu_pd(x0 + 5 * n_elem_per_reg);
|
||||
xv[6] = _mm256_loadu_pd(x0 + 6 * n_elem_per_reg);
|
||||
|
||||
// yv = yR1 yI1 yR2 yI2
|
||||
yv[0] = _mm256_loadu_pd(y0 + 0 * n_elem_per_reg);
|
||||
yv[1] = _mm256_loadu_pd(y0 + 1 * n_elem_per_reg);
|
||||
yv[2] = _mm256_loadu_pd(y0 + 2 * n_elem_per_reg);
|
||||
yv[3] = _mm256_loadu_pd(y0 + 3 * n_elem_per_reg);
|
||||
yv[4] = _mm256_loadu_pd(y0 + 4 * n_elem_per_reg);
|
||||
yv[5] = _mm256_loadu_pd(y0 + 5 * n_elem_per_reg);
|
||||
yv[6] = _mm256_loadu_pd(y0 + 6 * n_elem_per_reg);
|
||||
|
||||
// yv = ar*xv + yv
|
||||
// = aR.xR1 + yR1, aR.xI1 + yI1, aR.xR2 + yR2, aR.xI2 + yI2, ...
|
||||
yv[0] = _mm256_fmadd_pd(xv[0], alphaRv, yv[0]);
|
||||
yv[1] = _mm256_fmadd_pd(xv[1], alphaRv, yv[1]);
|
||||
yv[2] = _mm256_fmadd_pd(xv[2], alphaRv, yv[2]);
|
||||
yv[3] = _mm256_fmadd_pd(xv[3], alphaRv, yv[3]);
|
||||
yv[4] = _mm256_fmadd_pd(xv[4], alphaRv, yv[4]);
|
||||
yv[5] = _mm256_fmadd_pd(xv[5], alphaRv, yv[5]);
|
||||
yv[6] = _mm256_fmadd_pd(xv[6], alphaRv, yv[6]);
|
||||
|
||||
// xv' = xI1 xRI xI2 xR2
|
||||
xv[0] = _mm256_permute_pd(xv[0], 5);
|
||||
xv[1] = _mm256_permute_pd(xv[1], 5);
|
||||
xv[2] = _mm256_permute_pd(xv[2], 5);
|
||||
xv[3] = _mm256_permute_pd(xv[3], 5);
|
||||
xv[4] = _mm256_permute_pd(xv[4], 5);
|
||||
xv[5] = _mm256_permute_pd(xv[5], 5);
|
||||
xv[6] = _mm256_permute_pd(xv[6], 5);
|
||||
|
||||
// Prefetch X and Y vectors to the L1 cache
|
||||
_mm_prefetch(x0 + distance, _MM_HINT_T1);
|
||||
_mm_prefetch(y0 + distance, _MM_HINT_T1);
|
||||
// alphaIv = -aI aI -aI aI
|
||||
|
||||
// yv = ar*xv + yv
|
||||
// = aR.xR1 + yR1, aR.xI1 + yI1, aR.xR2 + yR2, aR.xI2 + yI2, ...
|
||||
yv[0] = _mm256_fmadd_pd(xv[0], alphaIv, yv[0]);
|
||||
yv[1] = _mm256_fmadd_pd(xv[1], alphaIv, yv[1]);
|
||||
yv[2] = _mm256_fmadd_pd(xv[2], alphaIv, yv[2]);
|
||||
yv[3] = _mm256_fmadd_pd(xv[3], alphaIv, yv[3]);
|
||||
yv[4] = _mm256_fmadd_pd(xv[4], alphaIv, yv[4]);
|
||||
yv[5] = _mm256_fmadd_pd(xv[5], alphaIv, yv[5]);
|
||||
yv[6] = _mm256_fmadd_pd(xv[6], alphaIv, yv[6]);
|
||||
|
||||
// Store back the result
|
||||
_mm256_storeu_pd((y0 + 0 * n_elem_per_reg), yv[0]);
|
||||
_mm256_storeu_pd((y0 + 1 * n_elem_per_reg), yv[1]);
|
||||
_mm256_storeu_pd((y0 + 2 * n_elem_per_reg), yv[2]);
|
||||
_mm256_storeu_pd((y0 + 3 * n_elem_per_reg), yv[3]);
|
||||
_mm256_storeu_pd((y0 + 4 * n_elem_per_reg), yv[4]);
|
||||
_mm256_storeu_pd((y0 + 5 * n_elem_per_reg), yv[5]);
|
||||
_mm256_storeu_pd((y0 + 6 * n_elem_per_reg), yv[6]);
|
||||
|
||||
x0 += 7 * n_elem_per_reg;
|
||||
y0 += 7 * n_elem_per_reg;
|
||||
}
|
||||
|
||||
for ( ; (i + 9) < n; i += 10 )
|
||||
{
|
||||
// 10 elements will be processed per loop; 10 FMAs will run per loop.
|
||||
|
||||
@@ -1079,6 +1149,48 @@ void bli_zaxpyv_zen_int5
|
||||
y0 += 5*n_elem_per_reg;
|
||||
}
|
||||
|
||||
for (; (i + 5) < n; i += 6)
|
||||
{
|
||||
// alphaRv = aR aR aR aR
|
||||
// xv = xR1 xI1 xR2 xI2
|
||||
xv[0] = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg);
|
||||
xv[1] = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg);
|
||||
xv[2] = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg);
|
||||
|
||||
// yv = yR1 yI1 yR2 yI2
|
||||
yv[0] = _mm256_loadu_pd(y0 + 0 * n_elem_per_reg);
|
||||
yv[1] = _mm256_loadu_pd(y0 + 1 * n_elem_per_reg);
|
||||
yv[2] = _mm256_loadu_pd(y0 + 2 * n_elem_per_reg);
|
||||
|
||||
// xv' = xI1 xRI xI2 xR2
|
||||
xShufv[0] = _mm256_permute_pd(xv[0], 5);
|
||||
xShufv[1] = _mm256_permute_pd(xv[1], 5);
|
||||
xShufv[2] = _mm256_permute_pd(xv[2], 5);
|
||||
|
||||
// alphaIv = -aI aI -aI aI
|
||||
|
||||
// yv = ar*xv + yv
|
||||
// = aR.xR1 + yR1, aR.xI1 + yI1, aR.xR2 + yR2, aR.xI2 + yI2, ...
|
||||
yv[0] = _mm256_fmadd_pd(xv[0], alphaRv, yv[0]);
|
||||
yv[1] = _mm256_fmadd_pd(xv[1], alphaRv, yv[1]);
|
||||
yv[2] = _mm256_fmadd_pd(xv[2], alphaRv, yv[2]);
|
||||
|
||||
// yv = ai*xv' + yv (old)
|
||||
// yv = ai*xv' + ar*xv + yv
|
||||
// = -aI*xI1 + aR.xR1 + yR1, aI.xR1 + aR.xI1 + yI1, .........
|
||||
yv[0] = _mm256_fmadd_pd(xShufv[0], alphaIv, yv[0]);
|
||||
yv[1] = _mm256_fmadd_pd(xShufv[1], alphaIv, yv[1]);
|
||||
yv[2] = _mm256_fmadd_pd(xShufv[2], alphaIv, yv[2]);
|
||||
|
||||
// Store back the result
|
||||
_mm256_storeu_pd((y0 + 0 * n_elem_per_reg), yv[0]);
|
||||
_mm256_storeu_pd((y0 + 1 * n_elem_per_reg), yv[1]);
|
||||
_mm256_storeu_pd((y0 + 2 * n_elem_per_reg), yv[2]);
|
||||
|
||||
x0 += 3 * n_elem_per_reg;
|
||||
y0 += 3 * n_elem_per_reg;
|
||||
}
|
||||
|
||||
for ( ; (i + 3) < n; i += 4 )
|
||||
{
|
||||
// alphaRv = aR aR aR aR
|
||||
@@ -1115,7 +1227,7 @@ void bli_zaxpyv_zen_int5
|
||||
y0 += 2*n_elem_per_reg;
|
||||
}
|
||||
|
||||
for ( ; (i + 3) < n; i += 2 )
|
||||
for ( ; (i + 1) < n; i += 2 )
|
||||
{
|
||||
// alphaRv = aR aR aR aR
|
||||
// xv = xR1 xI1 xR2 xI2
|
||||
@@ -1151,72 +1263,45 @@ void bli_zaxpyv_zen_int5
|
||||
// as soon as the n_left cleanup loop below if BLIS is compiled with
|
||||
// -mfpmath=sse).
|
||||
_mm256_zeroupper();
|
||||
}
|
||||
|
||||
/* Residual values are calculated here
|
||||
y0 += (alpha) * (x0); --> BLIS_NO_CONJUGATE
|
||||
y0 += ( aR.xR - aIxI + yR ) + ( aR.xI + aI.xR + yI )i
|
||||
__m128d alpha_r, alpha_i, x_vec, y_vec;
|
||||
|
||||
y0 += (alpha) * conjx(x0); --> BLIS_CONJUGATE
|
||||
y0 = ( aR.xR + aIxI + yR ) + (aI.xR - aR.xI + yI)i */
|
||||
// Broadcast the alpha scalar to all elements of a vector register.
|
||||
if (bli_is_noconj(conjx)) // If BLIS_NO_CONJUGATE
|
||||
{
|
||||
alpha_r = _mm_set1_pd(alphaR);
|
||||
|
||||
if ( !bli_is_conj(conjx_use) ) // BLIS_NO_CONJUGATE
|
||||
{
|
||||
for ( ; (i + 0) < n; i += 1 )
|
||||
{
|
||||
// real part: ( aR.xR - aIxI + yR )
|
||||
*y0 += *alpha0 * (*x0) - (*(alpha0 + 1)) * (*(x0+1));
|
||||
// img part: ( aR.xI + aI.xR + yI )
|
||||
*(y0 + 1) += *alpha0 * (*(x0+1)) + (*(alpha0 + 1)) * (*x0);
|
||||
x0 += 2;
|
||||
y0 += 2;
|
||||
}
|
||||
}
|
||||
else // BLIS_CONJUGATE
|
||||
{
|
||||
for ( ; (i + 0) < n; i += 1 )
|
||||
{
|
||||
// real part: ( aR.xR + aIxI + yR )
|
||||
*y0 += *alpha0 * (*x0) + (*(alpha0 + 1)) * (*(x0+1));
|
||||
// img part: ( aI.xR - aR.xI + yI )
|
||||
*(y0 + 1) += (*(alpha0 + 1)) * (*x0) - (*alpha0) * (*(x0+1));
|
||||
x0 += 2;
|
||||
y0 += 2;
|
||||
}
|
||||
}
|
||||
alpha_i[0] = -alphaI;
|
||||
alpha_i[1] = alphaI;
|
||||
}
|
||||
else
|
||||
{
|
||||
const double alphar = *alpha0;
|
||||
const double alphai = *(alpha0 + 1);
|
||||
alpha_i = _mm_set1_pd(alphaI);
|
||||
|
||||
if ( !bli_is_conj(conjx_use) ) // BLIS_NO_CONJUGATE
|
||||
{
|
||||
for ( i = 0; i < n; ++i )
|
||||
{
|
||||
const double x0c = *x0;
|
||||
const double x1c = *( x0+1 );
|
||||
|
||||
*y0 += alphar * x0c - alphai * x1c;
|
||||
*(y0 + 1) += alphar * x1c + alphai * x0c;
|
||||
|
||||
x0 += incx * 2;
|
||||
y0 += incy * 2;
|
||||
}
|
||||
}
|
||||
else // BLIS_CONJUGATE
|
||||
{
|
||||
for ( i = 0; i < n; ++i )
|
||||
{
|
||||
const double x0c = *x0;
|
||||
const double x1c = *( x0+1 );
|
||||
|
||||
*y0 += alphar * x0c + alphai * x1c;
|
||||
*(y0 + 1) += alphai * x0c - alphar * x1c;
|
||||
|
||||
x0 += incx * 2;
|
||||
y0 += incy * 2;
|
||||
}
|
||||
}
|
||||
alpha_r[0] = alphaR;
|
||||
alpha_r[1] = -alphaR;
|
||||
}
|
||||
|
||||
/* This loop has two functions:
|
||||
1. Acts as the the fringe case when incx == 1 and incy == 1
|
||||
2. Performs the complete computation when incx != 1 or incy != 1
|
||||
*/
|
||||
for (; i < n; ++i)
|
||||
{
|
||||
|
||||
x_vec = _mm_loadu_pd(x0);
|
||||
y_vec = _mm_loadu_pd(y0);
|
||||
|
||||
y_vec = _mm_fmadd_pd(x_vec, alpha_r, y_vec);
|
||||
x_vec = _mm_permute_pd(x_vec, 0b01);
|
||||
y_vec = _mm_fmadd_pd(x_vec, alpha_i, y_vec);
|
||||
|
||||
_mm_storeu_pd(y0, y_vec);
|
||||
|
||||
x0 += incx * 2;
|
||||
y0 += incy * 2;
|
||||
}
|
||||
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user