Improved DGEMV performance for column-major cases

- Altered the framework to use 2 more fused kernels for
	  better problem decomposition
	- Increased unroll factor in AXPYF5 and AXPYF8 kernels
	  to improve register usage

AMD-Internal: [CPUPL-1970]

Change-Id: I79750235d9554466def5ff93898f832834990343
This commit is contained in:
Harihara Sudhan S
2022-01-28 11:44:38 +05:30
committed by HariharaSudhan S
parent 6d1edca727
commit 6696f91f41
3 changed files with 652 additions and 246 deletions

View File

@@ -313,27 +313,87 @@ void bli_dgemv_unf_var2
}
}
for ( i = 0; i < n_iter; i += f )
dim_t fuse_factor = 8;
dim_t f_temp = 0;
// Change the fuse factor based on
// Input size and available kernels
// This ensures that fusing is possible when the number of
// left over colums is less (better problem decomposition)
if (n < 5) fuse_factor = 4;
else if (n < 8) fuse_factor = 5;
for (i = 0; i < n_iter; i += f)
{
f = bli_determine_blocksize_dim_f( i, n_iter, BLIS_DGEMV_VAR2_FUSE );
f = bli_determine_blocksize_dim_f(i, n_iter, fuse_factor);
A1 = a + (0 )*rs_at + (i )*cs_at;
x1 = x + (i )*incx;
A1 = a + (i)*cs_at;
x1 = x + (i)*incx;
/* y = y + alpha * A1 * x1; */
bli_daxpyf_zen_int_16x4
(
conja,
conjx,
n_elem,
f,
alpha,
A1, rs_at, cs_at,
x1, incx,
y_buf, buf_incy,
cntx
);
// Pick kernel based on problem size
switch (f)
{
case 8:
bli_daxpyf_zen_int_8(
conja,
conjx,
n_elem,
f,
alpha,
A1, rs_at, cs_at,
x1, incx,
y_buf, buf_incy,
cntx);
break;
default:
if (f < 5)
{
bli_daxpyf_zen_int_16x4(
conja,
conjx,
n_elem,
f,
alpha,
A1, rs_at, cs_at,
x1, incx,
y_buf, buf_incy,
cntx);
}
else
{
bli_daxpyf_zen_int_5(
conja,
conjx,
n_elem,
f,
alpha,
A1, rs_at, cs_at,
x1, incx,
y_buf, buf_incy,
cntx);
}
}
// Calculate the next problem size
f_temp = bli_determine_blocksize_dim_f(i + f, n_iter, fuse_factor);
// Change fuse factor based on the next problem size
if (f_temp < fuse_factor)
{
if (f_temp < 5)
{
fuse_factor = 4;
}
else
{
fuse_factor = 5;
}
}
}
if ((incy > 1) && bli_mem_is_alloc( &mem_bufY ))
{
//store the result from unit strided y_buf to non-unit strided Y

View File

@@ -329,27 +329,13 @@ void bli_daxpyf_zen_int_5
dim_t i;
double* restrict a0;
double* restrict a1;
double* restrict a2;
double* restrict a3;
double* restrict a4;
double* restrict av[5] __attribute__((aligned(64)));
double* restrict y0;
v4df_t chi0v, chi1v, chi2v, chi3v;
v4df_t chi4v;
v4df_t a00v, a01v, a02v, a03v;
v4df_t a04v;
v4df_t a10v, a11v, a12v, a13v;
v4df_t a14v;
v4df_t y0v, y1v;
double chi0, chi1, chi2, chi3;
double chi4;
v4df_t chiv[5], a_vec[20], yv[4];
double chi[5];
// If either dimension is zero, or if alpha is zero, return early.
if ( bli_zero_dim2( m, b_n ) || bli_deq0( *alpha ) ) return;
@@ -385,117 +371,241 @@ void bli_daxpyf_zen_int_5
}
// At this point, we know that b_n is exactly equal to the fusing factor.
a0 = a + 0*lda;
a1 = a + 1*lda;
a2 = a + 2*lda;
a3 = a + 3*lda;
a4 = a + 4*lda;
// av points to the 5 columns under consideration
av[0] = a + 0*lda;
av[1] = a + 1*lda;
av[2] = a + 2*lda;
av[3] = a + 3*lda;
av[4] = a + 4*lda;
y0 = y;
chi0 = *( x + 0*incx );
chi1 = *( x + 1*incx );
chi2 = *( x + 2*incx );
chi3 = *( x + 3*incx );
chi4 = *( x + 4*incx );
chi[0] = *( x + 0*incx );
chi[1] = *( x + 1*incx );
chi[2] = *( x + 2*incx );
chi[3] = *( x + 3*incx );
chi[4] = *( x + 4*incx );
// Scale each chi scalar by alpha.
bli_dscals( *alpha, chi0 );
bli_dscals( *alpha, chi1 );
bli_dscals( *alpha, chi2 );
bli_dscals( *alpha, chi3 );
bli_dscals( *alpha, chi4 );
bli_dscals( *alpha, chi[0] );
bli_dscals( *alpha, chi[1] );
bli_dscals( *alpha, chi[2] );
bli_dscals( *alpha, chi[3] );
bli_dscals( *alpha, chi[4] );
// Broadcast the (alpha*chi?) scalars to all elements of vector registers.
chi0v.v = _mm256_broadcast_sd( &chi0 );
chi1v.v = _mm256_broadcast_sd( &chi1 );
chi2v.v = _mm256_broadcast_sd( &chi2 );
chi3v.v = _mm256_broadcast_sd( &chi3 );
chi4v.v = _mm256_broadcast_sd( &chi4 );
chiv[0].v = _mm256_broadcast_sd( &chi[0] );
chiv[1].v = _mm256_broadcast_sd( &chi[1] );
chiv[2].v = _mm256_broadcast_sd( &chi[2] );
chiv[3].v = _mm256_broadcast_sd( &chi[3] );
chiv[4].v = _mm256_broadcast_sd( &chi[4] );
// If there are vectorized iterations, perform them with vector
// instructions.
if ( inca == 1 && incy == 1 )
{
for ( i = 0; (i + 7) < m; i += 8 )
// 16 elements of the result are computed per iteration
for ( i = 0; (i + 15) < m; i += 16 )
{
// Load the input values.
y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg );
yv[3].v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg );
a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg );
a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg );
a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg );
a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg );
a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg );
a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg );
a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg );
a_vec[5].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg );
a_vec[6].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg );
a_vec[7].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg );
a_vec[8].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg );
a_vec[9].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg );
a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg );
a12v.v = _mm256_loadu_pd( a2 + 1*n_elem_per_reg );
a_vec[10].v = _mm256_loadu_pd( av[0] + 2*n_elem_per_reg );
a_vec[11].v = _mm256_loadu_pd( av[1] + 2*n_elem_per_reg );
a_vec[12].v = _mm256_loadu_pd( av[2] + 2*n_elem_per_reg );
a_vec[13].v = _mm256_loadu_pd( av[3] + 2*n_elem_per_reg );
a_vec[14].v = _mm256_loadu_pd( av[4] + 2*n_elem_per_reg );
a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg );
a13v.v = _mm256_loadu_pd( a3 + 1*n_elem_per_reg );
a04v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg );
a14v.v = _mm256_loadu_pd( a4 + 1*n_elem_per_reg );
a_vec[15].v = _mm256_loadu_pd( av[0] + 3*n_elem_per_reg );
a_vec[16].v = _mm256_loadu_pd( av[1] + 3*n_elem_per_reg );
a_vec[17].v = _mm256_loadu_pd( av[2] + 3*n_elem_per_reg );
a_vec[18].v = _mm256_loadu_pd( av[3] + 3*n_elem_per_reg );
a_vec[19].v = _mm256_loadu_pd( av[4] + 3*n_elem_per_reg );
// perform : y += alpha * x;
y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v );
yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v );
y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v );
yv[1].v = _mm256_fmadd_pd( a_vec[5].v, chiv[0].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[6].v, chiv[1].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[7].v, chiv[2].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[3].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[4].v, yv[1].v );
y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v );
y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v );
y0v.v = _mm256_fmadd_pd( a04v.v, chi4v.v, y0v.v );
y1v.v = _mm256_fmadd_pd( a14v.v, chi4v.v, y1v.v );
yv[2].v = _mm256_fmadd_pd( a_vec[10].v, chiv[0].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[11].v, chiv[1].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[12].v, chiv[2].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[13].v, chiv[3].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[14].v, chiv[4].v, yv[2].v );
yv[3].v = _mm256_fmadd_pd( a_vec[15].v, chiv[0].v, yv[3].v );
yv[3].v = _mm256_fmadd_pd( a_vec[16].v, chiv[1].v, yv[3].v );
yv[3].v = _mm256_fmadd_pd( a_vec[17].v, chiv[2].v, yv[3].v );
yv[3].v = _mm256_fmadd_pd( a_vec[18].v, chiv[3].v, yv[3].v );
yv[3].v = _mm256_fmadd_pd( a_vec[19].v, chiv[4].v, yv[3].v );
// Store the output.
_mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v );
_mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v );
_mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v );
_mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v );
_mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2].v );
_mm256_storeu_pd( (y0 + 3*n_elem_per_reg), yv[3].v );
y0 += n_iter_unroll * n_elem_per_reg;
a0 += n_iter_unroll * n_elem_per_reg;
a1 += n_iter_unroll * n_elem_per_reg;
a2 += n_iter_unroll * n_elem_per_reg;
a3 += n_iter_unroll * n_elem_per_reg;
a4 += n_iter_unroll * n_elem_per_reg;
y0 += n_elem_per_reg * 4;
av[0] += n_elem_per_reg * 4;
av[1] += n_elem_per_reg * 4;
av[2] += n_elem_per_reg * 4;
av[3] += n_elem_per_reg * 4;
av[4] += n_elem_per_reg * 4;
}
// 12 elements of the result are computed per iteration
for ( ; (i + 11) < m; i += 12 )
{
// Load the input values.
yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg );
a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg );
a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg );
a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg );
a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg );
a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg );
a_vec[5].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg );
a_vec[6].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg );
a_vec[7].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg );
a_vec[8].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg );
a_vec[9].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg );
a_vec[10].v = _mm256_loadu_pd( av[0] + 2*n_elem_per_reg );
a_vec[11].v = _mm256_loadu_pd( av[1] + 2*n_elem_per_reg );
a_vec[12].v = _mm256_loadu_pd( av[2] + 2*n_elem_per_reg );
a_vec[13].v = _mm256_loadu_pd( av[3] + 2*n_elem_per_reg );
a_vec[14].v = _mm256_loadu_pd( av[4] + 2*n_elem_per_reg );
// perform : y += alpha * x;
yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v );
yv[1].v = _mm256_fmadd_pd( a_vec[5].v, chiv[0].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[6].v, chiv[1].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[7].v, chiv[2].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[3].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[4].v, yv[1].v );
yv[2].v = _mm256_fmadd_pd( a_vec[10].v, chiv[0].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[11].v, chiv[1].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[12].v, chiv[2].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[13].v, chiv[3].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[14].v, chiv[4].v, yv[2].v );
// Store the output.
_mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v );
_mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v );
_mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2].v );
y0 += n_elem_per_reg * 3;
av[0] += n_elem_per_reg * 3;
av[1] += n_elem_per_reg * 3;
av[2] += n_elem_per_reg * 3;
av[3] += n_elem_per_reg * 3;
av[4] += n_elem_per_reg * 3;
}
// 8 elements of the result are computed per iteration
for (; (i + 7) < m; i += 8 )
{
// Load the input values.
yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg );
a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg );
a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg );
a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg );
a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg );
a_vec[5].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg );
a_vec[6].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg );
a_vec[7].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg );
a_vec[8].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg );
a_vec[9].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg );
// perform : y += alpha * x;
yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v );
yv[1].v = _mm256_fmadd_pd( a_vec[5].v, chiv[0].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[6].v, chiv[1].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[7].v, chiv[2].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[3].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[4].v, yv[1].v );
// Store the output.
_mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v );
_mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v );
y0 += n_elem_per_reg * 2;
av[0] += n_elem_per_reg * 2;
av[1] += n_elem_per_reg * 2;
av[2] += n_elem_per_reg * 2;
av[3] += n_elem_per_reg * 2;
av[4] += n_elem_per_reg * 2;
}
// 4 elements of the result are computed per iteration
for( ; (i + 3) < m; i += 4 )
{
// Load the input values.
y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg );
a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg );
a04v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg );
yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg );
a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg );
a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg );
a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg );
a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg );
// perform : y += alpha * x;
y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v );
y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v );
y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v );
y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v );
y0v.v = _mm256_fmadd_pd( a04v.v, chi4v.v, y0v.v );
yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v );
// Store the output.
_mm256_storeu_pd( (y0 + 0*n_elem_per_reg), y0v.v );
_mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v );
y0 += n_elem_per_reg;
a0 += n_elem_per_reg;
a1 += n_elem_per_reg;
a2 += n_elem_per_reg;
a3 += n_elem_per_reg;
a4 += n_elem_per_reg;
av[0] += n_elem_per_reg;
av[1] += n_elem_per_reg;
av[2] += n_elem_per_reg;
av[3] += n_elem_per_reg;
av[4] += n_elem_per_reg;
}
// If there are leftover iterations, perform them with scalar code.
@@ -503,25 +613,25 @@ void bli_daxpyf_zen_int_5
{
double y0c = *y0;
const double a0c = *a0;
const double a1c = *a1;
const double a2c = *a2;
const double a3c = *a3;
const double a4c = *a4;
const double a0c = *av[0];
const double a1c = *av[1];
const double a2c = *av[2];
const double a3c = *av[3];
const double a4c = *av[4];
y0c += chi0 * a0c;
y0c += chi1 * a1c;
y0c += chi2 * a2c;
y0c += chi3 * a3c;
y0c += chi4 * a4c;
y0c += chi[0] * a0c;
y0c += chi[1] * a1c;
y0c += chi[2] * a2c;
y0c += chi[3] * a3c;
y0c += chi[4] * a4c;
*y0 = y0c;
a0 += 1;
a1 += 1;
a2 += 1;
a3 += 1;
a4 += 1;
av[0] += 1;
av[1] += 1;
av[2] += 1;
av[3] += 1;
av[4] += 1;
y0 += 1;
}
}
@@ -531,25 +641,25 @@ void bli_daxpyf_zen_int_5
{
double y0c = *y0;
const double a0c = *a0;
const double a1c = *a1;
const double a2c = *a2;
const double a3c = *a3;
const double a4c = *a4;
const double a0c = *av[0];
const double a1c = *av[1];
const double a2c = *av[2];
const double a3c = *av[3];
const double a4c = *av[4];
y0c += chi0 * a0c;
y0c += chi1 * a1c;
y0c += chi2 * a2c;
y0c += chi3 * a3c;
y0c += chi4 * a4c;
y0c += chi[0] * a0c;
y0c += chi[1] * a1c;
y0c += chi[2] * a2c;
y0c += chi[3] * a3c;
y0c += chi[4] * a4c;
*y0 = y0c;
a0 += inca;
a1 += inca;
a2 += inca;
a3 += inca;
a4 += inca;
av[0] += inca;
av[1] += inca;
av[2] += inca;
av[3] += inca;
av[4] += inca;
y0 += incy;
}
@@ -1153,7 +1263,7 @@ void bli_daxpyf_zen_int_16x4
a2 += n_elem_per_reg;
a3 += n_elem_per_reg;
}
#if 1
for ( ; (i + 1) < m; i += 2)
{
@@ -1186,7 +1296,7 @@ void bli_daxpyf_zen_int_16x4
a2 += 2;
a3 += 2;
}
#endif
// If there are leftover iterations, perform them with scalar code.
for ( ; (i + 0) < m ; ++i )
{

View File

@@ -5,7 +5,7 @@
libraries.
Copyright (C) 2018, The University of Texas at Austin
Copyright (C) 2016 - 2018, Advanced Micro Devices, Inc.
Copyright (C) 2016 - 2022, Advanced Micro Devices, Inc.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -279,32 +279,19 @@ void bli_daxpyf_zen_int_8
const dim_t fuse_fac = 8;
const dim_t n_elem_per_reg = 4;
const dim_t n_iter_unroll = 1;
const dim_t n_iter_unroll[4] = {4, 3, 2, 1};
dim_t i;
dim_t m_viter;
dim_t m_left;
dim_t m_viter[4];
dim_t m_left = m;
double* restrict a0;
double* restrict a1;
double* restrict a2;
double* restrict a3;
double* restrict a4;
double* restrict a5;
double* restrict a6;
double* restrict a7;
double* restrict av[8] __attribute__((aligned(64)));
double* restrict y0;
v4df_t chi0v, chi1v, chi2v, chi3v;
v4df_t chi4v, chi5v, chi6v, chi7v;
v4df_t chiv[8], a_vec[32], yv[4];
v4df_t a0v, a1v, a2v, a3v;
v4df_t a4v, a5v, a6v, a7v;
v4df_t y0v;
double chi0, chi1, chi2, chi3;
double chi4, chi5, chi6, chi7;
double chi[8] __attribute__((aligned(64)));
// If either dimension is zero, or if alpha is zero, return early.
if ( bli_zero_dim2( m, b_n ) || PASTEMAC(d,eq0)( *alpha ) ) return;
@@ -343,94 +330,343 @@ void bli_daxpyf_zen_int_8
// Use the unrolling factor and the number of elements per register
// to compute the number of vectorized and leftover iterations.
m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll );
m_left = ( m ) % ( n_elem_per_reg * n_iter_unroll );
m_viter[0] = ( m_left ) / ( n_elem_per_reg * n_iter_unroll[0] );
m_left = ( m_left ) % ( n_elem_per_reg * n_iter_unroll[0] );
m_viter[1] = ( m_left ) / ( n_elem_per_reg * n_iter_unroll[1] );
m_left = ( m_left ) % ( n_elem_per_reg * n_iter_unroll[1] );
m_viter[2] = ( m_left ) / ( n_elem_per_reg * n_iter_unroll[2] );
m_left = ( m_left ) % ( n_elem_per_reg * n_iter_unroll[2] );
m_viter[3] = ( m_left ) / ( n_elem_per_reg * n_iter_unroll[3] );
m_left = ( m_left ) % ( n_elem_per_reg * n_iter_unroll[3] );
// If there is anything that would interfere with our use of contiguous
// vector loads/stores, override m_viter and m_left to use scalar code
// for all iterations.
if ( inca != 1 || incy != 1 )
{
m_viter = 0;
m_viter[0] = m_viter[1] = m_viter[2] = m_viter[3] = 0;
m_left = m;
}
a0 = a + 0*lda;
a1 = a + 1*lda;
a2 = a + 2*lda;
a3 = a + 3*lda;
a4 = a + 4*lda;
a5 = a + 5*lda;
a6 = a + 6*lda;
a7 = a + 7*lda;
// av points to the 8 columns under consideration
av[0] = a + 0*lda;
av[1] = a + 1*lda;
av[2] = a + 2*lda;
av[3] = a + 3*lda;
av[4] = a + 4*lda;
av[5] = a + 5*lda;
av[6] = a + 6*lda;
av[7] = a + 7*lda;
y0 = y;
chi0 = *( x + 0*incx );
chi1 = *( x + 1*incx );
chi2 = *( x + 2*incx );
chi3 = *( x + 3*incx );
chi4 = *( x + 4*incx );
chi5 = *( x + 5*incx );
chi6 = *( x + 6*incx );
chi7 = *( x + 7*incx );
chi[0] = *( x + 0*incx );
chi[1] = *( x + 1*incx );
chi[2] = *( x + 2*incx );
chi[3] = *( x + 3*incx );
chi[4] = *( x + 4*incx );
chi[5] = *( x + 5*incx );
chi[6] = *( x + 6*incx );
chi[7] = *( x + 7*incx );
// Scale each chi scalar by alpha.
PASTEMAC(d,scals)( *alpha, chi0 );
PASTEMAC(d,scals)( *alpha, chi1 );
PASTEMAC(d,scals)( *alpha, chi2 );
PASTEMAC(d,scals)( *alpha, chi3 );
PASTEMAC(d,scals)( *alpha, chi4 );
PASTEMAC(d,scals)( *alpha, chi5 );
PASTEMAC(d,scals)( *alpha, chi6 );
PASTEMAC(d,scals)( *alpha, chi7 );
PASTEMAC(d,scals)( *alpha, chi[0] );
PASTEMAC(d,scals)( *alpha, chi[1] );
PASTEMAC(d,scals)( *alpha, chi[2] );
PASTEMAC(d,scals)( *alpha, chi[3] );
PASTEMAC(d,scals)( *alpha, chi[4] );
PASTEMAC(d,scals)( *alpha, chi[5] );
PASTEMAC(d,scals)( *alpha, chi[6] );
PASTEMAC(d,scals)( *alpha, chi[7] );
// Broadcast the (alpha*chi?) scalars to all elements of vector registers.
chi0v.v = _mm256_broadcast_sd( &chi0 );
chi1v.v = _mm256_broadcast_sd( &chi1 );
chi2v.v = _mm256_broadcast_sd( &chi2 );
chi3v.v = _mm256_broadcast_sd( &chi3 );
chi4v.v = _mm256_broadcast_sd( &chi4 );
chi5v.v = _mm256_broadcast_sd( &chi5 );
chi6v.v = _mm256_broadcast_sd( &chi6 );
chi7v.v = _mm256_broadcast_sd( &chi7 );
chiv[0].v = _mm256_broadcast_sd( &chi[0] );
chiv[1].v = _mm256_broadcast_sd( &chi[1] );
chiv[2].v = _mm256_broadcast_sd( &chi[2] );
chiv[3].v = _mm256_broadcast_sd( &chi[3] );
chiv[4].v = _mm256_broadcast_sd( &chi[4] );
chiv[5].v = _mm256_broadcast_sd( &chi[5] );
chiv[6].v = _mm256_broadcast_sd( &chi[6] );
chiv[7].v = _mm256_broadcast_sd( &chi[7] );
// If there are vectorized iterations, perform them with vector
// instructions.
for ( i = 0; i < m_viter; ++i )
// 16 elements of the result are computed per iteration
for ( i = 0; i < m_viter[0]; ++i )
{
// Load the input values.
y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
a0v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
a1v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
a2v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg );
a3v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg );
a4v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg );
a5v.v = _mm256_loadu_pd( a5 + 0*n_elem_per_reg );
a6v.v = _mm256_loadu_pd( a6 + 0*n_elem_per_reg );
a7v.v = _mm256_loadu_pd( a7 + 0*n_elem_per_reg );
yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg );
yv[3].v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg );
a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg );
a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg );
a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg );
a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg );
a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg );
a_vec[5].v = _mm256_loadu_pd( av[5] + 0*n_elem_per_reg );
a_vec[6].v = _mm256_loadu_pd( av[6] + 0*n_elem_per_reg );
a_vec[7].v = _mm256_loadu_pd( av[7] + 0*n_elem_per_reg );
a_vec[8].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg );
a_vec[9].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg );
a_vec[10].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg );
a_vec[11].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg );
a_vec[12].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg );
a_vec[13].v = _mm256_loadu_pd( av[5] + 1*n_elem_per_reg );
a_vec[14].v = _mm256_loadu_pd( av[6] + 1*n_elem_per_reg );
a_vec[15].v = _mm256_loadu_pd( av[7] + 1*n_elem_per_reg );
a_vec[16].v = _mm256_loadu_pd( av[0] + 2*n_elem_per_reg );
a_vec[17].v = _mm256_loadu_pd( av[1] + 2*n_elem_per_reg );
a_vec[18].v = _mm256_loadu_pd( av[2] + 2*n_elem_per_reg );
a_vec[19].v = _mm256_loadu_pd( av[3] + 2*n_elem_per_reg );
a_vec[20].v = _mm256_loadu_pd( av[4] + 2*n_elem_per_reg );
a_vec[21].v = _mm256_loadu_pd( av[5] + 2*n_elem_per_reg );
a_vec[22].v = _mm256_loadu_pd( av[6] + 2*n_elem_per_reg );
a_vec[23].v = _mm256_loadu_pd( av[7] + 2*n_elem_per_reg );
a_vec[24].v = _mm256_loadu_pd( av[0] + 3*n_elem_per_reg );
a_vec[25].v = _mm256_loadu_pd( av[1] + 3*n_elem_per_reg );
a_vec[26].v = _mm256_loadu_pd( av[2] + 3*n_elem_per_reg );
a_vec[27].v = _mm256_loadu_pd( av[3] + 3*n_elem_per_reg );
a_vec[28].v = _mm256_loadu_pd( av[4] + 3*n_elem_per_reg );
a_vec[29].v = _mm256_loadu_pd( av[5] + 3*n_elem_per_reg );
a_vec[30].v = _mm256_loadu_pd( av[6] + 3*n_elem_per_reg );
a_vec[31].v = _mm256_loadu_pd( av[7] + 3*n_elem_per_reg );
// perform : y += alpha * x;
y0v.v = _mm256_fmadd_pd( a0v.v, chi0v.v, y0v.v );
y0v.v = _mm256_fmadd_pd( a1v.v, chi1v.v, y0v.v );
y0v.v = _mm256_fmadd_pd( a2v.v, chi2v.v, y0v.v );
y0v.v = _mm256_fmadd_pd( a3v.v, chi3v.v, y0v.v );
y0v.v = _mm256_fmadd_pd( a4v.v, chi4v.v, y0v.v );
y0v.v = _mm256_fmadd_pd( a5v.v, chi5v.v, y0v.v );
y0v.v = _mm256_fmadd_pd( a6v.v, chi6v.v, y0v.v );
y0v.v = _mm256_fmadd_pd( a7v.v, chi7v.v, y0v.v );
yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[5].v, chiv[5].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[6].v, chiv[6].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[7].v, chiv[7].v, yv[0].v );
yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[0].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[1].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[10].v, chiv[2].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[11].v, chiv[3].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[12].v, chiv[4].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[13].v, chiv[5].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[14].v, chiv[6].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[15].v, chiv[7].v, yv[1].v );
yv[2].v = _mm256_fmadd_pd( a_vec[16].v, chiv[0].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[17].v, chiv[1].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[18].v, chiv[2].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[19].v, chiv[3].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[20].v, chiv[4].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[21].v, chiv[5].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[22].v, chiv[6].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[23].v, chiv[7].v, yv[2].v );
yv[3].v = _mm256_fmadd_pd( a_vec[24].v, chiv[0].v, yv[3].v );
yv[3].v = _mm256_fmadd_pd( a_vec[25].v, chiv[1].v, yv[3].v );
yv[3].v = _mm256_fmadd_pd( a_vec[26].v, chiv[2].v, yv[3].v );
yv[3].v = _mm256_fmadd_pd( a_vec[27].v, chiv[3].v, yv[3].v );
yv[3].v = _mm256_fmadd_pd( a_vec[28].v, chiv[4].v, yv[3].v );
yv[3].v = _mm256_fmadd_pd( a_vec[29].v, chiv[5].v, yv[3].v );
yv[3].v = _mm256_fmadd_pd( a_vec[30].v, chiv[6].v, yv[3].v );
yv[3].v = _mm256_fmadd_pd( a_vec[31].v, chiv[7].v, yv[3].v );
// Store the output.
_mm256_storeu_pd( (y0 + 0*n_elem_per_reg), y0v.v );
_mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v );
_mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v );
_mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2].v );
_mm256_storeu_pd( (y0 + 3*n_elem_per_reg), yv[3].v );
y0 += n_elem_per_reg * n_iter_unroll[0];
av[0] += n_elem_per_reg * n_iter_unroll[0];
av[1] += n_elem_per_reg * n_iter_unroll[0];
av[2] += n_elem_per_reg * n_iter_unroll[0];
av[3] += n_elem_per_reg * n_iter_unroll[0];
av[4] += n_elem_per_reg * n_iter_unroll[0];
av[5] += n_elem_per_reg * n_iter_unroll[0];
av[6] += n_elem_per_reg * n_iter_unroll[0];
av[7] += n_elem_per_reg * n_iter_unroll[0];
}
// 12 elements of the result are computed per iteration
for ( i = 0; i < m_viter[1]; ++i )
{
// Load the input values.
yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg );
a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg );
a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg );
a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg );
a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg );
a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg );
a_vec[5].v = _mm256_loadu_pd( av[5] + 0*n_elem_per_reg );
a_vec[6].v = _mm256_loadu_pd( av[6] + 0*n_elem_per_reg );
a_vec[7].v = _mm256_loadu_pd( av[7] + 0*n_elem_per_reg );
a_vec[8].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg );
a_vec[9].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg );
a_vec[10].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg );
a_vec[11].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg );
a_vec[12].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg );
a_vec[13].v = _mm256_loadu_pd( av[5] + 1*n_elem_per_reg );
a_vec[14].v = _mm256_loadu_pd( av[6] + 1*n_elem_per_reg );
a_vec[15].v = _mm256_loadu_pd( av[7] + 1*n_elem_per_reg );
a_vec[16].v = _mm256_loadu_pd( av[0] + 2*n_elem_per_reg );
a_vec[17].v = _mm256_loadu_pd( av[1] + 2*n_elem_per_reg );
a_vec[18].v = _mm256_loadu_pd( av[2] + 2*n_elem_per_reg );
a_vec[19].v = _mm256_loadu_pd( av[3] + 2*n_elem_per_reg );
a_vec[20].v = _mm256_loadu_pd( av[4] + 2*n_elem_per_reg );
a_vec[21].v = _mm256_loadu_pd( av[5] + 2*n_elem_per_reg );
a_vec[22].v = _mm256_loadu_pd( av[6] + 2*n_elem_per_reg );
a_vec[23].v = _mm256_loadu_pd( av[7] + 2*n_elem_per_reg );
// perform : y += alpha * x;
yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[5].v, chiv[5].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[6].v, chiv[6].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[7].v, chiv[7].v, yv[0].v );
yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[0].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[1].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[10].v, chiv[2].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[11].v, chiv[3].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[12].v, chiv[4].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[13].v, chiv[5].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[14].v, chiv[6].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[15].v, chiv[7].v, yv[1].v );
yv[2].v = _mm256_fmadd_pd( a_vec[16].v, chiv[0].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[17].v, chiv[1].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[18].v, chiv[2].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[19].v, chiv[3].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[20].v, chiv[4].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[21].v, chiv[5].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[22].v, chiv[6].v, yv[2].v );
yv[2].v = _mm256_fmadd_pd( a_vec[23].v, chiv[7].v, yv[2].v );
// Store the output.
_mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v );
_mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v );
_mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2].v );
y0 += n_elem_per_reg * n_iter_unroll[1];
av[0] += n_elem_per_reg * n_iter_unroll[1];
av[1] += n_elem_per_reg * n_iter_unroll[1];
av[2] += n_elem_per_reg * n_iter_unroll[1];
av[3] += n_elem_per_reg * n_iter_unroll[1];
av[4] += n_elem_per_reg * n_iter_unroll[1];
av[5] += n_elem_per_reg * n_iter_unroll[1];
av[6] += n_elem_per_reg * n_iter_unroll[1];
av[7] += n_elem_per_reg * n_iter_unroll[1];
}
// 8 elements of the result are computed per iteration
for ( i = 0; i < m_viter[2]; ++i )
{
// Load the input values.
yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg );
a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg );
a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg );
a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg );
a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg );
a_vec[5].v = _mm256_loadu_pd( av[5] + 0*n_elem_per_reg );
a_vec[6].v = _mm256_loadu_pd( av[6] + 0*n_elem_per_reg );
a_vec[7].v = _mm256_loadu_pd( av[7] + 0*n_elem_per_reg );
a_vec[8].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg );
a_vec[9].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg );
a_vec[10].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg );
a_vec[11].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg );
a_vec[12].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg );
a_vec[13].v = _mm256_loadu_pd( av[5] + 1*n_elem_per_reg );
a_vec[14].v = _mm256_loadu_pd( av[6] + 1*n_elem_per_reg );
a_vec[15].v = _mm256_loadu_pd( av[7] + 1*n_elem_per_reg );
// perform : y += alpha * x;
yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[5].v, chiv[5].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[6].v, chiv[6].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[7].v, chiv[7].v, yv[0].v );
yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[0].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[1].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[10].v, chiv[2].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[11].v, chiv[3].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[12].v, chiv[4].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[13].v, chiv[5].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[14].v, chiv[6].v, yv[1].v );
yv[1].v = _mm256_fmadd_pd( a_vec[15].v, chiv[7].v, yv[1].v );
// Store the output.
_mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v );
_mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v );
y0 += n_elem_per_reg * n_iter_unroll[2];
av[0] += n_elem_per_reg * n_iter_unroll[2];
av[1] += n_elem_per_reg * n_iter_unroll[2];
av[2] += n_elem_per_reg * n_iter_unroll[2];
av[3] += n_elem_per_reg * n_iter_unroll[2];
av[4] += n_elem_per_reg * n_iter_unroll[2];
av[5] += n_elem_per_reg * n_iter_unroll[2];
av[6] += n_elem_per_reg * n_iter_unroll[2];
av[7] += n_elem_per_reg * n_iter_unroll[2];
}
// 4 elements of the result are computed per iteration
for ( i = 0; i < m_viter[3]; ++i )
{
// Load the input values.
yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg );
a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg );
a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg );
a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg );
a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg );
a_vec[5].v = _mm256_loadu_pd( av[5] + 0*n_elem_per_reg );
a_vec[6].v = _mm256_loadu_pd( av[6] + 0*n_elem_per_reg );
a_vec[7].v = _mm256_loadu_pd( av[7] + 0*n_elem_per_reg );
// perform : y += alpha * x;
yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[5].v, chiv[5].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[6].v, chiv[6].v, yv[0].v );
yv[0].v = _mm256_fmadd_pd( a_vec[7].v, chiv[7].v, yv[0].v );
// Store the output.
_mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v );
y0 += n_elem_per_reg;
a0 += n_elem_per_reg;
a1 += n_elem_per_reg;
a2 += n_elem_per_reg;
a3 += n_elem_per_reg;
a4 += n_elem_per_reg;
a5 += n_elem_per_reg;
a6 += n_elem_per_reg;
a7 += n_elem_per_reg;
av[0] += n_elem_per_reg;
av[1] += n_elem_per_reg;
av[2] += n_elem_per_reg;
av[3] += n_elem_per_reg;
av[4] += n_elem_per_reg;
av[5] += n_elem_per_reg;
av[6] += n_elem_per_reg;
av[7] += n_elem_per_reg;
}
// If there are leftover iterations, perform them with scalar code.
@@ -438,34 +674,34 @@ void bli_daxpyf_zen_int_8
{
double y0c = *y0;
const double a0c = *a0;
const double a1c = *a1;
const double a2c = *a2;
const double a3c = *a3;
const double a4c = *a4;
const double a5c = *a5;
const double a6c = *a6;
const double a7c = *a7;
const double a0c = *av[0];
const double a1c = *av[1];
const double a2c = *av[2];
const double a3c = *av[3];
const double a4c = *av[4];
const double a5c = *av[5];
const double a6c = *av[6];
const double a7c = *av[7];
y0c += chi0 * a0c;
y0c += chi1 * a1c;
y0c += chi2 * a2c;
y0c += chi3 * a3c;
y0c += chi4 * a4c;
y0c += chi5 * a5c;
y0c += chi6 * a6c;
y0c += chi7 * a7c;
y0c += chi[0] * a0c;
y0c += chi[1] * a1c;
y0c += chi[2] * a2c;
y0c += chi[3] * a3c;
y0c += chi[4] * a4c;
y0c += chi[5] * a5c;
y0c += chi[6] * a6c;
y0c += chi[7] * a7c;
*y0 = y0c;
a0 += inca;
a1 += inca;
a2 += inca;
a3 += inca;
a4 += inca;
a5 += inca;
a6 += inca;
a7 += inca;
av[0] += inca;
av[1] += inca;
av[2] += inca;
av[3] += inca;
av[4] += inca;
av[5] += inca;
av[6] += inca;
av[7] += inca;
y0 += incy;
}
}