diff --git a/frame/2/gemv/bli_gemv_unf_var2_amd.c b/frame/2/gemv/bli_gemv_unf_var2_amd.c index d7f5145e3..831d906ca 100644 --- a/frame/2/gemv/bli_gemv_unf_var2_amd.c +++ b/frame/2/gemv/bli_gemv_unf_var2_amd.c @@ -313,27 +313,87 @@ void bli_dgemv_unf_var2 } } - for ( i = 0; i < n_iter; i += f ) + dim_t fuse_factor = 8; + dim_t f_temp = 0; + + // Change the fuse factor based on + // Input size and available kernels + // This ensures that fusing is possible when the number of + // left over colums is less (better problem decomposition) + if (n < 5) fuse_factor = 4; + else if (n < 8) fuse_factor = 5; + + for (i = 0; i < n_iter; i += f) { - f = bli_determine_blocksize_dim_f( i, n_iter, BLIS_DGEMV_VAR2_FUSE ); + f = bli_determine_blocksize_dim_f(i, n_iter, fuse_factor); - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; + A1 = a + (i)*cs_at; + x1 = x + (i)*incx; - /* y = y + alpha * A1 * x1; */ - bli_daxpyf_zen_int_16x4 - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y_buf, buf_incy, - cntx - ); + // Pick kernel based on problem size + switch (f) + { + case 8: + + bli_daxpyf_zen_int_8( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y_buf, buf_incy, + cntx); + + break; + default: + + if (f < 5) + { + bli_daxpyf_zen_int_16x4( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y_buf, buf_incy, + cntx); + } + else + { + bli_daxpyf_zen_int_5( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y_buf, buf_incy, + cntx); + } + } + + // Calculate the next problem size + f_temp = bli_determine_blocksize_dim_f(i + f, n_iter, fuse_factor); + + // Change fuse factor based on the next problem size + if (f_temp < fuse_factor) + { + if (f_temp < 5) + { + fuse_factor = 4; + } + else + { + fuse_factor = 5; + } + } } + if ((incy > 1) && bli_mem_is_alloc( &mem_bufY )) { //store the result from unit strided y_buf to non-unit strided Y diff --git a/kernels/zen/1f/bli_axpyf_zen_int_5.c b/kernels/zen/1f/bli_axpyf_zen_int_5.c index d09a85f57..8b1f697ce 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_5.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_5.c @@ -329,27 +329,13 @@ void bli_daxpyf_zen_int_5 dim_t i; - double* restrict a0; - double* restrict a1; - double* restrict a2; - double* restrict a3; - double* restrict a4; + double* restrict av[5] __attribute__((aligned(64))); double* restrict y0; - v4df_t chi0v, chi1v, chi2v, chi3v; - v4df_t chi4v; - - v4df_t a00v, a01v, a02v, a03v; - v4df_t a04v; - - v4df_t a10v, a11v, a12v, a13v; - v4df_t a14v; - - v4df_t y0v, y1v; - - double chi0, chi1, chi2, chi3; - double chi4; + v4df_t chiv[5], a_vec[20], yv[4]; + + double chi[5]; // If either dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim2( m, b_n ) || bli_deq0( *alpha ) ) return; @@ -385,117 +371,241 @@ void bli_daxpyf_zen_int_5 } // At this point, we know that b_n is exactly equal to the fusing factor. - - a0 = a + 0*lda; - a1 = a + 1*lda; - a2 = a + 2*lda; - a3 = a + 3*lda; - a4 = a + 4*lda; + // av points to the 5 columns under consideration + av[0] = a + 0*lda; + av[1] = a + 1*lda; + av[2] = a + 2*lda; + av[3] = a + 3*lda; + av[4] = a + 4*lda; y0 = y; - chi0 = *( x + 0*incx ); - chi1 = *( x + 1*incx ); - chi2 = *( x + 2*incx ); - chi3 = *( x + 3*incx ); - chi4 = *( x + 4*incx ); + chi[0] = *( x + 0*incx ); + chi[1] = *( x + 1*incx ); + chi[2] = *( x + 2*incx ); + chi[3] = *( x + 3*incx ); + chi[4] = *( x + 4*incx ); // Scale each chi scalar by alpha. - bli_dscals( *alpha, chi0 ); - bli_dscals( *alpha, chi1 ); - bli_dscals( *alpha, chi2 ); - bli_dscals( *alpha, chi3 ); - bli_dscals( *alpha, chi4 ); + bli_dscals( *alpha, chi[0] ); + bli_dscals( *alpha, chi[1] ); + bli_dscals( *alpha, chi[2] ); + bli_dscals( *alpha, chi[3] ); + bli_dscals( *alpha, chi[4] ); // Broadcast the (alpha*chi?) scalars to all elements of vector registers. - chi0v.v = _mm256_broadcast_sd( &chi0 ); - chi1v.v = _mm256_broadcast_sd( &chi1 ); - chi2v.v = _mm256_broadcast_sd( &chi2 ); - chi3v.v = _mm256_broadcast_sd( &chi3 ); - chi4v.v = _mm256_broadcast_sd( &chi4 ); + chiv[0].v = _mm256_broadcast_sd( &chi[0] ); + chiv[1].v = _mm256_broadcast_sd( &chi[1] ); + chiv[2].v = _mm256_broadcast_sd( &chi[2] ); + chiv[3].v = _mm256_broadcast_sd( &chi[3] ); + chiv[4].v = _mm256_broadcast_sd( &chi[4] ); // If there are vectorized iterations, perform them with vector // instructions. if ( inca == 1 && incy == 1 ) { - for ( i = 0; (i + 7) < m; i += 8 ) + // 16 elements of the result are computed per iteration + for ( i = 0; (i + 15) < m; i += 16 ) { // Load the input values. - y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); - a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); - a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg ); + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); - a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); - a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg ); + a_vec[5].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[8].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); - a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); - a12v.v = _mm256_loadu_pd( a2 + 1*n_elem_per_reg ); + a_vec[10].v = _mm256_loadu_pd( av[0] + 2*n_elem_per_reg ); + a_vec[11].v = _mm256_loadu_pd( av[1] + 2*n_elem_per_reg ); + a_vec[12].v = _mm256_loadu_pd( av[2] + 2*n_elem_per_reg ); + a_vec[13].v = _mm256_loadu_pd( av[3] + 2*n_elem_per_reg ); + a_vec[14].v = _mm256_loadu_pd( av[4] + 2*n_elem_per_reg ); - a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); - a13v.v = _mm256_loadu_pd( a3 + 1*n_elem_per_reg ); - - a04v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg ); - a14v.v = _mm256_loadu_pd( a4 + 1*n_elem_per_reg ); + a_vec[15].v = _mm256_loadu_pd( av[0] + 3*n_elem_per_reg ); + a_vec[16].v = _mm256_loadu_pd( av[1] + 3*n_elem_per_reg ); + a_vec[17].v = _mm256_loadu_pd( av[2] + 3*n_elem_per_reg ); + a_vec[18].v = _mm256_loadu_pd( av[3] + 3*n_elem_per_reg ); + a_vec[19].v = _mm256_loadu_pd( av[4] + 3*n_elem_per_reg ); // perform : y += alpha * x; - y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v ); + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); - y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v ); + yv[1].v = _mm256_fmadd_pd( a_vec[5].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[6].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[7].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[4].v, yv[1].v ); - y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v ); - - y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v ); - - y0v.v = _mm256_fmadd_pd( a04v.v, chi4v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( a14v.v, chi4v.v, y1v.v ); + yv[2].v = _mm256_fmadd_pd( a_vec[10].v, chiv[0].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[11].v, chiv[1].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[12].v, chiv[2].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[13].v, chiv[3].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[14].v, chiv[4].v, yv[2].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[15].v, chiv[0].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[16].v, chiv[1].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[17].v, chiv[2].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[18].v, chiv[3].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[19].v, chiv[4].v, yv[3].v ); // Store the output. - _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); - _mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v ); + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2].v ); + _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), yv[3].v ); - y0 += n_iter_unroll * n_elem_per_reg; - a0 += n_iter_unroll * n_elem_per_reg; - a1 += n_iter_unroll * n_elem_per_reg; - a2 += n_iter_unroll * n_elem_per_reg; - a3 += n_iter_unroll * n_elem_per_reg; - a4 += n_iter_unroll * n_elem_per_reg; + y0 += n_elem_per_reg * 4; + av[0] += n_elem_per_reg * 4; + av[1] += n_elem_per_reg * 4; + av[2] += n_elem_per_reg * 4; + av[3] += n_elem_per_reg * 4; + av[4] += n_elem_per_reg * 4; } + // 12 elements of the result are computed per iteration + for ( ; (i + 11) < m; i += 12 ) + { + // Load the input values. + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + + a_vec[5].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[8].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + + a_vec[10].v = _mm256_loadu_pd( av[0] + 2*n_elem_per_reg ); + a_vec[11].v = _mm256_loadu_pd( av[1] + 2*n_elem_per_reg ); + a_vec[12].v = _mm256_loadu_pd( av[2] + 2*n_elem_per_reg ); + a_vec[13].v = _mm256_loadu_pd( av[3] + 2*n_elem_per_reg ); + a_vec[14].v = _mm256_loadu_pd( av[4] + 2*n_elem_per_reg ); + + // perform : y += alpha * x; + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[5].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[6].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[7].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[4].v, yv[1].v ); + + yv[2].v = _mm256_fmadd_pd( a_vec[10].v, chiv[0].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[11].v, chiv[1].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[12].v, chiv[2].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[13].v, chiv[3].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[14].v, chiv[4].v, yv[2].v ); + + // Store the output. + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2].v ); + + y0 += n_elem_per_reg * 3; + av[0] += n_elem_per_reg * 3; + av[1] += n_elem_per_reg * 3; + av[2] += n_elem_per_reg * 3; + av[3] += n_elem_per_reg * 3; + av[4] += n_elem_per_reg * 3; + } + + // 8 elements of the result are computed per iteration + for (; (i + 7) < m; i += 8 ) + { + // Load the input values. + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + + a_vec[5].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[8].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + + // perform : y += alpha * x; + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[5].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[6].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[7].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[4].v, yv[1].v ); + + // Store the output. + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + + y0 += n_elem_per_reg * 2; + av[0] += n_elem_per_reg * 2; + av[1] += n_elem_per_reg * 2; + av[2] += n_elem_per_reg * 2; + av[3] += n_elem_per_reg * 2; + av[4] += n_elem_per_reg * 2; + } + + // 4 elements of the result are computed per iteration for( ; (i + 3) < m; i += 4 ) { // Load the input values. - y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - - a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); - a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); - a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); - a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); - a04v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg ); + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); // perform : y += alpha * x; - y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a04v.v, chi4v.v, y0v.v ); + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); // Store the output. - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); y0 += n_elem_per_reg; - a0 += n_elem_per_reg; - a1 += n_elem_per_reg; - a2 += n_elem_per_reg; - a3 += n_elem_per_reg; - a4 += n_elem_per_reg; + av[0] += n_elem_per_reg; + av[1] += n_elem_per_reg; + av[2] += n_elem_per_reg; + av[3] += n_elem_per_reg; + av[4] += n_elem_per_reg; } // If there are leftover iterations, perform them with scalar code. @@ -503,25 +613,25 @@ void bli_daxpyf_zen_int_5 { double y0c = *y0; - const double a0c = *a0; - const double a1c = *a1; - const double a2c = *a2; - const double a3c = *a3; - const double a4c = *a4; + const double a0c = *av[0]; + const double a1c = *av[1]; + const double a2c = *av[2]; + const double a3c = *av[3]; + const double a4c = *av[4]; - y0c += chi0 * a0c; - y0c += chi1 * a1c; - y0c += chi2 * a2c; - y0c += chi3 * a3c; - y0c += chi4 * a4c; + y0c += chi[0] * a0c; + y0c += chi[1] * a1c; + y0c += chi[2] * a2c; + y0c += chi[3] * a3c; + y0c += chi[4] * a4c; *y0 = y0c; - a0 += 1; - a1 += 1; - a2 += 1; - a3 += 1; - a4 += 1; + av[0] += 1; + av[1] += 1; + av[2] += 1; + av[3] += 1; + av[4] += 1; y0 += 1; } } @@ -531,25 +641,25 @@ void bli_daxpyf_zen_int_5 { double y0c = *y0; - const double a0c = *a0; - const double a1c = *a1; - const double a2c = *a2; - const double a3c = *a3; - const double a4c = *a4; + const double a0c = *av[0]; + const double a1c = *av[1]; + const double a2c = *av[2]; + const double a3c = *av[3]; + const double a4c = *av[4]; - y0c += chi0 * a0c; - y0c += chi1 * a1c; - y0c += chi2 * a2c; - y0c += chi3 * a3c; - y0c += chi4 * a4c; + y0c += chi[0] * a0c; + y0c += chi[1] * a1c; + y0c += chi[2] * a2c; + y0c += chi[3] * a3c; + y0c += chi[4] * a4c; *y0 = y0c; - a0 += inca; - a1 += inca; - a2 += inca; - a3 += inca; - a4 += inca; + av[0] += inca; + av[1] += inca; + av[2] += inca; + av[3] += inca; + av[4] += inca; y0 += incy; } @@ -1153,7 +1263,7 @@ void bli_daxpyf_zen_int_16x4 a2 += n_elem_per_reg; a3 += n_elem_per_reg; } -#if 1 + for ( ; (i + 1) < m; i += 2) { @@ -1186,7 +1296,7 @@ void bli_daxpyf_zen_int_16x4 a2 += 2; a3 += 2; } -#endif + // If there are leftover iterations, perform them with scalar code. for ( ; (i + 0) < m ; ++i ) { diff --git a/kernels/zen/1f/bli_axpyf_zen_int_8.c b/kernels/zen/1f/bli_axpyf_zen_int_8.c index b958600ce..27dafb28f 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_8.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2018, The University of Texas at Austin - Copyright (C) 2016 - 2018, Advanced Micro Devices, Inc. + Copyright (C) 2016 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -279,32 +279,19 @@ void bli_daxpyf_zen_int_8 const dim_t fuse_fac = 8; const dim_t n_elem_per_reg = 4; - const dim_t n_iter_unroll = 1; + const dim_t n_iter_unroll[4] = {4, 3, 2, 1}; dim_t i; - dim_t m_viter; - dim_t m_left; + dim_t m_viter[4]; + dim_t m_left = m; - double* restrict a0; - double* restrict a1; - double* restrict a2; - double* restrict a3; - double* restrict a4; - double* restrict a5; - double* restrict a6; - double* restrict a7; + double* restrict av[8] __attribute__((aligned(64))); double* restrict y0; - v4df_t chi0v, chi1v, chi2v, chi3v; - v4df_t chi4v, chi5v, chi6v, chi7v; + v4df_t chiv[8], a_vec[32], yv[4]; - v4df_t a0v, a1v, a2v, a3v; - v4df_t a4v, a5v, a6v, a7v; - v4df_t y0v; - - double chi0, chi1, chi2, chi3; - double chi4, chi5, chi6, chi7; + double chi[8] __attribute__((aligned(64))); // If either dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim2( m, b_n ) || PASTEMAC(d,eq0)( *alpha ) ) return; @@ -343,94 +330,343 @@ void bli_daxpyf_zen_int_8 // Use the unrolling factor and the number of elements per register // to compute the number of vectorized and leftover iterations. - m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll ); - m_left = ( m ) % ( n_elem_per_reg * n_iter_unroll ); + m_viter[0] = ( m_left ) / ( n_elem_per_reg * n_iter_unroll[0] ); + m_left = ( m_left ) % ( n_elem_per_reg * n_iter_unroll[0] ); + + m_viter[1] = ( m_left ) / ( n_elem_per_reg * n_iter_unroll[1] ); + m_left = ( m_left ) % ( n_elem_per_reg * n_iter_unroll[1] ); + + m_viter[2] = ( m_left ) / ( n_elem_per_reg * n_iter_unroll[2] ); + m_left = ( m_left ) % ( n_elem_per_reg * n_iter_unroll[2] ); + + m_viter[3] = ( m_left ) / ( n_elem_per_reg * n_iter_unroll[3] ); + m_left = ( m_left ) % ( n_elem_per_reg * n_iter_unroll[3] ); // If there is anything that would interfere with our use of contiguous // vector loads/stores, override m_viter and m_left to use scalar code // for all iterations. if ( inca != 1 || incy != 1 ) { - m_viter = 0; + m_viter[0] = m_viter[1] = m_viter[2] = m_viter[3] = 0; m_left = m; } - a0 = a + 0*lda; - a1 = a + 1*lda; - a2 = a + 2*lda; - a3 = a + 3*lda; - a4 = a + 4*lda; - a5 = a + 5*lda; - a6 = a + 6*lda; - a7 = a + 7*lda; + // av points to the 8 columns under consideration + av[0] = a + 0*lda; + av[1] = a + 1*lda; + av[2] = a + 2*lda; + av[3] = a + 3*lda; + av[4] = a + 4*lda; + av[5] = a + 5*lda; + av[6] = a + 6*lda; + av[7] = a + 7*lda; y0 = y; - chi0 = *( x + 0*incx ); - chi1 = *( x + 1*incx ); - chi2 = *( x + 2*incx ); - chi3 = *( x + 3*incx ); - chi4 = *( x + 4*incx ); - chi5 = *( x + 5*incx ); - chi6 = *( x + 6*incx ); - chi7 = *( x + 7*incx ); + chi[0] = *( x + 0*incx ); + chi[1] = *( x + 1*incx ); + chi[2] = *( x + 2*incx ); + chi[3] = *( x + 3*incx ); + chi[4] = *( x + 4*incx ); + chi[5] = *( x + 5*incx ); + chi[6] = *( x + 6*incx ); + chi[7] = *( x + 7*incx ); // Scale each chi scalar by alpha. - PASTEMAC(d,scals)( *alpha, chi0 ); - PASTEMAC(d,scals)( *alpha, chi1 ); - PASTEMAC(d,scals)( *alpha, chi2 ); - PASTEMAC(d,scals)( *alpha, chi3 ); - PASTEMAC(d,scals)( *alpha, chi4 ); - PASTEMAC(d,scals)( *alpha, chi5 ); - PASTEMAC(d,scals)( *alpha, chi6 ); - PASTEMAC(d,scals)( *alpha, chi7 ); + PASTEMAC(d,scals)( *alpha, chi[0] ); + PASTEMAC(d,scals)( *alpha, chi[1] ); + PASTEMAC(d,scals)( *alpha, chi[2] ); + PASTEMAC(d,scals)( *alpha, chi[3] ); + PASTEMAC(d,scals)( *alpha, chi[4] ); + PASTEMAC(d,scals)( *alpha, chi[5] ); + PASTEMAC(d,scals)( *alpha, chi[6] ); + PASTEMAC(d,scals)( *alpha, chi[7] ); // Broadcast the (alpha*chi?) scalars to all elements of vector registers. - chi0v.v = _mm256_broadcast_sd( &chi0 ); - chi1v.v = _mm256_broadcast_sd( &chi1 ); - chi2v.v = _mm256_broadcast_sd( &chi2 ); - chi3v.v = _mm256_broadcast_sd( &chi3 ); - chi4v.v = _mm256_broadcast_sd( &chi4 ); - chi5v.v = _mm256_broadcast_sd( &chi5 ); - chi6v.v = _mm256_broadcast_sd( &chi6 ); - chi7v.v = _mm256_broadcast_sd( &chi7 ); + chiv[0].v = _mm256_broadcast_sd( &chi[0] ); + chiv[1].v = _mm256_broadcast_sd( &chi[1] ); + chiv[2].v = _mm256_broadcast_sd( &chi[2] ); + chiv[3].v = _mm256_broadcast_sd( &chi[3] ); + chiv[4].v = _mm256_broadcast_sd( &chi[4] ); + chiv[5].v = _mm256_broadcast_sd( &chi[5] ); + chiv[6].v = _mm256_broadcast_sd( &chi[6] ); + chiv[7].v = _mm256_broadcast_sd( &chi[7] ); // If there are vectorized iterations, perform them with vector // instructions. - for ( i = 0; i < m_viter; ++i ) + // 16 elements of the result are computed per iteration + for ( i = 0; i < m_viter[0]; ++i ) { // Load the input values. - y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - a0v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); - a1v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); - a2v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); - a3v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); - a4v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg ); - a5v.v = _mm256_loadu_pd( a5 + 0*n_elem_per_reg ); - a6v.v = _mm256_loadu_pd( a6 + 0*n_elem_per_reg ); - a7v.v = _mm256_loadu_pd( a7 + 0*n_elem_per_reg ); + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + a_vec[5].v = _mm256_loadu_pd( av[5] + 0*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[6] + 0*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[7] + 0*n_elem_per_reg ); + + a_vec[8].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[10].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[11].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[12].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + a_vec[13].v = _mm256_loadu_pd( av[5] + 1*n_elem_per_reg ); + a_vec[14].v = _mm256_loadu_pd( av[6] + 1*n_elem_per_reg ); + a_vec[15].v = _mm256_loadu_pd( av[7] + 1*n_elem_per_reg ); + + a_vec[16].v = _mm256_loadu_pd( av[0] + 2*n_elem_per_reg ); + a_vec[17].v = _mm256_loadu_pd( av[1] + 2*n_elem_per_reg ); + a_vec[18].v = _mm256_loadu_pd( av[2] + 2*n_elem_per_reg ); + a_vec[19].v = _mm256_loadu_pd( av[3] + 2*n_elem_per_reg ); + a_vec[20].v = _mm256_loadu_pd( av[4] + 2*n_elem_per_reg ); + a_vec[21].v = _mm256_loadu_pd( av[5] + 2*n_elem_per_reg ); + a_vec[22].v = _mm256_loadu_pd( av[6] + 2*n_elem_per_reg ); + a_vec[23].v = _mm256_loadu_pd( av[7] + 2*n_elem_per_reg ); + + a_vec[24].v = _mm256_loadu_pd( av[0] + 3*n_elem_per_reg ); + a_vec[25].v = _mm256_loadu_pd( av[1] + 3*n_elem_per_reg ); + a_vec[26].v = _mm256_loadu_pd( av[2] + 3*n_elem_per_reg ); + a_vec[27].v = _mm256_loadu_pd( av[3] + 3*n_elem_per_reg ); + a_vec[28].v = _mm256_loadu_pd( av[4] + 3*n_elem_per_reg ); + a_vec[29].v = _mm256_loadu_pd( av[5] + 3*n_elem_per_reg ); + a_vec[30].v = _mm256_loadu_pd( av[6] + 3*n_elem_per_reg ); + a_vec[31].v = _mm256_loadu_pd( av[7] + 3*n_elem_per_reg ); // perform : y += alpha * x; - y0v.v = _mm256_fmadd_pd( a0v.v, chi0v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a1v.v, chi1v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a2v.v, chi2v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a3v.v, chi3v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a4v.v, chi4v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a5v.v, chi5v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a6v.v, chi6v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a7v.v, chi7v.v, y0v.v ); + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[5].v, chiv[5].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[6].v, chiv[6].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[7].v, chiv[7].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[10].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[11].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[12].v, chiv[4].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[13].v, chiv[5].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[14].v, chiv[6].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[15].v, chiv[7].v, yv[1].v ); + + yv[2].v = _mm256_fmadd_pd( a_vec[16].v, chiv[0].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[17].v, chiv[1].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[18].v, chiv[2].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[19].v, chiv[3].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[20].v, chiv[4].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[21].v, chiv[5].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[22].v, chiv[6].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[23].v, chiv[7].v, yv[2].v ); + + yv[3].v = _mm256_fmadd_pd( a_vec[24].v, chiv[0].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[25].v, chiv[1].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[26].v, chiv[2].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[27].v, chiv[3].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[28].v, chiv[4].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[29].v, chiv[5].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[30].v, chiv[6].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[31].v, chiv[7].v, yv[3].v ); // Store the output. - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2].v ); + _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), yv[3].v ); + + y0 += n_elem_per_reg * n_iter_unroll[0]; + av[0] += n_elem_per_reg * n_iter_unroll[0]; + av[1] += n_elem_per_reg * n_iter_unroll[0]; + av[2] += n_elem_per_reg * n_iter_unroll[0]; + av[3] += n_elem_per_reg * n_iter_unroll[0]; + av[4] += n_elem_per_reg * n_iter_unroll[0]; + av[5] += n_elem_per_reg * n_iter_unroll[0]; + av[6] += n_elem_per_reg * n_iter_unroll[0]; + av[7] += n_elem_per_reg * n_iter_unroll[0]; + } + + // 12 elements of the result are computed per iteration + for ( i = 0; i < m_viter[1]; ++i ) + { + // Load the input values. + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + a_vec[5].v = _mm256_loadu_pd( av[5] + 0*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[6] + 0*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[7] + 0*n_elem_per_reg ); + + a_vec[8].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[10].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[11].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[12].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + a_vec[13].v = _mm256_loadu_pd( av[5] + 1*n_elem_per_reg ); + a_vec[14].v = _mm256_loadu_pd( av[6] + 1*n_elem_per_reg ); + a_vec[15].v = _mm256_loadu_pd( av[7] + 1*n_elem_per_reg ); + + a_vec[16].v = _mm256_loadu_pd( av[0] + 2*n_elem_per_reg ); + a_vec[17].v = _mm256_loadu_pd( av[1] + 2*n_elem_per_reg ); + a_vec[18].v = _mm256_loadu_pd( av[2] + 2*n_elem_per_reg ); + a_vec[19].v = _mm256_loadu_pd( av[3] + 2*n_elem_per_reg ); + a_vec[20].v = _mm256_loadu_pd( av[4] + 2*n_elem_per_reg ); + a_vec[21].v = _mm256_loadu_pd( av[5] + 2*n_elem_per_reg ); + a_vec[22].v = _mm256_loadu_pd( av[6] + 2*n_elem_per_reg ); + a_vec[23].v = _mm256_loadu_pd( av[7] + 2*n_elem_per_reg ); + + // perform : y += alpha * x; + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[5].v, chiv[5].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[6].v, chiv[6].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[7].v, chiv[7].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[10].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[11].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[12].v, chiv[4].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[13].v, chiv[5].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[14].v, chiv[6].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[15].v, chiv[7].v, yv[1].v ); + + yv[2].v = _mm256_fmadd_pd( a_vec[16].v, chiv[0].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[17].v, chiv[1].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[18].v, chiv[2].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[19].v, chiv[3].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[20].v, chiv[4].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[21].v, chiv[5].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[22].v, chiv[6].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[23].v, chiv[7].v, yv[2].v ); + + // Store the output. + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2].v ); + + y0 += n_elem_per_reg * n_iter_unroll[1]; + av[0] += n_elem_per_reg * n_iter_unroll[1]; + av[1] += n_elem_per_reg * n_iter_unroll[1]; + av[2] += n_elem_per_reg * n_iter_unroll[1]; + av[3] += n_elem_per_reg * n_iter_unroll[1]; + av[4] += n_elem_per_reg * n_iter_unroll[1]; + av[5] += n_elem_per_reg * n_iter_unroll[1]; + av[6] += n_elem_per_reg * n_iter_unroll[1]; + av[7] += n_elem_per_reg * n_iter_unroll[1]; + } + + // 8 elements of the result are computed per iteration + for ( i = 0; i < m_viter[2]; ++i ) + { + // Load the input values. + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + a_vec[5].v = _mm256_loadu_pd( av[5] + 0*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[6] + 0*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[7] + 0*n_elem_per_reg ); + + a_vec[8].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[10].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[11].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[12].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + a_vec[13].v = _mm256_loadu_pd( av[5] + 1*n_elem_per_reg ); + a_vec[14].v = _mm256_loadu_pd( av[6] + 1*n_elem_per_reg ); + a_vec[15].v = _mm256_loadu_pd( av[7] + 1*n_elem_per_reg ); + + // perform : y += alpha * x; + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[5].v, chiv[5].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[6].v, chiv[6].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[7].v, chiv[7].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[10].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[11].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[12].v, chiv[4].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[13].v, chiv[5].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[14].v, chiv[6].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[15].v, chiv[7].v, yv[1].v ); + + // Store the output. + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + + y0 += n_elem_per_reg * n_iter_unroll[2]; + av[0] += n_elem_per_reg * n_iter_unroll[2]; + av[1] += n_elem_per_reg * n_iter_unroll[2]; + av[2] += n_elem_per_reg * n_iter_unroll[2]; + av[3] += n_elem_per_reg * n_iter_unroll[2]; + av[4] += n_elem_per_reg * n_iter_unroll[2]; + av[5] += n_elem_per_reg * n_iter_unroll[2]; + av[6] += n_elem_per_reg * n_iter_unroll[2]; + av[7] += n_elem_per_reg * n_iter_unroll[2]; + } + + // 4 elements of the result are computed per iteration + for ( i = 0; i < m_viter[3]; ++i ) + { + // Load the input values. + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + a_vec[5].v = _mm256_loadu_pd( av[5] + 0*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[6] + 0*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[7] + 0*n_elem_per_reg ); + + // perform : y += alpha * x; + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[5].v, chiv[5].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[6].v, chiv[6].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[7].v, chiv[7].v, yv[0].v ); + + // Store the output. + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); y0 += n_elem_per_reg; - a0 += n_elem_per_reg; - a1 += n_elem_per_reg; - a2 += n_elem_per_reg; - a3 += n_elem_per_reg; - a4 += n_elem_per_reg; - a5 += n_elem_per_reg; - a6 += n_elem_per_reg; - a7 += n_elem_per_reg; + av[0] += n_elem_per_reg; + av[1] += n_elem_per_reg; + av[2] += n_elem_per_reg; + av[3] += n_elem_per_reg; + av[4] += n_elem_per_reg; + av[5] += n_elem_per_reg; + av[6] += n_elem_per_reg; + av[7] += n_elem_per_reg; } // If there are leftover iterations, perform them with scalar code. @@ -438,34 +674,34 @@ void bli_daxpyf_zen_int_8 { double y0c = *y0; - const double a0c = *a0; - const double a1c = *a1; - const double a2c = *a2; - const double a3c = *a3; - const double a4c = *a4; - const double a5c = *a5; - const double a6c = *a6; - const double a7c = *a7; + const double a0c = *av[0]; + const double a1c = *av[1]; + const double a2c = *av[2]; + const double a3c = *av[3]; + const double a4c = *av[4]; + const double a5c = *av[5]; + const double a6c = *av[6]; + const double a7c = *av[7]; - y0c += chi0 * a0c; - y0c += chi1 * a1c; - y0c += chi2 * a2c; - y0c += chi3 * a3c; - y0c += chi4 * a4c; - y0c += chi5 * a5c; - y0c += chi6 * a6c; - y0c += chi7 * a7c; + y0c += chi[0] * a0c; + y0c += chi[1] * a1c; + y0c += chi[2] * a2c; + y0c += chi[3] * a3c; + y0c += chi[4] * a4c; + y0c += chi[5] * a5c; + y0c += chi[6] * a6c; + y0c += chi[7] * a7c; *y0 = y0c; - a0 += inca; - a1 += inca; - a2 += inca; - a3 += inca; - a4 += inca; - a5 += inca; - a6 += inca; - a7 += inca; + av[0] += inca; + av[1] += inca; + av[2] += inca; + av[3] += inca; + av[4] += inca; + av[5] += inca; + av[6] += inca; + av[7] += inca; y0 += incy; } }