Bug fix for C/ZAXPBY Kernels

- Complex AXPBY kernels gave incorrect output when both alpha and
  beta had non-zero imaginary parts.
- Previously, the scalar code (used to calculate remainder result
  or non-unit increment cases) was directly accessing and updating
  the y-vector pointer thus, resulting in an incorrect output.
  Updated it to operate on a local copy of the currect y element
  and store the final result to the y-pointer.
- Also, added operation to store temporary calculation of alpha*x
  in an intermediate vector and then later added to the y vector.

AMD-Internal: [CPUPL-3037]
Change-Id: Iddbd3000dcb1505b444b0ad41ab881b055842e1c
This commit is contained in:
Arnav Sharma
2023-03-14 17:36:31 +05:30
committed by Arnav Sharma
parent c1766e312a
commit 0bfb0393fd

View File

@@ -83,7 +83,7 @@ void bli_saxpbyv_zen_int
/* if the vector dimension is zero, or if alpha & beta are zero,
return early. */
if ( bli_zero_dim1( n ) ||
if ( bli_zero_dim1( n ) ||
( PASTEMAC( s, eq0 )( *alpha ) && PASTEMAC( s, eq0 )( *beta ) ) )
return;
@@ -114,7 +114,7 @@ void bli_saxpbyv_zen_int
// y := y' + alpha * x
y0v.v = _mm256_fmadd_ps
(
(
alphav.v,
_mm256_loadu_ps( x0 + 0*n_elem_per_reg ),
y0v.v
@@ -181,7 +181,7 @@ void bli_saxpbyv_zen_int
/**
* daxpbyv kernel performs the axpbyv operation.
* y := beta * y + alpha * conjx(x)
* where,
* where,
* x & y are double precision vectors of length n.
* alpha & beta are scalers.
*/
@@ -211,7 +211,7 @@ void bli_daxpbyv_zen_int
/* if the vector dimension is zero, or if alpha & beta are zero,
return early. */
if ( bli_zero_dim1( n ) ||
if ( bli_zero_dim1( n ) ||
( PASTEMAC( s, eq0 )( *alpha ) && PASTEMAC( s, eq0 )( *beta ) ) )
{
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4)
@@ -312,7 +312,7 @@ void bli_daxpbyv_zen_int
/**
* caxpbyv kernel performs the axpbyv operation.
* y := beta * y + alpha * conjx(x)
* where,
* where,
* x & y are simple complex vectors of length n.
* alpha & beta are scalers.
*/
@@ -349,7 +349,7 @@ void bli_caxpbyv_zen_int
/* if the vector dimension is zero, or if alpha & beta are zero,
return early. */
if ( bli_zero_dim1( n ) ||
if ( bli_zero_dim1( n ) ||
( PASTEMAC( c, eq0 )( *alpha ) && PASTEMAC( c, eq0 )( *beta ) ) )
{
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4)
@@ -371,7 +371,7 @@ void bli_caxpbyv_zen_int
// y = beta*y + alpha*x
// y = ( bR + ibI ) * ( yR + iyI ) + ( aR + iaI ) * ( xR + ixI )
// y = bR.yR + ibR.yI + ibI.yR - ibIyI + aR.xR + iaR.xI + iaI.xR - aI.xI
// y = ( bR.yR - bI.yI + aR.xR - aI.xI ) +
// y = ( bR.yR - bI.yI + aR.xR - aI.xI ) +
// i ( bR.yI + bI.yR + aR.xI + aI.xR )
// SIMD Algorithm BLIS_NO_CONJUGATE
@@ -424,8 +424,8 @@ void bli_caxpbyv_zen_int
// betaIv = -bI bI -bI bI -bI bI -bI bI
alphaRv = _mm256_broadcast_ss( &alphaR );
alphaIv = _mm256_set_ps
(
alphaI, -alphaI, alphaI, -alphaI,
(
alphaI, -alphaI, alphaI, -alphaI,
alphaI, -alphaI, alphaI, -alphaI
);
betaRv = _mm256_broadcast_ss( &betaR );
@@ -505,10 +505,15 @@ void bli_caxpbyv_zen_int
// yv = alphaIv * xv + yv
// = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ...
yv[0] = _mm256_fmadd_ps( alphaIv, xv[0], yv[0] );
yv[1] = _mm256_fmadd_ps( alphaIv, xv[1], yv[1] );
yv[2] = _mm256_fmadd_ps( alphaIv, xv[2], yv[2] );
yv[3] = _mm256_fmadd_ps( alphaIv, xv[3], yv[3] );
iv[0] = _mm256_fmadd_ps( alphaIv, xv[0], iv[0] );
iv[1] = _mm256_fmadd_ps( alphaIv, xv[1], iv[1] );
iv[2] = _mm256_fmadd_ps( alphaIv, xv[2], iv[2] );
iv[3] = _mm256_fmadd_ps( alphaIv, xv[3], iv[3] );
yv[0] = _mm256_add_ps( yv[0], iv[0] );
yv[1] = _mm256_add_ps( yv[1], iv[1] );
yv[2] = _mm256_add_ps( yv[2], iv[2] );
yv[3] = _mm256_add_ps( yv[3], iv[3] );
_mm256_storeu_ps( (y0 + 0*n_elem_per_reg), yv[0] );
_mm256_storeu_ps( (y0 + 1*n_elem_per_reg), yv[1] );
@@ -562,9 +567,13 @@ void bli_caxpbyv_zen_int
// yv = alphaIv * xv + yv
// = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ...
yv[0] = _mm256_fmadd_ps( alphaIv, xv[0], yv[0] );
yv[1] = _mm256_fmadd_ps( alphaIv, xv[1], yv[1] );
yv[2] = _mm256_fmadd_ps( alphaIv, xv[2], yv[2] );
iv[0] = _mm256_fmadd_ps( alphaIv, xv[0], iv[0] );
iv[1] = _mm256_fmadd_ps( alphaIv, xv[1], iv[1] );
iv[2] = _mm256_fmadd_ps( alphaIv, xv[2], iv[2] );
yv[0] = _mm256_add_ps( yv[0], iv[0] );
yv[1] = _mm256_add_ps( yv[1], iv[1] );
yv[2] = _mm256_add_ps( yv[2], iv[2] );
_mm256_storeu_ps( (y0 + 0*n_elem_per_reg), yv[0] );
_mm256_storeu_ps( (y0 + 1*n_elem_per_reg), yv[1] );
@@ -610,8 +619,11 @@ void bli_caxpbyv_zen_int
// yv = alphaIv * xv + yv
// = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ...
yv[0] = _mm256_fmadd_ps( alphaIv, xv[0], yv[0] );
yv[1] = _mm256_fmadd_ps( alphaIv, xv[1], yv[1] );
iv[0] = _mm256_fmadd_ps( alphaIv, xv[0], iv[0] );
iv[1] = _mm256_fmadd_ps( alphaIv, xv[1], iv[1] );
yv[0] = _mm256_add_ps( yv[0], iv[0] );
yv[1] = _mm256_add_ps( yv[1], iv[1] );
_mm256_storeu_ps( (y0 + 0*n_elem_per_reg), yv[0] );
_mm256_storeu_ps( (y0 + 1*n_elem_per_reg), yv[1] );
@@ -631,9 +643,12 @@ void bli_caxpbyv_zen_int
{
for ( ; i < n ; ++i )
{
*y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) +
const float yRc = *y0;
const float yIc = *( y0 + 1 );
*y0 = ( betaR * yRc ) - ( betaI * yIc ) +
( alphaR * (*x0) ) - ( alphaI * (*(x0 + 1)) );
*(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) +
*(y0 + 1) = ( betaR * yIc ) + ( betaI * yRc ) +
( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) );
x0 += 2;
@@ -644,9 +659,12 @@ void bli_caxpbyv_zen_int
{
for ( ; i < n ; ++i )
{
*y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) +
const float yRc = *y0;
const float yIc = *( y0 + 1 );
*y0 = ( betaR * yRc ) - ( betaI * yIc ) +
( alphaR * (*x0) ) + ( alphaI * (*(x0 + 1)) );
*(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) -
*(y0 + 1) = ( betaR * yIc ) + ( betaI * yRc ) -
( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) );
x0 += 2;
@@ -661,11 +679,14 @@ void bli_caxpbyv_zen_int
{
for ( i = 0; i < n ; ++i )
{
const float yRc = *y0;
const float yIc = *( y0 + 1 );
// yReal = ( bR.yR - bI.yI + aR.xR - aI.xI )
*y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) +
*y0 = ( betaR * yRc ) - ( betaI * yIc ) +
( alphaR * (*x0) ) - ( alphaI * (*(x0 + 1)) );
// yImag = ( bR.yI + bI.yR + aR.xI + aI.xR )
*(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) +
*(y0 + 1) = ( betaR * yIc ) + ( betaI * yRc ) +
( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) );
x0 += incx * 2;
@@ -676,11 +697,14 @@ void bli_caxpbyv_zen_int
{
for ( i = 0; i < n ; ++i )
{
const float yRc = *y0;
const float yIc = *( y0 + 1 );
// yReal = ( bR.yR - bI.yI + aR.xR - aI.xI )
*y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) +
*y0 = ( betaR * yRc ) - ( betaI * yIc ) +
( alphaR * (*x0) ) + ( alphaI * (*(x0 + 1)) );
// yImag = ( bR.yI + bI.yR + aR.xI + aI.xR )
*(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) -
*(y0 + 1) = ( betaR * yIc ) + ( betaI * yRc ) -
( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) );
x0 += incx * 2;
@@ -694,7 +718,7 @@ void bli_caxpbyv_zen_int
/**
* zaxpbyv kernel performs the axpbyv operation.
* y := beta * y + alpha * conjx(x)
* where,
* where,
* x & y are double complex vectors of length n.
* alpha & beta are scalers.
*/
@@ -731,7 +755,7 @@ void bli_zaxpbyv_zen_int
/* if the vector dimension is zero, or if alpha & beta are zero,
return early. */
if ( bli_zero_dim1( n ) ||
if ( bli_zero_dim1( n ) ||
( PASTEMAC( c, eq0 )( *alpha ) && PASTEMAC( c, eq0 )( *beta ) ) )
{
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4)
@@ -753,7 +777,7 @@ void bli_zaxpbyv_zen_int
// y = beta*y + alpha*x
// y = ( bR + ibI ) * ( yR + iyI ) + ( aR + iaI ) * ( xR + ixI )
// y = bR.yR + ibR.yI + ibI.yR - ibIyI + aR.xR + iaR.xI + iaI.xR - aI.xI
// y = ( bR.yR - bI.yI + aR.xR - aI.xI ) +
// y = ( bR.yR - bI.yI + aR.xR - aI.xI ) +
// i ( bR.yI + bI.yR + aR.xI + aI.xR )
// SIMD Algorithm BLIS_NO_CONJUGATE
@@ -761,10 +785,10 @@ void bli_zaxpbyv_zen_int
// yv' = yI1 yR1 yI2 yR2
// xv = xR1 xI1 xR2 xI2
// xv' = xI1 xR1 xI2 xR2
// arv = aR aR aR aR
// aiv = -aI aI -aI aI
// brv = bR bR bR bR
// biv = -bI bI -bI bI
// arv = aR aR aR aR
// aiv = -aI aI -aI aI
// brv = bR bR bR bR
// biv = -bI bI -bI bI
//
// step 1: iv = brv * iv
// step 2: shuffle yv -> yv'
@@ -785,10 +809,10 @@ void bli_zaxpbyv_zen_int
// yv' = yI1 yR1 yI2 yR2
// xv = xR1 xI1 xR2 xI2
// xv' = xI1 xR1 xI2 xR2
// arv = aR -aR aR -aR
// aiv = aI aI aI aI
// brv = bR bR bR bR
// biv = -bI bI -bI bI
// arv = aR -aR aR -aR
// aiv = aI aI aI aI
// brv = bR bR bR bR
// biv = -bI bI -bI bI
//
// step 1: iv = brv * iv
// step 2: shuffle yv -> yv'
@@ -871,10 +895,15 @@ void bli_zaxpbyv_zen_int
// yv = alphaIv * xv + yv
// = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ...
yv[0] = _mm256_fmadd_pd( alphaIv, xv[0], yv[0] );
yv[1] = _mm256_fmadd_pd( alphaIv, xv[1], yv[1] );
yv[2] = _mm256_fmadd_pd( alphaIv, xv[2], yv[2] );
yv[3] = _mm256_fmadd_pd( alphaIv, xv[3], yv[3] );
iv[0] = _mm256_fmadd_pd( alphaIv, xv[0], iv[0] );
iv[1] = _mm256_fmadd_pd( alphaIv, xv[1], iv[1] );
iv[2] = _mm256_fmadd_pd( alphaIv, xv[2], iv[2] );
iv[3] = _mm256_fmadd_pd( alphaIv, xv[3], iv[3] );
yv[0] = _mm256_add_pd( yv[0], iv[0] );
yv[1] = _mm256_add_pd( yv[1], iv[1] );
yv[2] = _mm256_add_pd( yv[2], iv[2] );
yv[3] = _mm256_add_pd( yv[3], iv[3] );
_mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] );
_mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1] );
@@ -928,9 +957,13 @@ void bli_zaxpbyv_zen_int
// yv = alphaIv * xv + yv
// = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ...
yv[0] = _mm256_fmadd_pd( alphaIv, xv[0], yv[0] );
yv[1] = _mm256_fmadd_pd( alphaIv, xv[1], yv[1] );
yv[2] = _mm256_fmadd_pd( alphaIv, xv[2], yv[2] );
iv[0] = _mm256_fmadd_pd( alphaIv, xv[0], iv[0] );
iv[1] = _mm256_fmadd_pd( alphaIv, xv[1], iv[1] );
iv[2] = _mm256_fmadd_pd( alphaIv, xv[2], iv[2] );
yv[0] = _mm256_add_pd( yv[0], iv[0] );
yv[1] = _mm256_add_pd( yv[1], iv[1] );
yv[2] = _mm256_add_pd( yv[2], iv[2] );
_mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] );
_mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1] );
@@ -976,8 +1009,11 @@ void bli_zaxpbyv_zen_int
// yv = alphaIv * xv + yv
// = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ...
yv[0] = _mm256_fmadd_pd( alphaIv, xv[0], yv[0] );
yv[1] = _mm256_fmadd_pd( alphaIv, xv[1], yv[1] );
iv[0] = _mm256_fmadd_pd( alphaIv, xv[0], iv[0] );
iv[1] = _mm256_fmadd_pd( alphaIv, xv[1], iv[1] );
yv[0] = _mm256_add_pd( yv[0], iv[0] );
yv[1] = _mm256_add_pd( yv[1], iv[1] );
_mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] );
_mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1] );
@@ -1015,7 +1051,9 @@ void bli_zaxpbyv_zen_int
// yv = alphaIv * xv + yv
// = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ...
yv[0] = _mm256_fmadd_pd( alphaIv, xv[0], yv[0] );
iv[0] = _mm256_fmadd_pd( alphaIv, xv[0], iv[0] );
yv[0] = _mm256_add_pd( yv[0], iv[0] );
_mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] );
@@ -1034,11 +1072,14 @@ void bli_zaxpbyv_zen_int
{
for ( ; i < n ; ++i )
{
const double yRc = *y0;
const double yIc = *( y0 + 1 );
// yReal = ( bR.yR - bI.yI + aR.xR - aI.xI )
*y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) +
*y0 = ( betaR * yRc ) - ( betaI * yIc ) +
( alphaR * (*x0) ) - ( alphaI * (*(x0 + 1)) );
// yImag = ( bR.yI + bI.yR + aR.xI + aI.xR )
*(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) +
*(y0 + 1) = ( betaR * yIc ) + ( betaI * yRc ) +
( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) );
x0 += 2;
@@ -1049,11 +1090,14 @@ void bli_zaxpbyv_zen_int
{
for ( ; i < n ; ++i )
{
const double yRc = *y0;
const double yIc = *( y0 + 1 );
// yReal = ( bR.yR - bI.yI + aR.xR - aI.xI )
*y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) +
*y0 = ( betaR * yRc ) - ( betaI * yIc ) +
( alphaR * (*x0) ) + ( alphaI * (*(x0 + 1)) );
// yImag = ( bR.yI + bI.yR + aR.xI + aI.xR )
*(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) -
*(y0 + 1) = ( betaR * yIc ) + ( betaI * yRc ) -
( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) );
x0 += 2;
@@ -1068,11 +1112,14 @@ void bli_zaxpbyv_zen_int
{
for ( i = 0; i < n ; ++i )
{
const double yRc = *y0;
const double yIc = *( y0 + 1 );
// yReal = ( bR.yR - bI.yI + aR.xR - aI.xI )
*y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) +
*y0 = ( betaR * yRc ) - ( betaI * yIc ) +
( alphaR * (*x0) ) - ( alphaI * (*(x0 + 1)) );
// yImag = ( bR.yI + bI.yR + aR.xI + aI.xR )
*(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) +
*(y0 + 1) = ( betaR * yIc ) + ( betaI * yRc ) +
( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) );
x0 += incx * 2;
@@ -1083,11 +1130,14 @@ void bli_zaxpbyv_zen_int
{
for ( i = 0; i < n ; ++i )
{
const double yRc = *y0;
const double yIc = *( y0 + 1 );
// yReal = ( bR.yR - bI.yI + aR.xR - aI.xI )
*y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) +
*y0 = ( betaR * yRc ) - ( betaI * yIc ) +
( alphaR * (*x0) ) + ( alphaI * (*(x0 + 1)) );
// yImag = ( bR.yI + bI.yR + aR.xI + aI.xR )
*(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) -
*(y0 + 1) = ( betaR * yIc ) + ( betaI * yRc ) -
( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) );
x0 += incx * 2;