diff --git a/frame/2/gemv/bli_gemv_unf_var1.c b/frame/2/gemv/bli_gemv_unf_var1.c index e468587d4..838ea577b 100644 --- a/frame/2/gemv/bli_gemv_unf_var1.c +++ b/frame/2/gemv/bli_gemv_unf_var1.c @@ -34,7 +34,6 @@ */ #include "blis.h" -#define BLIS_DGEMV_VAR1_FUSE 8 #undef GENTFUNC #define GENTFUNC( ctype, ch, varname ) \ @@ -121,30 +120,30 @@ void bli_dgemv_unf_var1 ) { - double* A1; - double* y1; - dim_t i; - dim_t f; - dim_t n_elem, n_iter; - inc_t rs_at, cs_at; - conj_t conja; + double *A1; + double *y1; + dim_t i; + dim_t f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; //memory pool declarations for packing vector X. - mem_t mem_bufX; - rntm_t rntm; - double *x_buf = x; - inc_t buf_incx = incx; + mem_t mem_bufX; + rntm_t rntm; + double *x_buf = x; + inc_t buf_incx = incx; bli_init_once(); - if( cntx == NULL ) cntx = bli_gks_query_cntx(); + if (cntx == NULL) + cntx = bli_gks_query_cntx(); - bli_set_dims_incs_with_trans( transa, - m, n, rs_a, cs_a, - &n_iter, &n_elem, &rs_at, &cs_at ); + bli_set_dims_incs_with_trans(transa, + m, n, rs_a, cs_a, + &n_iter, &n_elem, &rs_at, &cs_at); - conja = bli_extract_conj( transa ); + conja = bli_extract_conj(transa); - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. // This function is invoked on all architectures including ‘generic’. // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). @@ -193,88 +192,154 @@ void bli_dgemv_unf_var1 AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); return; } - + if (incx > 1) { - /* + /* Initialize mem pool buffer to NULL and size to 0 "buf" and "size" fields are assigned once memory is allocated from the pool in bli_membrk_acquire_m(). This will ensure bli_mem_is_alloc() will be passed on an allocated memory if created or a NULL . - */ - mem_bufX.pblk.buf = NULL; mem_bufX.pblk.block_size = 0; - mem_bufX.buf_type = 0; mem_bufX.size = 0; - mem_bufX.pool = NULL; + */ - /* In order to get the buffer from pool via rntm access to memory broker + mem_bufX.pblk.buf = NULL; + mem_bufX.pblk.block_size = 0; + mem_bufX.buf_type = 0; + mem_bufX.size = 0; + mem_bufX.pool = NULL; + + /* In order to get the buffer from pool via rntm access to memory broker is needed.Following are initializations for rntm */ - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_rntm_init_from_global(&rntm); + bli_rntm_set_num_threads_only(1, &rntm); + bli_membrk_rntm_set_membrk(&rntm); - //calculate the size required for n_elem double elements in vector X. - size_t buffer_size = n_elem * sizeof(double); + //calculate the size required for n_elem double elements in vector X. + size_t buffer_size = n_elem * sizeof(double); - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dgemv_unf_var1(): get mem pool block\n" ); - #endif +#ifdef BLIS_ENABLE_MEM_TRACING + printf("bli_dgemv_unf_var1(): get mem pool block\n"); +#endif - /*acquire a Buffer(n_elem*size(double)) from the memory broker - and save the associated mem_t entry to mem_bufX.*/ - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BUFFER_FOR_B_PANEL, - &mem_bufX); + /*acquire a Buffer(n_elem*size(double)) from the memory broker + and save the associated mem_t entry to mem_bufX.*/ + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BUFFER_FOR_B_PANEL, + &mem_bufX); - /*Continue packing X if buffer memory is allocated*/ - if ((bli_mem_is_alloc( &mem_bufX ))) - { - x_buf = bli_mem_buffer(&mem_bufX); - - //pack X vector with non-unit stride to a temp buffer x_buf with unit stride - for(dim_t x_index = 0 ; x_index < n_elem ; x_index++) - { - *(x_buf + x_index) = *(x + (x_index * incx)) ; - } - // stride of vector x_buf =1 - buf_incx = 1; - } - } - - for ( i = 0; i < n_iter; i += f ) + /*Continue packing X if buffer memory is allocated*/ + if ((bli_mem_is_alloc(&mem_bufX))) { - f = bli_determine_blocksize_dim_f( i, n_iter, BLIS_DGEMV_VAR1_FUSE ); + x_buf = bli_mem_buffer(&mem_bufX); - A1 = a + (i )*rs_at + (0 )*cs_at; - y1 = y + (i )*incy; + //pack X vector with non-unit stride to a temp buffer x_buf with unit stride + for (dim_t x_index = 0; x_index < n_elem; x_index++) + { + *(x_buf + x_index) = *(x + (x_index * incx)); + } + // stride of vector x_buf =1 + buf_incx = 1; + } + } - /* y1 = beta * y1 + alpha * A1 * x; */ - bli_ddotxf_zen_int_8 - ( + dim_t fuse_factor = 8; + dim_t f_temp =0; + + if (n < 4) + { + fuse_factor = 2; + } else if (n < 8) + { + fuse_factor = 4; + } + + + for (i = 0; i < n_iter; i += f) + { + f = bli_determine_blocksize_dim_f(i, n_iter, fuse_factor); + + //A = a + i * row_increment + 0 * column_increment + A1 = a + (i)*rs_at; + y1 = y + (i)*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + switch (f) + { + case 8: + + bli_ddotxf_zen_int_8( conja, conjx, n_elem, f, alpha, - A1, cs_at, rs_at, - x_buf, buf_incx, + A1, cs_at, rs_at, + x_buf, buf_incx, beta, - y1, incy, - cntx - ); + y1, incy, + cntx); + break; + default: + + if (f < 4) + { + bli_ddotxf_zen_int_2( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x_buf, buf_incx, + beta, + y1, incy, + cntx); + } + else + { + bli_ddotxf_zen_int_4( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x_buf, buf_incx, + beta, + y1, incy, + cntx); + } } - if ((incx > 1) && bli_mem_is_alloc( &mem_bufX )) + + f_temp = bli_determine_blocksize_dim_f(i + f, n_iter, fuse_factor); + + if (f_temp < fuse_factor) { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dgemv_unf_var1(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool - bli_membrk_release(&rntm , &mem_bufX); + switch (fuse_factor) + { + case 8: + fuse_factor = 4; + break; + case 4: + fuse_factor = 2; + break; + } } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + } + + if ((incx > 1) && bli_mem_is_alloc(&mem_bufX)) + { +#ifdef BLIS_ENABLE_MEM_TRACING + printf("bli_dgemv_unf_var1(): releasing mem pool block\n"); +#endif + // Return the buffer to pool + bli_membrk_release(&rntm, &mem_bufX); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); } void bli_sgemv_unf_var1 diff --git a/kernels/zen/1f/bli_dotxf_zen_int_8.c b/kernels/zen/1f/bli_dotxf_zen_int_8.c index 531a389b5..e25910fb4 100644 --- a/kernels/zen/1f/bli_dotxf_zen_int_8.c +++ b/kernels/zen/1f/bli_dotxf_zen_int_8.c @@ -52,6 +52,14 @@ typedef union double d[4] __attribute__((aligned(64))); } v4df_t; +/* Union data structure to access AVX registers +* One 128-bit AVX register holds 2 DP elements. */ +typedef union +{ + __m128d v; + double d[2] __attribute__((aligned(64))); +} v2df_t; + // ----------------------------------------------------------------------------- void bli_sdotxf_zen_int_8 @@ -430,49 +438,46 @@ void bli_ddotxf_zen_int_8 cntx_t* restrict cntx ) { - const dim_t fuse_fac = 8; - const dim_t n_elem_per_reg = 4; + const dim_t fuse_fac = 8; + const dim_t n_elem_per_reg = 4; // If the b_n dimension is zero, y is empty and there is no computation. - if ( bli_zero_dim1( b_n ) ) return; + if (bli_zero_dim1(b_n)) + return; // If the m dimension is zero, or if alpha is zero, the computation // simplifies to updating y. - if ( bli_zero_dim1( m ) || PASTEMAC(d,eq0)( *alpha ) ) + if (bli_zero_dim1(m) || PASTEMAC(d, eq0)(*alpha)) { - bli_dscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - b_n, - beta, - y, incy, - cntx - ); + bli_dscalv_zen_int10( + BLIS_NO_CONJUGATE, + b_n, + beta, + y, incy, + cntx); return; } // If b_n is not equal to the fusing factor, then perform the entire // operation as a loop over dotxv. - if ( b_n != fuse_fac ) + if (b_n != fuse_fac) { - for ( dim_t i = 0; i < b_n; ++i ) + for (dim_t i = 0; i < b_n; ++i) { - double* a1 = a + (0 )*inca + (i )*lda; - double* x1 = x + (0 )*incx; - double* psi1 = y + (i )*incy; + double *a1 = a + (0) * inca + (i)*lda; + double *x1 = x + (0) * incx; + double *psi1 = y + (i)*incy; - bli_ddotxv_zen_int - ( - conjat, - conjx, - m, - alpha, - a1, inca, - x1, incx, - beta, - psi1, - cntx - ); + bli_ddotxv_zen_int( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx); } return; } @@ -493,115 +498,113 @@ void bli_ddotxf_zen_int_8 // distinguishes between (1) and (2). // Intermediate variables to hold the completed dot products - double rho0 = 0, rho1 = 0, rho2 = 0, rho3 = 0, - rho4 = 0, rho5 = 0, rho6 = 0, rho7 = 0; + double rho0 = 0, rho1 = 0, rho2 = 0, rho3 = 0; + double rho4 = 0, rho5 = 0, rho6 = 0, rho7 = 0; - if ( inca == 1 && incx == 1 ) + if (inca == 1 && incx == 1) { const dim_t n_iter_unroll = 1; // Use the unrolling factor and the number of elements per register // to compute the number of vectorized and leftover iterations. - dim_t m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll ); + dim_t m_viter; + + // Calculate the number of vector iterations that can occur + // for the given unroll factors. + m_viter = (m) / (n_elem_per_reg * n_iter_unroll); // Set up pointers for x and the b_n columns of A (rows of A^T). - double* restrict x0 = x; - double* restrict a0 = a + 0*lda; - double* restrict a1 = a + 1*lda; - double* restrict a2 = a + 2*lda; - double* restrict a3 = a + 3*lda; - double* restrict a4 = a + 4*lda; - double* restrict a5 = a + 5*lda; - double* restrict a6 = a + 6*lda; - double* restrict a7 = a + 7*lda; + double *restrict x0 = x; + double *restrict av[8]; + + av[0] = a + 0 * lda; + av[1] = a + 1 * lda; + av[2] = a + 2 * lda; + av[3] = a + 3 * lda; + av[4] = a + 4 * lda; + av[5] = a + 5 * lda; + av[6] = a + 6 * lda; + av[7] = a + 7 * lda; // Initialize b_n rho vector accumulators to zero. - v4df_t rho0v; rho0v.v = _mm256_setzero_pd(); - v4df_t rho1v; rho1v.v = _mm256_setzero_pd(); - v4df_t rho2v; rho2v.v = _mm256_setzero_pd(); - v4df_t rho3v; rho3v.v = _mm256_setzero_pd(); - v4df_t rho4v; rho4v.v = _mm256_setzero_pd(); - v4df_t rho5v; rho5v.v = _mm256_setzero_pd(); - v4df_t rho6v; rho6v.v = _mm256_setzero_pd(); - v4df_t rho7v; rho7v.v = _mm256_setzero_pd(); + v4df_t rhov[8]; - v4df_t x0v; - v4df_t a0v, a1v, a2v, a3v, a4v, a5v, a6v, a7v; + rhov[0].v = _mm256_setzero_pd(); + rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); + rhov[4].v = _mm256_setzero_pd(); + rhov[5].v = _mm256_setzero_pd(); + rhov[6].v = _mm256_setzero_pd(); + rhov[7].v = _mm256_setzero_pd(); - // If there are vectorized iterations, perform them with vector - // instructions. - for ( dim_t i = 0; i < m_viter; ++i ) + v4df_t xv; + v4df_t avec[8]; + + for (dim_t i = 0; i < m_viter; ++i) { // Load the input values. - x0v.v = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv.v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); - a0v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); - a1v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); - a2v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); - a3v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); - a4v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg ); - a5v.v = _mm256_loadu_pd( a5 + 0*n_elem_per_reg ); - a6v.v = _mm256_loadu_pd( a6 + 0*n_elem_per_reg ); - a7v.v = _mm256_loadu_pd( a7 + 0*n_elem_per_reg ); + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[2] + 0 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[3] + 0 * n_elem_per_reg); // perform: rho?v += a?v * x0v; - rho0v.v = _mm256_fmadd_pd( a0v.v, x0v.v, rho0v.v ); - rho1v.v = _mm256_fmadd_pd( a1v.v, x0v.v, rho1v.v ); - rho2v.v = _mm256_fmadd_pd( a2v.v, x0v.v, rho2v.v ); - rho3v.v = _mm256_fmadd_pd( a3v.v, x0v.v, rho3v.v ); - rho4v.v = _mm256_fmadd_pd( a4v.v, x0v.v, rho4v.v ); - rho5v.v = _mm256_fmadd_pd( a5v.v, x0v.v, rho5v.v ); - rho6v.v = _mm256_fmadd_pd( a6v.v, x0v.v, rho6v.v ); - rho7v.v = _mm256_fmadd_pd( a7v.v, x0v.v, rho7v.v ); + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv.v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv.v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv.v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv.v, rhov[3].v); + + avec[4].v = _mm256_loadu_pd(av[4] + 0 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[5] + 0 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[6] + 0 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[7] + 0 * n_elem_per_reg); + + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv.v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv.v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv.v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv.v, rhov[7].v); x0 += n_elem_per_reg * n_iter_unroll; - a0 += n_elem_per_reg * n_iter_unroll; - a1 += n_elem_per_reg * n_iter_unroll; - a2 += n_elem_per_reg * n_iter_unroll; - a3 += n_elem_per_reg * n_iter_unroll; - a4 += n_elem_per_reg * n_iter_unroll; - a5 += n_elem_per_reg * n_iter_unroll; - a6 += n_elem_per_reg * n_iter_unroll; - a7 += n_elem_per_reg * n_iter_unroll; + av[0] += n_elem_per_reg * n_iter_unroll; + av[1] += n_elem_per_reg * n_iter_unroll; + av[2] += n_elem_per_reg * n_iter_unroll; + av[3] += n_elem_per_reg * n_iter_unroll; + av[4] += n_elem_per_reg * n_iter_unroll; + av[5] += n_elem_per_reg * n_iter_unroll; + av[6] += n_elem_per_reg * n_iter_unroll; + av[7] += n_elem_per_reg * n_iter_unroll; } -#if 0 - rho0 += rho0v.d[0] + rho0v.d[1] + rho0v.d[2] + rho0v.d[3]; - rho1 += rho1v.d[0] + rho1v.d[1] + rho1v.d[2] + rho1v.d[3]; - rho2 += rho2v.d[0] + rho2v.d[1] + rho2v.d[2] + rho2v.d[3]; - rho3 += rho3v.d[0] + rho3v.d[1] + rho3v.d[2] + rho3v.d[3]; - rho4 += rho4v.d[0] + rho4v.d[1] + rho4v.d[2] + rho4v.d[3]; - rho5 += rho5v.d[0] + rho5v.d[1] + rho5v.d[2] + rho5v.d[3]; - rho6 += rho6v.d[0] + rho6v.d[1] + rho6v.d[2] + rho6v.d[3]; - rho7 += rho7v.d[0] + rho7v.d[1] + rho7v.d[2] + rho7v.d[3]; -#else // Sum the elements of a given rho?v. This computes the sum of // elements within lanes and stores the sum to both elements. - rho0v.v = _mm256_hadd_pd( rho0v.v, rho0v.v ); - rho1v.v = _mm256_hadd_pd( rho1v.v, rho1v.v ); - rho2v.v = _mm256_hadd_pd( rho2v.v, rho2v.v ); - rho3v.v = _mm256_hadd_pd( rho3v.v, rho3v.v ); - rho4v.v = _mm256_hadd_pd( rho4v.v, rho4v.v ); - rho5v.v = _mm256_hadd_pd( rho5v.v, rho5v.v ); - rho6v.v = _mm256_hadd_pd( rho6v.v, rho6v.v ); - rho7v.v = _mm256_hadd_pd( rho7v.v, rho7v.v ); + rhov[0].v = _mm256_hadd_pd(rhov[0].v, rhov[0].v); + rhov[1].v = _mm256_hadd_pd(rhov[1].v, rhov[1].v); + rhov[2].v = _mm256_hadd_pd(rhov[2].v, rhov[2].v); + rhov[3].v = _mm256_hadd_pd(rhov[3].v, rhov[3].v); + rhov[4].v = _mm256_hadd_pd(rhov[4].v, rhov[4].v); + rhov[5].v = _mm256_hadd_pd(rhov[5].v, rhov[5].v); + rhov[6].v = _mm256_hadd_pd(rhov[6].v, rhov[6].v); + rhov[7].v = _mm256_hadd_pd(rhov[7].v, rhov[7].v); // Manually add the results from above to finish the sum. - rho0 = rho0v.d[0] + rho0v.d[2]; - rho1 = rho1v.d[0] + rho1v.d[2]; - rho2 = rho2v.d[0] + rho2v.d[2]; - rho3 = rho3v.d[0] + rho3v.d[2]; - rho4 = rho4v.d[0] + rho4v.d[2]; - rho5 = rho5v.d[0] + rho5v.d[2]; - rho6 = rho6v.d[0] + rho6v.d[2]; - rho7 = rho7v.d[0] + rho7v.d[2]; -#endif + rho0 = rhov[0].d[0] + rhov[0].d[2]; + rho1 = rhov[1].d[0] + rhov[1].d[2]; + rho2 = rhov[2].d[0] + rhov[2].d[2]; + rho3 = rhov[3].d[0] + rhov[3].d[2]; + rho4 = rhov[4].d[0] + rhov[4].d[2]; + rho5 = rhov[5].d[0] + rhov[5].d[2]; + rho6 = rhov[6].d[0] + rhov[6].d[2]; + rho7 = rhov[7].d[0] + rhov[7].d[2]; + // Adjust for scalar subproblem. m -= n_elem_per_reg * n_iter_unroll * m_viter; a += n_elem_per_reg * n_iter_unroll * m_viter /* * inca */; x += n_elem_per_reg * n_iter_unroll * m_viter /* * incx */; - } - else if ( lda == 1 ) + + }else if (lda == 1) { const dim_t n_iter_unroll = 3; const dim_t n_reg_per_row = 2; // fuse_fac / n_elem_per_reg; @@ -672,58 +675,50 @@ void bli_ddotxf_zen_int_8 a += n_iter_unroll * m_viter * inca; x += n_iter_unroll * m_viter * incx; } - else + + // Initialize pointers for x and the b_n columns of A (rows of A^T). + double *restrict x0 = x; + double *restrict a0 = a + 0 * lda; + double *restrict a1 = a + 1 * lda; + double *restrict a2 = a + 2 * lda; + double *restrict a3 = a + 3 * lda; + double *restrict a4 = a + 4 * lda; + double *restrict a5 = a + 5 * lda; + double *restrict a6 = a + 6 * lda; + double *restrict a7 = a + 7 * lda; + + // If there are leftover iterations, perform them with scalar code. + for (dim_t i = 0; i < m; ++i) { - // No vectorization possible; use scalar iterations for the entire - // problem. - } + const double x0c = *x0; - // Scalar edge case. - { - // Initialize pointers for x and the b_n columns of A (rows of A^T). - double* restrict x0 = x; - double* restrict a0 = a + 0*lda; - double* restrict a1 = a + 1*lda; - double* restrict a2 = a + 2*lda; - double* restrict a3 = a + 3*lda; - double* restrict a4 = a + 4*lda; - double* restrict a5 = a + 5*lda; - double* restrict a6 = a + 6*lda; - double* restrict a7 = a + 7*lda; + const double a0c = *a0; + const double a1c = *a1; + const double a2c = *a2; + const double a3c = *a3; + const double a4c = *a4; + const double a5c = *a5; + const double a6c = *a6; + const double a7c = *a7; - // If there are leftover iterations, perform them with scalar code. - for ( dim_t i = 0; i < m ; ++i ) - { - const double x0c = *x0; + rho0 += a0c * x0c; + rho1 += a1c * x0c; + rho2 += a2c * x0c; + rho3 += a3c * x0c; + rho4 += a4c * x0c; + rho5 += a5c * x0c; + rho6 += a6c * x0c; + rho7 += a7c * x0c; - const double a0c = *a0; - const double a1c = *a1; - const double a2c = *a2; - const double a3c = *a3; - const double a4c = *a4; - const double a5c = *a5; - const double a6c = *a6; - const double a7c = *a7; - - rho0 += a0c * x0c; - rho1 += a1c * x0c; - rho2 += a2c * x0c; - rho3 += a3c * x0c; - rho4 += a4c * x0c; - rho5 += a5c * x0c; - rho6 += a6c * x0c; - rho7 += a7c * x0c; - - x0 += incx; - a0 += inca; - a1 += inca; - a2 += inca; - a3 += inca; - a4 += inca; - a5 += inca; - a6 += inca; - a7 += inca; - } + x0 += incx; + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + a4 += inca; + a5 += inca; + a6 += inca; + a7 += inca; } // Now prepare the final rho values to output/accumulate back into @@ -742,57 +737,809 @@ void bli_ddotxf_zen_int_8 rho1v.d[3] = rho7; // Broadcast the alpha scalar. - v4df_t alphav; alphav.v = _mm256_broadcast_sd( alpha ); + v4df_t alphav; + alphav.v = _mm256_broadcast_sd(alpha); // We know at this point that alpha is nonzero; however, beta may still // be zero. If beta is indeed zero, we must overwrite y rather than scale // by beta (in case y contains NaN or Inf). - if ( PASTEMAC(d,eq0)( *beta ) ) + if (PASTEMAC(d, eq0)(*beta)) { // Apply alpha to the accumulated dot product in rho: // y := alpha * rho - y0v.v = _mm256_mul_pd( alphav.v, rho0v.v ); - y1v.v = _mm256_mul_pd( alphav.v, rho1v.v ); + y0v.v = _mm256_mul_pd(alphav.v, rho0v.v); + y1v.v = _mm256_mul_pd(alphav.v, rho1v.v); } else { // Broadcast the beta scalar. - v4df_t betav; betav.v = _mm256_broadcast_sd( beta ); + v4df_t betav; + betav.v = _mm256_broadcast_sd(beta); // Load y. - if ( incy == 1 ) + if (incy == 1) { - y0v.v = _mm256_loadu_pd( y + 0*n_elem_per_reg ); - y1v.v = _mm256_loadu_pd( y + 1*n_elem_per_reg ); + y0v.v = _mm256_loadu_pd(y + 0 * n_elem_per_reg); + y1v.v = _mm256_loadu_pd(y + 1 * n_elem_per_reg); } else { - y0v.d[0] = *(y + 0*incy); y0v.d[1] = *(y + 1*incy); - y0v.d[2] = *(y + 2*incy); y0v.d[3] = *(y + 3*incy); - y1v.d[0] = *(y + 4*incy); y1v.d[1] = *(y + 5*incy); - y1v.d[2] = *(y + 6*incy); y1v.d[3] = *(y + 7*incy); + y0v.d[0] = *(y + 0 * incy); + y0v.d[1] = *(y + 1 * incy); + y0v.d[2] = *(y + 2 * incy); + y0v.d[3] = *(y + 3 * incy); + y1v.d[0] = *(y + 4 * incy); + y1v.d[1] = *(y + 5 * incy); + y1v.d[2] = *(y + 6 * incy); + y1v.d[3] = *(y + 7 * incy); } // Apply beta to y and alpha to the accumulated dot product in rho: // y := beta * y + alpha * rho - y0v.v = _mm256_mul_pd( betav.v, y0v.v ); - y1v.v = _mm256_mul_pd( betav.v, y1v.v ); - y0v.v = _mm256_fmadd_pd( alphav.v, rho0v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( alphav.v, rho1v.v, y1v.v ); + y0v.v = _mm256_mul_pd(betav.v, y0v.v); + y1v.v = _mm256_mul_pd(betav.v, y1v.v); + y0v.v = _mm256_fmadd_pd(alphav.v, rho0v.v, y0v.v); + y1v.v = _mm256_fmadd_pd(alphav.v, rho1v.v, y1v.v); } - if ( incy == 1 ) + if (incy == 1) { // Store the output. - _mm256_storeu_pd( (y + 0*n_elem_per_reg), y0v.v ); - _mm256_storeu_pd( (y + 1*n_elem_per_reg), y1v.v ); + _mm256_storeu_pd((y + 0 * n_elem_per_reg), y0v.v); + _mm256_storeu_pd((y + 1 * n_elem_per_reg), y1v.v); } else { - *(y + 0*incy) = y0v.d[0]; *(y + 1*incy) = y0v.d[1]; - *(y + 2*incy) = y0v.d[2]; *(y + 3*incy) = y0v.d[3]; - *(y + 4*incy) = y1v.d[0]; *(y + 5*incy) = y1v.d[1]; - *(y + 6*incy) = y1v.d[2]; *(y + 7*incy) = y1v.d[3]; + *(y + 0 * incy) = y0v.d[0]; + *(y + 1 * incy) = y0v.d[1]; + *(y + 2 * incy) = y0v.d[2]; + *(y + 3 * incy) = y0v.d[3]; + *(y + 4 * incy) = y1v.d[0]; + *(y + 5 * incy) = y1v.d[1]; + *(y + 6 * incy) = y1v.d[2]; + *(y + 7 * incy) = y1v.d[3]; } } + +void bli_ddotxf_zen_int_4 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + double *restrict alpha, + double *restrict a, inc_t inca, inc_t lda, + double *restrict x, inc_t incx, + double *restrict beta, + double *restrict y, inc_t incy, + cntx_t *restrict cntx + ) +{ + const dim_t fuse_fac = 4; + const dim_t n_elem_per_reg = 4; + + // If the b_n dimension is zero, y is empty and there is no computation. + if (bli_zero_dim1(b_n)) + return; + + // If the m dimension is zero, or if alpha is zero, the computation + // simplifies to updating y. + if (bli_zero_dim1(m) || PASTEMAC(d, eq0)(*alpha)) + { + bli_dscalv_zen_int10( + BLIS_NO_CONJUGATE, + b_n, + beta, + y, incy, + cntx); + return; + } + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over dotxv. + if (b_n != fuse_fac) + { + for (dim_t i = 0; i < b_n; ++i) + { + double *a1 = a + (0) * inca + (i)*lda; + double *x1 = x + (0) * incx; + double *psi1 = y + (i)*incy; + + bli_ddotxv_zen_int( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx); + } + return; + } + + // At this point, we know that b_n is exactly equal to the fusing factor. + // However, m may not be a multiple of the number of elements per vector. + + // Going forward, we handle two possible storage formats of A explicitly: + // (1) A is stored by columns, or (2) A is stored by rows. Either case is + // further split into two subproblems along the m dimension: + // (a) a vectorized part, starting at m = 0 and ending at any 0 <= m' <= m. + // (b) a scalar part, starting at m' and ending at m. If no vectorization + // is possible then m' == 0 and thus the scalar part is the entire + // problem. If 0 < m', then the a and x pointers and m variable will + // be adjusted accordingly for the second subproblem. + // Note: since parts (b) for both (1) and (2) are so similar, they are + // factored out into one code block after the following conditional, which + // distinguishes between (1) and (2). + + // Intermediate variables to hold the completed dot products + double rho0 = 0, rho1 = 0, rho2 = 0, rho3 = 0; + + if (inca == 1 && incx == 1) + { + const dim_t n_iter_unroll[4] = {4, 3, 2, 1}; + + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + dim_t m_viter[4], m_left = m, i; + + // Calculate the number of vector iterations that can occur for + // various unroll factors. + for (i = 0; i < 4; ++i) + { + m_viter[i] = (m_left) / (n_elem_per_reg * n_iter_unroll[i]); + m_left = (m_left) % (n_elem_per_reg * n_iter_unroll[i]); + } + + // Set up pointers for x and the b_n columns of A (rows of A^T). + double *restrict x0 = x; + double *restrict av[4]; + + av[0] = a + 0 * lda; + av[1] = a + 1 * lda; + av[2] = a + 2 * lda; + av[3] = a + 3 * lda; + + // Initialize b_n rho vector accumulators to zero. + v4df_t rhov[8]; + + rhov[0].v = _mm256_setzero_pd(); + rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); + rhov[4].v = _mm256_setzero_pd(); + rhov[5].v = _mm256_setzero_pd(); + rhov[6].v = _mm256_setzero_pd(); + rhov[7].v = _mm256_setzero_pd(); + + v4df_t xv[4]; + v4df_t avec[16]; + + // If there are vectorized iterations, perform them with vector + // instructions. + for (i = 0; i < m_viter[0]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3].v = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[2] + 0 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[3] + 0 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[0].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[0].v, rhov[3].v); + + avec[4].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[2] + 1 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[3] + 1 * n_elem_per_reg); + + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[1].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[1].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[1].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[1].v, rhov[7].v); + + avec[8].v = _mm256_loadu_pd(av[0] + 2 * n_elem_per_reg); + avec[9].v = _mm256_loadu_pd(av[1] + 2 * n_elem_per_reg); + avec[10].v = _mm256_loadu_pd(av[2] + 2 * n_elem_per_reg); + avec[11].v = _mm256_loadu_pd(av[3] + 2 * n_elem_per_reg); + + rhov[0].v = _mm256_fmadd_pd(avec[8].v, xv[2].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[9].v, xv[2].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[10].v, xv[2].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[11].v, xv[2].v, rhov[3].v); + + avec[12].v = _mm256_loadu_pd(av[0] + 3 * n_elem_per_reg); + avec[13].v = _mm256_loadu_pd(av[1] + 3 * n_elem_per_reg); + avec[14].v = _mm256_loadu_pd(av[2] + 3 * n_elem_per_reg); + avec[15].v = _mm256_loadu_pd(av[3] + 3 * n_elem_per_reg); + + rhov[4].v = _mm256_fmadd_pd(avec[12].v, xv[3].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[13].v, xv[3].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[14].v, xv[3].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[15].v, xv[3].v, rhov[7].v); + + x0 += n_elem_per_reg * n_iter_unroll[0]; + av[0] += n_elem_per_reg * n_iter_unroll[0]; + av[1] += n_elem_per_reg * n_iter_unroll[0]; + av[2] += n_elem_per_reg * n_iter_unroll[0]; + av[3] += n_elem_per_reg * n_iter_unroll[0]; + } + + for (i = 0; i < m_viter[1]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[2] + 0 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[3] + 0 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[0].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[0].v, rhov[3].v); + + avec[4].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[2] + 1 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[3] + 1 * n_elem_per_reg); + + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[1].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[1].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[1].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[1].v, rhov[7].v); + + avec[8].v = _mm256_loadu_pd(av[0] + 2 * n_elem_per_reg); + avec[9].v = _mm256_loadu_pd(av[1] + 2 * n_elem_per_reg); + avec[10].v = _mm256_loadu_pd(av[2] + 2 * n_elem_per_reg); + avec[11].v = _mm256_loadu_pd(av[3] + 2 * n_elem_per_reg); + + rhov[0].v = _mm256_fmadd_pd(avec[8].v, xv[2].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[9].v, xv[2].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[10].v, xv[2].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[11].v, xv[2].v, rhov[3].v); + + x0 += n_elem_per_reg * n_iter_unroll[1]; + av[0] += n_elem_per_reg * n_iter_unroll[1]; + av[1] += n_elem_per_reg * n_iter_unroll[1]; + av[2] += n_elem_per_reg * n_iter_unroll[1]; + av[3] += n_elem_per_reg * n_iter_unroll[1]; + } + + for (i = 0; i < m_viter[2]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[2] + 0 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[3] + 0 * n_elem_per_reg); + + avec[4].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[2] + 1 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[3] + 1 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[0].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[0].v, rhov[3].v); + + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[1].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[1].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[1].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[1].v, rhov[7].v); + + x0 += n_elem_per_reg * n_iter_unroll[2]; + av[0] += n_elem_per_reg * n_iter_unroll[2]; + av[1] += n_elem_per_reg * n_iter_unroll[2]; + av[2] += n_elem_per_reg * n_iter_unroll[2]; + av[3] += n_elem_per_reg * n_iter_unroll[2]; + } + + for (i = 0; i < m_viter[3]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[2] + 0 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[3] + 0 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[0].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[0].v, rhov[3].v); + + x0 += n_elem_per_reg * n_iter_unroll[3]; + av[0] += n_elem_per_reg * n_iter_unroll[3]; + av[1] += n_elem_per_reg * n_iter_unroll[3]; + av[2] += n_elem_per_reg * n_iter_unroll[3]; + av[3] += n_elem_per_reg * n_iter_unroll[3]; + } + + // Sum the elements of a given rho?v. This computes the sum of + // elements within lanes and stores the sum to both elements. + rhov[0].v = _mm256_add_pd(rhov[0].v, rhov[4].v); + rhov[1].v = _mm256_add_pd(rhov[1].v, rhov[5].v); + rhov[2].v = _mm256_add_pd(rhov[2].v, rhov[6].v); + rhov[3].v = _mm256_add_pd(rhov[3].v, rhov[7].v); + + rhov[0].v = _mm256_hadd_pd(rhov[0].v, rhov[0].v); + rhov[1].v = _mm256_hadd_pd(rhov[1].v, rhov[1].v); + rhov[2].v = _mm256_hadd_pd(rhov[2].v, rhov[2].v); + rhov[3].v = _mm256_hadd_pd(rhov[3].v, rhov[3].v); + + // Manually add the results from above to finish the sum. + rho0 = rhov[0].d[0] + rhov[0].d[2]; + rho1 = rhov[1].d[0] + rhov[1].d[2]; + rho2 = rhov[2].d[0] + rhov[2].d[2]; + rho3 = rhov[3].d[0] + rhov[3].d[2]; + + // Adjust for scalar subproblem. + for (i = 0; i < 4; ++i) + { + m -= n_elem_per_reg * n_iter_unroll[i] * m_viter[i]; + a += n_elem_per_reg * n_iter_unroll[i] * m_viter[i] /* * inca */; + x += n_elem_per_reg * n_iter_unroll[i] * m_viter[i] /* * incx */; + } + } + + // Initialize pointers for x and the b_n columns of A (rows of A^T). + double *restrict x0 = x; + double *restrict a0 = a + 0 * lda; + double *restrict a1 = a + 1 * lda; + double *restrict a2 = a + 2 * lda; + double *restrict a3 = a + 3 * lda; + + // If there are leftover iterations, perform them with scalar code. + for (dim_t i = 0; i < m; ++i) + { + const double x0c = *x0; + + const double a0c = *a0; + const double a1c = *a1; + const double a2c = *a2; + const double a3c = *a3; + + rho0 += a0c * x0c; + rho1 += a1c * x0c; + rho2 += a2c * x0c; + rho3 += a3c * x0c; + + x0 += incx; + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + } + + // Now prepare the final rho values to output/accumulate back into + // the y vector. + + v4df_t rho0v, y0v; + + // Insert the scalar rho values into a single vector. + rho0v.d[0] = rho0; + rho0v.d[1] = rho1; + rho0v.d[2] = rho2; + rho0v.d[3] = rho3; + + // Broadcast the alpha scalar. + v4df_t alphav; + alphav.v = _mm256_broadcast_sd(alpha); + + // We know at this point that alpha is nonzero; however, beta may still + // be zero. If beta is indeed zero, we must overwrite y rather than scale + // by beta (in case y contains NaN or Inf). + if (PASTEMAC(d, eq0)(*beta)) + { + // Apply alpha to the accumulated dot product in rho: + // y := alpha * rho + y0v.v = _mm256_mul_pd(alphav.v, rho0v.v); + } + else + { + // Broadcast the beta scalar. + v4df_t betav; + betav.v = _mm256_broadcast_sd(beta); + + // Load y. + if (incy == 1) + { + y0v.v = _mm256_loadu_pd(y + 0 * n_elem_per_reg); + } + else + { + y0v.d[0] = *(y + 0 * incy); + y0v.d[1] = *(y + 1 * incy); + y0v.d[2] = *(y + 2 * incy); + y0v.d[3] = *(y + 3 * incy); + } + + // Apply beta to y and alpha to the accumulated dot product in rho: + // y := beta * y + alpha * rho + y0v.v = _mm256_mul_pd(betav.v, y0v.v); + y0v.v = _mm256_fmadd_pd(alphav.v, rho0v.v, y0v.v); + } + + if (incy == 1) + { + // Store the output. + _mm256_storeu_pd((y + 0 * n_elem_per_reg), y0v.v); + } + else + { + *(y + 0 * incy) = y0v.d[0]; + *(y + 1 * incy) = y0v.d[1]; + *(y + 2 * incy) = y0v.d[2]; + *(y + 3 * incy) = y0v.d[3]; + } +} + +void bli_ddotxf_zen_int_2 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + double *restrict alpha, + double *restrict a, inc_t inca, inc_t lda, + double *restrict x, inc_t incx, + double *restrict beta, + double *restrict y, inc_t incy, + cntx_t *restrict cntx + ) +{ + const dim_t fuse_fac = 2; + const dim_t n_elem_per_reg = 4; + + // If the b_n dimension is zero, y is empty and there is no computation. + if (bli_zero_dim1(b_n)) + return; + + // If the m dimension is zero, or if alpha is zero, the computation + // simplifies to updating y. + if (bli_zero_dim1(m) || PASTEMAC(d, eq0)(*alpha)) + { + bli_dscalv_zen_int10( + BLIS_NO_CONJUGATE, + b_n, + beta, + y, incy, + cntx); + return; + } + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over dotxv. + if (b_n != fuse_fac) + { + for (dim_t i = 0; i < b_n; ++i) + { + double *a1 = a + (0) * inca + (i)*lda; + double *x1 = x + (0) * incx; + double *psi1 = y + (i)*incy; + + bli_ddotxv_zen_int( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx); + } + return; + } + + // At this point, we know that b_n is exactly equal to the fusing factor. + // However, m may not be a multiple of the number of elements per vector. + + // Going forward, we handle two possible storage formats of A explicitly: + // (1) A is stored by columns, or (2) A is stored by rows. Either case is + // further split into two subproblems along the m dimension: + // (a) a vectorized part, starting at m = 0 and ending at any 0 <= m' <= m. + // (b) a scalar part, starting at m' and ending at m. If no vectorization + // is possible then m' == 0 and thus the scalar part is the entire + // problem. If 0 < m', then the a and x pointers and m variable will + // be adjusted accordingly for the second subproblem. + // Note: since parts (b) for both (1) and (2) are so similar, they are + // factored out into one code block after the following conditional, which + // distinguishes between (1) and (2). + + // Intermediate variables to hold the completed dot products + double rho0 = 0, rho1 = 0; + + if (inca == 1 && incx == 1) + { + const dim_t n_iter_unroll[4] = {8, 4, 2, 1}; + + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + dim_t m_viter[4], i, m_left = m; + + for (i = 0; i < 4; ++i) + { + m_viter[i] = (m_left) / (n_elem_per_reg * n_iter_unroll[i]); + m_left = (m_left) % (n_elem_per_reg * n_iter_unroll[i]); + } + + // Set up pointers for x and the b_n columns of A (rows of A^T). + double *restrict x0 = x; + double *restrict av[2]; + + av[0] = a + 0 * lda; + av[1] = a + 1 * lda; + + // Initialize b_n rho vector accumulators to zero. + v4df_t rhov[8]; + + rhov[0].v = _mm256_setzero_pd(); + rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); + rhov[4].v = _mm256_setzero_pd(); + rhov[5].v = _mm256_setzero_pd(); + rhov[6].v = _mm256_setzero_pd(); + rhov[7].v = _mm256_setzero_pd(); + + v4df_t xv[4]; + v4df_t avec[8]; + + for (i = 0; i < m_viter[0]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3].v = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[4].v = _mm256_loadu_pd(av[0] + 2 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 2 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[0] + 3 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[1] + 3 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[1].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[1].v, rhov[3].v); + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[2].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[2].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[3].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[3].v, rhov[7].v); + + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3].v = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[4].v = _mm256_loadu_pd(av[0] + 2 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 2 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[0] + 3 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[1] + 3 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[1].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[1].v, rhov[3].v); + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[2].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[2].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[3].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[3].v, rhov[7].v); + + x0 += n_elem_per_reg * n_iter_unroll[0]; + av[0] += n_elem_per_reg * n_iter_unroll[0]; + av[1] += n_elem_per_reg * n_iter_unroll[0]; + } + + for (i = 0; i < m_viter[1]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3].v = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[4].v = _mm256_loadu_pd(av[0] + 2 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 2 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[0] + 3 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[1] + 3 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[1].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[1].v, rhov[3].v); + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[2].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[2].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[3].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[3].v, rhov[7].v); + + x0 += n_elem_per_reg * n_iter_unroll[1]; + av[0] += n_elem_per_reg * n_iter_unroll[1]; + av[1] += n_elem_per_reg * n_iter_unroll[1]; + } + + rhov[0].v = _mm256_add_pd(rhov[0].v, rhov[4].v); + rhov[1].v = _mm256_add_pd(rhov[1].v, rhov[5].v); + rhov[2].v = _mm256_add_pd(rhov[2].v, rhov[6].v); + rhov[3].v = _mm256_add_pd(rhov[3].v, rhov[7].v); + + for (i = 0; i < m_viter[2]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[1].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[1].v, rhov[3].v); + + x0 += n_elem_per_reg * n_iter_unroll[2]; + av[0] += n_elem_per_reg * n_iter_unroll[2]; + av[1] += n_elem_per_reg * n_iter_unroll[2]; + } + + rhov[0].v = _mm256_add_pd(rhov[0].v, rhov[2].v); + rhov[1].v = _mm256_add_pd(rhov[1].v, rhov[3].v); + + for (i = 0; i < m_viter[3]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + + x0 += n_elem_per_reg * n_iter_unroll[3]; + av[0] += n_elem_per_reg * n_iter_unroll[3]; + av[1] += n_elem_per_reg * n_iter_unroll[3]; + } + + // Sum the elements of a given rho?v. This computes the sum of + // elements within lanes and stores the sum to both elements. + rhov[0].v = _mm256_hadd_pd(rhov[0].v, rhov[0].v); + rhov[1].v = _mm256_hadd_pd(rhov[1].v, rhov[1].v); + + // Manually add the results from above to finish the sum. + rho0 = rhov[0].d[0] + rhov[0].d[2]; + rho1 = rhov[1].d[0] + rhov[1].d[2]; + + // Adjust for scalar subproblem. + for (i = 0; i < 4; ++i) + { + m -= n_elem_per_reg * n_iter_unroll[i] * m_viter[i]; + a += n_elem_per_reg * n_iter_unroll[i] * m_viter[i] /* * inca */; + x += n_elem_per_reg * n_iter_unroll[i] * m_viter[i] /* * incx */; + } + } + + // Initialize pointers for x and the b_n columns of A (rows of A^T). + double *restrict x0 = x; + double *restrict a0 = a + 0 * lda; + double *restrict a1 = a + 1 * lda; + + // If there are leftover iterations, perform them with scalar code. + for (dim_t i = 0; i < m; ++i) + { + const double x0c = *x0; + + const double a0c = *a0; + const double a1c = *a1; + + rho0 += a0c * x0c; + rho1 += a1c * x0c; + + x0 += incx; + a0 += inca; + a1 += inca; + } + + // Now prepare the final rho values to output/accumulate back into + // the y vector. + + v2df_t rho0v, y0v; + + // Insert the scalar rho values into a single vector. + rho0v.d[0] = rho0; + rho0v.d[1] = rho1; + + // Broadcast the alpha scalar. + v2df_t alphav; + + alphav.v = _mm_load1_pd(alpha); + + // We know at this point that alpha is nonzero; however, beta may still + // be zero. If beta is indeed zero, we must overwrite y rather than scale + // by beta (in case y contains NaN or Inf). + if (PASTEMAC(d, eq0)(*beta)) + { + // Apply alpha to the accumulated dot product in rho: + // y := alpha * rho + y0v.v = _mm_mul_pd(alphav.v, rho0v.v); + } + else + { + // Broadcast the beta scalar. + v2df_t betav; + betav.v = _mm_load1_pd(beta); + + // Load y. + if (incy == 1) + { + y0v.v = _mm_loadu_pd(y + 0 * 2); + } + else + { + y0v.d[0] = *(y + 0 * incy); + y0v.d[1] = *(y + 1 * incy); + } + + // Apply beta to y and alpha to the accumulated dot product in rho: + // y := beta * y + alpha * rho + y0v.v = _mm_mul_pd(betav.v, y0v.v); + y0v.v = _mm_fmadd_pd(alphav.v, rho0v.v, y0v.v); + } + + if (incy == 1) + { + // Store the output. + _mm_storeu_pd((y + 0 * 2), y0v.v); + } + else + { + *(y + 0 * incy) = y0v.d[0]; + *(y + 1 * incy) = y0v.d[1]; + } +} + + diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index f3a939b0b..d46164a9c 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -118,6 +118,8 @@ AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_4 ) // dotxf (intrinsics) DOTXF_KER_PROT( float, s, dotxf_zen_int_8 ) DOTXF_KER_PROT( double, d, dotxf_zen_int_8 ) +DOTXF_KER_PROT( double, d, dotxf_zen_int_4 ) +DOTXF_KER_PROT( double, d, dotxf_zen_int_2 ) // dotxaxpyf (intrinsics) DOTXAXPYF_KER_PROT( double, d, dotxaxpyf_zen_int_8 )