diff --git a/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c b/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c index 489456aeb..62dff91da 100644 --- a/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c +++ b/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c @@ -97,8 +97,8 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 cntx_t* restrict cntx ) { - uint64_t k_iter = k / 4; - uint64_t k_left = k % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; __asm__ volatile ( @@ -865,8 +865,8 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 cntx_t* restrict cntx ) { - dim_t k_iter = k / 12; - dim_t k_left = k % 12; + uint64_t k_iter = k / 12; + uint64_t k_left = k % 12; __asm__ ( @@ -1076,8 +1076,8 @@ void bli_cgemm_bulldozer_asm_8x4_fma4 //void* a_next = bli_auxinfo_next_a( data ); void* b_next = bli_auxinfo_next_b( data ); - dim_t k_iter = k / 4; - dim_t k_left = k % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; __asm__ volatile ( @@ -1883,8 +1883,8 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 //void* a_next = bli_auxinfo_next_a( data ); //void* b_next = bli_auxinfo_next_b( data ); - dim_t k_iter = k / 4; - dim_t k_left = k % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; __asm__ volatile ( diff --git a/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c b/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c index a69b92086..8cd9bc683 100644 --- a/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c +++ b/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c @@ -265,22 +265,22 @@ ahead*/ \ PREFETCH_A_L1(n, 0) \ \ - VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+ 0)*8)) \ - VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+ 1)*8)) \ + VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+ 0)*4)) \ + VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+ 1)*4)) \ VFMADD231PS(ZMM( 8), ZMM(0), ZMM(3)) \ VFMADD231PS(ZMM( 9), ZMM(1), ZMM(3)) \ VFMADD231PS(ZMM(10), ZMM(0), ZMM(4)) \ VFMADD231PS(ZMM(11), ZMM(1), ZMM(4)) \ \ - VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+ 2)*8)) \ - VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+ 3)*8)) \ + VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+ 2)*4)) \ + VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+ 3)*4)) \ VFMADD231PS(ZMM(12), ZMM(0), ZMM(3)) \ VFMADD231PS(ZMM(13), ZMM(1), ZMM(3)) \ VFMADD231PS(ZMM(14), ZMM(0), ZMM(4)) \ VFMADD231PS(ZMM(15), ZMM(1), ZMM(4)) \ \ - VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+ 4)*8)) \ - VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+ 5)*8)) \ + VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+ 4)*4)) \ + VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+ 5)*4)) \ VFMADD231PS(ZMM(16), ZMM(0), ZMM(3)) \ VFMADD231PS(ZMM(17), ZMM(1), ZMM(3)) \ VFMADD231PS(ZMM(18), ZMM(0), ZMM(4)) \ @@ -288,29 +288,29 @@ ahead*/ \ PREFETCH_A_L1(n, 1) \ \ - VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+ 6)*8)) \ - VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+ 7)*8)) \ + VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+ 6)*4)) \ + VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+ 7)*4)) \ VFMADD231PS(ZMM(20), ZMM(0), ZMM(3)) \ VFMADD231PS(ZMM(21), ZMM(1), ZMM(3)) \ VFMADD231PS(ZMM(22), ZMM(0), ZMM(4)) \ VFMADD231PS(ZMM(23), ZMM(1), ZMM(4)) \ \ - VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+ 8)*8)) \ - VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+ 9)*8)) \ + VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+ 8)*4)) \ + VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+ 9)*4)) \ VFMADD231PS(ZMM(24), ZMM(0), ZMM(3)) \ VFMADD231PS(ZMM(25), ZMM(1), ZMM(3)) \ VFMADD231PS(ZMM(26), ZMM(0), ZMM(4)) \ VFMADD231PS(ZMM(27), ZMM(1), ZMM(4)) \ \ - VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+10)*8)) \ - VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+11)*8)) \ + VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+10)*4)) \ + VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+11)*4)) \ VFMADD231PS(ZMM(28), ZMM(0), ZMM(3)) \ VFMADD231PS(ZMM(29), ZMM(1), ZMM(3)) \ VFMADD231PS(ZMM(30), ZMM(0), ZMM(4)) \ VFMADD231PS(ZMM(31), ZMM(1), ZMM(4)) \ \ - VMOVAPD(ZMM(0), MEM(RAX,(16*n+0)*8)) \ - VMOVAPD(ZMM(1), MEM(RAX,(16*n+8)*8)) + VMOVAPD(ZMM(0), MEM(RAX,(32*n+0)*4)) \ + VMOVAPD(ZMM(1), MEM(RAX,(32*n+16)*4)) //This is an array used for the scatter/gather instructions. static int64_t offsets[16] __attribute__((aligned(64))) = diff --git a/kernels/zen/1f/bli_dotxf_zen_int_8.c b/kernels/zen/1f/bli_dotxf_zen_int_8.c index 1375135ab..7bbc14ae0 100644 --- a/kernels/zen/1f/bli_dotxf_zen_int_8.c +++ b/kernels/zen/1f/bli_dotxf_zen_int_8.c @@ -67,42 +67,8 @@ void bli_sdotxf_zen_int_8 cntx_t* restrict cntx ) { - const dim_t fuse_fac = 8; - - const dim_t n_elem_per_reg = 8; - const dim_t n_iter_unroll = 1; - - dim_t i; - dim_t m_viter; - dim_t m_left; - - float* restrict a0; - float* restrict a1; - float* restrict a2; - float* restrict a3; - float* restrict a4; - float* restrict a5; - float* restrict a6; - float* restrict a7; - - float* restrict x0; - - float rho0, rho1, rho2, rho3; - float rho4, rho5, rho6, rho7; - - v8sf_t a0v, a1v, a2v, a3v; - v8sf_t a4v, a5v, a6v, a7v; - - v8sf_t x0v; - - v8sf_t rho0v, rho1v, rho2v, rho3v; - v8sf_t rho4v, rho5v, rho6v, rho7v; - - v8sf_t alphav; - v8sf_t betav; - v8sf_t y0v; - - v8sf_t onev; + const dim_t fuse_fac = 8; + const dim_t n_elem_per_reg = 8; // If the b_n dimension is zero, y is empty and there is no computation. if ( bli_zero_dim1( b_n ) ) return; @@ -130,7 +96,7 @@ void bli_sdotxf_zen_int_8 { sdotxv_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_DOTXV_KER, cntx ); - for ( i = 0; i < b_n; ++i ) + for ( dim_t i = 0; i < b_n; ++i ) { float* a1 = a + (0 )*inca + (i )*lda; float* x1 = x + (0 )*incx; @@ -155,159 +121,265 @@ void bli_sdotxf_zen_int_8 // At this point, we know that b_n is exactly equal to the fusing factor. // However, m may not be a multiple of the number of elements per vector. - // Use the unrolling factor and the number of elements per register - // to compute the number of vectorized and leftover iterations. - m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll ); - m_left = ( m ) % ( n_elem_per_reg * n_iter_unroll ); + // Going forward, we handle two possible storage formats of A explicitly: + // (1) A is stored by columns, or (2) A is stored by rows. Either case is + // further split into two subproblems along the m dimension: + // (a) a vectorized part, starting at m = 0 and ending at any 0 <= m' <= m. + // (b) a scalar part, starting at m' and ending at m. If no vectorization + // is possible then m' == 0 and thus the scalar part is the entire + // problem. If 0 < m', then the a and x pointers and m variable will + // be adjusted accordingly for the second subproblem. + // Note: since parts (b) for both (1) and (2) are so similar, they are + // factored out into one code block after the following conditional, which + // distinguishes between (1) and (2). - // If there is anything that would interfere with our use of contiguous - // vector loads/stores, override m_viter and m_left to use scalar code - // for all iterations. (NOTE: vector instructions can be used even if - // the increment of the output vector, incy, is nonunit.) - if ( inca != 1 || incx != 1 ) + // Intermediate variables to hold the completed dot products + float rho0 = 0, rho1 = 0, rho2 = 0, rho3 = 0, + rho4 = 0, rho5 = 0, rho6 = 0, rho7 = 0; + + if ( inca == 1 && incx == 1 ) { - m_viter = 0; - m_left = m; - } + const dim_t n_iter_unroll = 1; - // Set up pointers for x and the b_n columns of A (rows of A^T). - x0 = x; - a0 = a + 0*lda; - a1 = a + 1*lda; - a2 = a + 2*lda; - a3 = a + 3*lda; - a4 = a + 4*lda; - a5 = a + 5*lda; - a6 = a + 6*lda; - a7 = a + 7*lda; + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + dim_t m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll ); - // Initialize b_n rho vector accumulators to zero. - rho0v.v = _mm256_setzero_ps(); - rho1v.v = _mm256_setzero_ps(); - rho2v.v = _mm256_setzero_ps(); - rho3v.v = _mm256_setzero_ps(); - rho4v.v = _mm256_setzero_ps(); - rho5v.v = _mm256_setzero_ps(); - rho6v.v = _mm256_setzero_ps(); - rho7v.v = _mm256_setzero_ps(); + // Set up pointers for x and the b_n columns of A (rows of A^T). + float* restrict x0 = x; + float* restrict a0 = a + 0*lda; + float* restrict a1 = a + 1*lda; + float* restrict a2 = a + 2*lda; + float* restrict a3 = a + 3*lda; + float* restrict a4 = a + 4*lda; + float* restrict a5 = a + 5*lda; + float* restrict a6 = a + 6*lda; + float* restrict a7 = a + 7*lda; - // If there are vectorized iterations, perform them with vector - // instructions. - for ( i = 0; i < m_viter; ++i ) - { - // Load the input values. - x0v.v = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + // Initialize b_n rho vector accumulators to zero. + v8sf_t rho0v; rho0v.v = _mm256_setzero_ps(); + v8sf_t rho1v; rho1v.v = _mm256_setzero_ps(); + v8sf_t rho2v; rho2v.v = _mm256_setzero_ps(); + v8sf_t rho3v; rho3v.v = _mm256_setzero_ps(); + v8sf_t rho4v; rho4v.v = _mm256_setzero_ps(); + v8sf_t rho5v; rho5v.v = _mm256_setzero_ps(); + v8sf_t rho6v; rho6v.v = _mm256_setzero_ps(); + v8sf_t rho7v; rho7v.v = _mm256_setzero_ps(); - a0v.v = _mm256_loadu_ps( a0 + 0*n_elem_per_reg ); - a1v.v = _mm256_loadu_ps( a1 + 0*n_elem_per_reg ); - a2v.v = _mm256_loadu_ps( a2 + 0*n_elem_per_reg ); - a3v.v = _mm256_loadu_ps( a3 + 0*n_elem_per_reg ); - a4v.v = _mm256_loadu_ps( a4 + 0*n_elem_per_reg ); - a5v.v = _mm256_loadu_ps( a5 + 0*n_elem_per_reg ); - a6v.v = _mm256_loadu_ps( a6 + 0*n_elem_per_reg ); - a7v.v = _mm256_loadu_ps( a7 + 0*n_elem_per_reg ); + v8sf_t x0v; + v8sf_t a0v, a1v, a2v, a3v, a4v, a5v, a6v, a7v; - // perform : rho?v += a?v * x0v; - rho0v.v = _mm256_fmadd_ps( a0v.v, x0v.v, rho0v.v ); - rho1v.v = _mm256_fmadd_ps( a1v.v, x0v.v, rho1v.v ); - rho2v.v = _mm256_fmadd_ps( a2v.v, x0v.v, rho2v.v ); - rho3v.v = _mm256_fmadd_ps( a3v.v, x0v.v, rho3v.v ); - rho4v.v = _mm256_fmadd_ps( a4v.v, x0v.v, rho4v.v ); - rho5v.v = _mm256_fmadd_ps( a5v.v, x0v.v, rho5v.v ); - rho6v.v = _mm256_fmadd_ps( a6v.v, x0v.v, rho6v.v ); - rho7v.v = _mm256_fmadd_ps( a7v.v, x0v.v, rho7v.v ); + // If there are vectorized iterations, perform them with vector + // instructions. + for ( dim_t i = 0; i < m_viter; ++i ) + { + // Load the input values. + x0v.v = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - x0 += n_elem_per_reg * n_iter_unroll; - a0 += n_elem_per_reg * n_iter_unroll; - a1 += n_elem_per_reg * n_iter_unroll; - a2 += n_elem_per_reg * n_iter_unroll; - a3 += n_elem_per_reg * n_iter_unroll; - a4 += n_elem_per_reg * n_iter_unroll; - a5 += n_elem_per_reg * n_iter_unroll; - a6 += n_elem_per_reg * n_iter_unroll; - a7 += n_elem_per_reg * n_iter_unroll; - } + a0v.v = _mm256_loadu_ps( a0 + 0*n_elem_per_reg ); + a1v.v = _mm256_loadu_ps( a1 + 0*n_elem_per_reg ); + a2v.v = _mm256_loadu_ps( a2 + 0*n_elem_per_reg ); + a3v.v = _mm256_loadu_ps( a3 + 0*n_elem_per_reg ); + a4v.v = _mm256_loadu_ps( a4 + 0*n_elem_per_reg ); + a5v.v = _mm256_loadu_ps( a5 + 0*n_elem_per_reg ); + a6v.v = _mm256_loadu_ps( a6 + 0*n_elem_per_reg ); + a7v.v = _mm256_loadu_ps( a7 + 0*n_elem_per_reg ); + + // perform: rho?v += a?v * x0v; + rho0v.v = _mm256_fmadd_ps( a0v.v, x0v.v, rho0v.v ); + rho1v.v = _mm256_fmadd_ps( a1v.v, x0v.v, rho1v.v ); + rho2v.v = _mm256_fmadd_ps( a2v.v, x0v.v, rho2v.v ); + rho3v.v = _mm256_fmadd_ps( a3v.v, x0v.v, rho3v.v ); + rho4v.v = _mm256_fmadd_ps( a4v.v, x0v.v, rho4v.v ); + rho5v.v = _mm256_fmadd_ps( a5v.v, x0v.v, rho5v.v ); + rho6v.v = _mm256_fmadd_ps( a6v.v, x0v.v, rho6v.v ); + rho7v.v = _mm256_fmadd_ps( a7v.v, x0v.v, rho7v.v ); + + x0 += n_elem_per_reg * n_iter_unroll; + a0 += n_elem_per_reg * n_iter_unroll; + a1 += n_elem_per_reg * n_iter_unroll; + a2 += n_elem_per_reg * n_iter_unroll; + a3 += n_elem_per_reg * n_iter_unroll; + a4 += n_elem_per_reg * n_iter_unroll; + a5 += n_elem_per_reg * n_iter_unroll; + a6 += n_elem_per_reg * n_iter_unroll; + a7 += n_elem_per_reg * n_iter_unroll; + } #if 0 - rho0 += rho0v.f[0] + rho0v.f[1] + rho0v.f[2] + rho0v.f[3] + - rho0v.f[4] + rho0v.f[5] + rho0v.f[6] + rho0v.f[7]; - rho1 += rho1v.f[0] + rho1v.f[1] + rho1v.f[2] + rho1v.f[3] + - rho1v.f[4] + rho1v.f[5] + rho1v.f[6] + rho1v.f[7]; - rho2 += rho2v.f[0] + rho2v.f[1] + rho2v.f[2] + rho2v.f[3] + - rho2v.f[4] + rho2v.f[5] + rho2v.f[6] + rho2v.f[7]; - rho3 += rho3v.f[0] + rho3v.f[1] + rho3v.f[2] + rho3v.f[3] + - rho3v.f[4] + rho3v.f[5] + rho3v.f[6] + rho3v.f[7]; - rho4 += rho4v.f[0] + rho4v.f[1] + rho4v.f[2] + rho4v.f[3] + - rho4v.f[4] + rho4v.f[5] + rho4v.f[6] + rho4v.f[7]; - rho5 += rho5v.f[0] + rho5v.f[1] + rho5v.f[2] + rho5v.f[3] + - rho5v.f[4] + rho5v.f[5] + rho5v.f[6] + rho5v.f[7]; - rho6 += rho6v.f[0] + rho6v.f[1] + rho6v.f[2] + rho6v.f[3] + - rho6v.f[4] + rho6v.f[5] + rho6v.f[6] + rho6v.f[7]; - rho7 += rho7v.f[0] + rho7v.f[1] + rho7v.f[2] + rho7v.f[3] + - rho7v.f[4] + rho7v.f[5] + rho7v.f[6] + rho7v.f[7]; + rho0 += rho0v.f[0] + rho0v.f[1] + rho0v.f[2] + rho0v.f[3] + + rho0v.f[4] + rho0v.f[5] + rho0v.f[6] + rho0v.f[7]; + rho1 += rho1v.f[0] + rho1v.f[1] + rho1v.f[2] + rho1v.f[3] + + rho1v.f[4] + rho1v.f[5] + rho1v.f[6] + rho1v.f[7]; + rho2 += rho2v.f[0] + rho2v.f[1] + rho2v.f[2] + rho2v.f[3] + + rho2v.f[4] + rho2v.f[5] + rho2v.f[6] + rho2v.f[7]; + rho3 += rho3v.f[0] + rho3v.f[1] + rho3v.f[2] + rho3v.f[3] + + rho3v.f[4] + rho3v.f[5] + rho3v.f[6] + rho3v.f[7]; + rho4 += rho4v.f[0] + rho4v.f[1] + rho4v.f[2] + rho4v.f[3] + + rho4v.f[4] + rho4v.f[5] + rho4v.f[6] + rho4v.f[7]; + rho5 += rho5v.f[0] + rho5v.f[1] + rho5v.f[2] + rho5v.f[3] + + rho5v.f[4] + rho5v.f[5] + rho5v.f[6] + rho5v.f[7]; + rho6 += rho6v.f[0] + rho6v.f[1] + rho6v.f[2] + rho6v.f[3] + + rho6v.f[4] + rho6v.f[5] + rho6v.f[6] + rho6v.f[7]; + rho7 += rho7v.f[0] + rho7v.f[1] + rho7v.f[2] + rho7v.f[3] + + rho7v.f[4] + rho7v.f[5] + rho7v.f[6] + rho7v.f[7]; #else - // Now we need to sum the elements within each vector. + // Now we need to sum the elements within each vector. - onev.v = _mm256_set1_ps( 1.0f ); + v8sf_t onev; onev.v = _mm256_set1_ps( 1.0f ); - // Sum the elements of a given rho?v by dotting it with 1. The '1' in - // '0xf1' stores the sum of the upper four and lower four values to - // the low elements of each lane: elements 4 and 0, respectively. (The - // 'f' in '0xf1' means include all four elements of each lane in the - // summation.) - rho0v.v = _mm256_dp_ps( rho0v.v, onev.v, 0xf1 ); - rho1v.v = _mm256_dp_ps( rho1v.v, onev.v, 0xf1 ); - rho2v.v = _mm256_dp_ps( rho2v.v, onev.v, 0xf1 ); - rho3v.v = _mm256_dp_ps( rho3v.v, onev.v, 0xf1 ); - rho4v.v = _mm256_dp_ps( rho4v.v, onev.v, 0xf1 ); - rho5v.v = _mm256_dp_ps( rho5v.v, onev.v, 0xf1 ); - rho6v.v = _mm256_dp_ps( rho6v.v, onev.v, 0xf1 ); - rho7v.v = _mm256_dp_ps( rho7v.v, onev.v, 0xf1 ); + // Sum the elements of a given rho?v by dotting it with 1. The '1' in + // '0xf1' stores the sum of the upper four and lower four values to + // the low elements of each lane: elements 4 and 0, respectively. (The + // 'f' in '0xf1' means include all four elements of each lane in the + // summation.) + rho0v.v = _mm256_dp_ps( rho0v.v, onev.v, 0xf1 ); + rho1v.v = _mm256_dp_ps( rho1v.v, onev.v, 0xf1 ); + rho2v.v = _mm256_dp_ps( rho2v.v, onev.v, 0xf1 ); + rho3v.v = _mm256_dp_ps( rho3v.v, onev.v, 0xf1 ); + rho4v.v = _mm256_dp_ps( rho4v.v, onev.v, 0xf1 ); + rho5v.v = _mm256_dp_ps( rho5v.v, onev.v, 0xf1 ); + rho6v.v = _mm256_dp_ps( rho6v.v, onev.v, 0xf1 ); + rho7v.v = _mm256_dp_ps( rho7v.v, onev.v, 0xf1 ); - // Manually add the results from above to finish the sum. - rho0 = rho0v.f[0] + rho0v.f[4]; - rho1 = rho1v.f[0] + rho1v.f[4]; - rho2 = rho2v.f[0] + rho2v.f[4]; - rho3 = rho3v.f[0] + rho3v.f[4]; - rho4 = rho4v.f[0] + rho4v.f[4]; - rho5 = rho5v.f[0] + rho5v.f[4]; - rho6 = rho6v.f[0] + rho6v.f[4]; - rho7 = rho7v.f[0] + rho7v.f[4]; + // Manually add the results from above to finish the sum. + rho0 = rho0v.f[0] + rho0v.f[4]; + rho1 = rho1v.f[0] + rho1v.f[4]; + rho2 = rho2v.f[0] + rho2v.f[4]; + rho3 = rho3v.f[0] + rho3v.f[4]; + rho4 = rho4v.f[0] + rho4v.f[4]; + rho5 = rho5v.f[0] + rho5v.f[4]; + rho6 = rho6v.f[0] + rho6v.f[4]; + rho7 = rho7v.f[0] + rho7v.f[4]; #endif - // If there are leftover iterations, perform them with scalar code. - for ( i = 0; i < m_left ; ++i ) - { - const float x0c = *x0; - - const float a0c = *a0; - const float a1c = *a1; - const float a2c = *a2; - const float a3c = *a3; - const float a4c = *a4; - const float a5c = *a5; - const float a6c = *a6; - const float a7c = *a7; - - rho0 += a0c * x0c; - rho1 += a1c * x0c; - rho2 += a2c * x0c; - rho3 += a3c * x0c; - rho4 += a4c * x0c; - rho5 += a5c * x0c; - rho6 += a6c * x0c; - rho7 += a7c * x0c; - - x0 += incx; - a0 += inca; - a1 += inca; - a2 += inca; - a3 += inca; - a4 += inca; - a5 += inca; - a6 += inca; - a7 += inca; + // Adjust for scalar subproblem. + m -= n_elem_per_reg * n_iter_unroll * m_viter; + a += n_elem_per_reg * n_iter_unroll * m_viter /* * inca */; + x += n_elem_per_reg * n_iter_unroll * m_viter /* * incx */; } + else if ( lda == 1 ) + { + const dim_t n_iter_unroll = 4; + + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + dim_t m_viter = ( m ) / ( n_iter_unroll ); + + // Initialize pointers for x and A. + float* restrict x0 = x; + float* restrict a0 = a; + + // Initialize rho vector accumulators to zero. + v8sf_t rho0v; rho0v.v = _mm256_setzero_ps(); + v8sf_t rho1v; rho1v.v = _mm256_setzero_ps(); + v8sf_t rho2v; rho2v.v = _mm256_setzero_ps(); + v8sf_t rho3v; rho3v.v = _mm256_setzero_ps(); + + v8sf_t x0v, x1v, x2v, x3v; + v8sf_t a0v, a1v, a2v, a3v; + + for ( dim_t i = 0; i < m_viter; ++i ) + { + // Load the input values. + a0v.v = _mm256_loadu_ps( a0 + 0*inca ); + a1v.v = _mm256_loadu_ps( a0 + 1*inca ); + a2v.v = _mm256_loadu_ps( a0 + 2*inca ); + a3v.v = _mm256_loadu_ps( a0 + 3*inca ); + + x0v.v = _mm256_broadcast_ss( x0 + 0*incx ); + x1v.v = _mm256_broadcast_ss( x0 + 1*incx ); + x2v.v = _mm256_broadcast_ss( x0 + 2*incx ); + x3v.v = _mm256_broadcast_ss( x0 + 3*incx ); + + // perform : rho?v += a?v * x?v; + rho0v.v = _mm256_fmadd_ps( a0v.v, x0v.v, rho0v.v ); + rho1v.v = _mm256_fmadd_ps( a1v.v, x1v.v, rho1v.v ); + rho2v.v = _mm256_fmadd_ps( a2v.v, x2v.v, rho2v.v ); + rho3v.v = _mm256_fmadd_ps( a3v.v, x3v.v, rho3v.v ); + + x0 += incx * n_iter_unroll; + a0 += inca * n_iter_unroll; + } + + // Combine the 8 accumulators into one vector register. + rho0v.v = _mm256_add_ps( rho0v.v, rho1v.v ); + rho2v.v = _mm256_add_ps( rho2v.v, rho3v.v ); + rho0v.v = _mm256_add_ps( rho0v.v, rho2v.v ); + + // Write vector components to scalar values. + rho0 = rho0v.f[0]; + rho1 = rho0v.f[1]; + rho2 = rho0v.f[2]; + rho3 = rho0v.f[3]; + rho4 = rho0v.f[4]; + rho5 = rho0v.f[5]; + rho6 = rho0v.f[6]; + rho7 = rho0v.f[7]; + + // Adjust for scalar subproblem. + m -= n_iter_unroll * m_viter; + a += n_iter_unroll * m_viter * inca; + x += n_iter_unroll * m_viter * incx; + } + else + { + // No vectorization possible; use scalar iterations for the entire + // problem. + } + + // Scalar edge case. + { + // Initialize pointers for x and the b_n columns of A (rows of A^T). + float* restrict x0 = x; + float* restrict a0 = a + 0*lda; + float* restrict a1 = a + 1*lda; + float* restrict a2 = a + 2*lda; + float* restrict a3 = a + 3*lda; + float* restrict a4 = a + 4*lda; + float* restrict a5 = a + 5*lda; + float* restrict a6 = a + 6*lda; + float* restrict a7 = a + 7*lda; + + // If there are leftover iterations, perform them with scalar code. + for ( dim_t i = 0; i < m ; ++i ) + { + const float x0c = *x0; + + const float a0c = *a0; + const float a1c = *a1; + const float a2c = *a2; + const float a3c = *a3; + const float a4c = *a4; + const float a5c = *a5; + const float a6c = *a6; + const float a7c = *a7; + + rho0 += a0c * x0c; + rho1 += a1c * x0c; + rho2 += a2c * x0c; + rho3 += a3c * x0c; + rho4 += a4c * x0c; + rho5 += a5c * x0c; + rho6 += a6c * x0c; + rho7 += a7c * x0c; + + x0 += incx; + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + a4 += inca; + a5 += inca; + a6 += inca; + a7 += inca; + } + } + + // Now prepare the final rho values to output/accumulate back into + // the y vector. + + v8sf_t rho0v, y0v; // Insert the scalar rho values into a single vector. rho0v.f[0] = rho0; @@ -320,7 +392,7 @@ void bli_sdotxf_zen_int_8 rho0v.f[7] = rho7; // Broadcast the alpha scalar. - alphav.v = _mm256_broadcast_ss( alpha ); + v8sf_t alphav; alphav.v = _mm256_broadcast_ss( alpha ); // We know at this point that alpha is nonzero; however, beta may still // be zero. If beta is indeed zero, we must overwrite y rather than scale @@ -334,7 +406,7 @@ void bli_sdotxf_zen_int_8 else { // Broadcast the beta scalar. - betav.v = _mm256_broadcast_ss( beta ); + v8sf_t betav; betav.v = _mm256_broadcast_ss( beta ); // Load y. if ( incy == 1 ) @@ -386,39 +458,7 @@ void bli_ddotxf_zen_int_8 ) { const dim_t fuse_fac = 8; - const dim_t n_elem_per_reg = 4; - const dim_t n_iter_unroll = 1; - - dim_t i; - dim_t m_viter; - dim_t m_left; - - double* restrict a0; - double* restrict a1; - double* restrict a2; - double* restrict a3; - double* restrict a4; - double* restrict a5; - double* restrict a6; - double* restrict a7; - - double* restrict x0; - - double rho0, rho1, rho2, rho3; - double rho4, rho5, rho6, rho7; - - v4df_t a0v, a1v, a2v, a3v; - v4df_t a4v, a5v, a6v, a7v; - - v4df_t x0v; - - v4df_t rho0v, rho1v, rho2v, rho3v; - v4df_t rho4v, rho5v, rho6v, rho7v; - - v4df_t alphav; - v4df_t betav; - v4df_t y0v, y1v; // If the b_n dimension is zero, y is empty and there is no computation. if ( bli_zero_dim1( b_n ) ) return; @@ -446,7 +486,7 @@ void bli_ddotxf_zen_int_8 { ddotxv_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_DOTXV_KER, cntx ); - for ( i = 0; i < b_n; ++i ) + for ( dim_t i = 0; i < b_n; ++i ) { double* a1 = a + (0 )*inca + (i )*lda; double* x1 = x + (0 )*incx; @@ -471,144 +511,256 @@ void bli_ddotxf_zen_int_8 // At this point, we know that b_n is exactly equal to the fusing factor. // However, m may not be a multiple of the number of elements per vector. - // Use the unrolling factor and the number of elements per register - // to compute the number of vectorized and leftover iterations. - m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll ); - m_left = ( m ) % ( n_elem_per_reg * n_iter_unroll ); + // Going forward, we handle two possible storage formats of A explicitly: + // (1) A is stored by columns, or (2) A is stored by rows. Either case is + // further split into two subproblems along the m dimension: + // (a) a vectorized part, starting at m = 0 and ending at any 0 <= m' <= m. + // (b) a scalar part, starting at m' and ending at m. If no vectorization + // is possible then m' == 0 and thus the scalar part is the entire + // problem. If 0 < m', then the a and x pointers and m variable will + // be adjusted accordingly for the second subproblem. + // Note: since parts (b) for both (1) and (2) are so similar, they are + // factored out into one code block after the following conditional, which + // distinguishes between (1) and (2). - // If there is anything that would interfere with our use of contiguous - // vector loads/stores, override m_viter and m_left to use scalar code - // for all iterations. (NOTE: vector instructions can be used even if - // the increment of the output vector, incy, is nonunit.) - if ( inca != 1 || incx != 1 ) + // Intermediate variables to hold the completed dot products + double rho0 = 0, rho1 = 0, rho2 = 0, rho3 = 0, + rho4 = 0, rho5 = 0, rho6 = 0, rho7 = 0; + + if ( inca == 1 && incx == 1 ) { - m_viter = 0; - m_left = m; - } + const dim_t n_iter_unroll = 1; - // Set up pointers for x and the b_n columns of A (rows of A^T). - x0 = x; - a0 = a + 0*lda; - a1 = a + 1*lda; - a2 = a + 2*lda; - a3 = a + 3*lda; - a4 = a + 4*lda; - a5 = a + 5*lda; - a6 = a + 6*lda; - a7 = a + 7*lda; + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + dim_t m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll ); - // Initialize b_n rho vector accumulators to zero. - rho0v.v = _mm256_setzero_pd(); - rho1v.v = _mm256_setzero_pd(); - rho2v.v = _mm256_setzero_pd(); - rho3v.v = _mm256_setzero_pd(); - rho4v.v = _mm256_setzero_pd(); - rho5v.v = _mm256_setzero_pd(); - rho6v.v = _mm256_setzero_pd(); - rho7v.v = _mm256_setzero_pd(); + // Set up pointers for x and the b_n columns of A (rows of A^T). + double* restrict x0 = x; + double* restrict a0 = a + 0*lda; + double* restrict a1 = a + 1*lda; + double* restrict a2 = a + 2*lda; + double* restrict a3 = a + 3*lda; + double* restrict a4 = a + 4*lda; + double* restrict a5 = a + 5*lda; + double* restrict a6 = a + 6*lda; + double* restrict a7 = a + 7*lda; - // If there are vectorized iterations, perform them with vector - // instructions. - for ( i = 0; i < m_viter; ++i ) - { - // Load the input values. - x0v.v = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + // Initialize b_n rho vector accumulators to zero. + v4df_t rho0v; rho0v.v = _mm256_setzero_pd(); + v4df_t rho1v; rho1v.v = _mm256_setzero_pd(); + v4df_t rho2v; rho2v.v = _mm256_setzero_pd(); + v4df_t rho3v; rho3v.v = _mm256_setzero_pd(); + v4df_t rho4v; rho4v.v = _mm256_setzero_pd(); + v4df_t rho5v; rho5v.v = _mm256_setzero_pd(); + v4df_t rho6v; rho6v.v = _mm256_setzero_pd(); + v4df_t rho7v; rho7v.v = _mm256_setzero_pd(); - a0v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); - a1v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); - a2v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); - a3v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); - a4v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg ); - a5v.v = _mm256_loadu_pd( a5 + 0*n_elem_per_reg ); - a6v.v = _mm256_loadu_pd( a6 + 0*n_elem_per_reg ); - a7v.v = _mm256_loadu_pd( a7 + 0*n_elem_per_reg ); + v4df_t x0v; + v4df_t a0v, a1v, a2v, a3v, a4v, a5v, a6v, a7v; - // perform : rho?v += a?v * x0v; - rho0v.v = _mm256_fmadd_pd( a0v.v, x0v.v, rho0v.v ); - rho1v.v = _mm256_fmadd_pd( a1v.v, x0v.v, rho1v.v ); - rho2v.v = _mm256_fmadd_pd( a2v.v, x0v.v, rho2v.v ); - rho3v.v = _mm256_fmadd_pd( a3v.v, x0v.v, rho3v.v ); - rho4v.v = _mm256_fmadd_pd( a4v.v, x0v.v, rho4v.v ); - rho5v.v = _mm256_fmadd_pd( a5v.v, x0v.v, rho5v.v ); - rho6v.v = _mm256_fmadd_pd( a6v.v, x0v.v, rho6v.v ); - rho7v.v = _mm256_fmadd_pd( a7v.v, x0v.v, rho7v.v ); + // If there are vectorized iterations, perform them with vector + // instructions. + for ( dim_t i = 0; i < m_viter; ++i ) + { + // Load the input values. + x0v.v = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - x0 += n_elem_per_reg * n_iter_unroll; - a0 += n_elem_per_reg * n_iter_unroll; - a1 += n_elem_per_reg * n_iter_unroll; - a2 += n_elem_per_reg * n_iter_unroll; - a3 += n_elem_per_reg * n_iter_unroll; - a4 += n_elem_per_reg * n_iter_unroll; - a5 += n_elem_per_reg * n_iter_unroll; - a6 += n_elem_per_reg * n_iter_unroll; - a7 += n_elem_per_reg * n_iter_unroll; - } + a0v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); + a1v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); + a2v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); + a3v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); + a4v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg ); + a5v.v = _mm256_loadu_pd( a5 + 0*n_elem_per_reg ); + a6v.v = _mm256_loadu_pd( a6 + 0*n_elem_per_reg ); + a7v.v = _mm256_loadu_pd( a7 + 0*n_elem_per_reg ); + + // perform: rho?v += a?v * x0v; + rho0v.v = _mm256_fmadd_pd( a0v.v, x0v.v, rho0v.v ); + rho1v.v = _mm256_fmadd_pd( a1v.v, x0v.v, rho1v.v ); + rho2v.v = _mm256_fmadd_pd( a2v.v, x0v.v, rho2v.v ); + rho3v.v = _mm256_fmadd_pd( a3v.v, x0v.v, rho3v.v ); + rho4v.v = _mm256_fmadd_pd( a4v.v, x0v.v, rho4v.v ); + rho5v.v = _mm256_fmadd_pd( a5v.v, x0v.v, rho5v.v ); + rho6v.v = _mm256_fmadd_pd( a6v.v, x0v.v, rho6v.v ); + rho7v.v = _mm256_fmadd_pd( a7v.v, x0v.v, rho7v.v ); + + x0 += n_elem_per_reg * n_iter_unroll; + a0 += n_elem_per_reg * n_iter_unroll; + a1 += n_elem_per_reg * n_iter_unroll; + a2 += n_elem_per_reg * n_iter_unroll; + a3 += n_elem_per_reg * n_iter_unroll; + a4 += n_elem_per_reg * n_iter_unroll; + a5 += n_elem_per_reg * n_iter_unroll; + a6 += n_elem_per_reg * n_iter_unroll; + a7 += n_elem_per_reg * n_iter_unroll; + } #if 0 - rho0 += rho0v.d[0] + rho0v.d[1] + rho0v.d[2] + rho0v.d[3]; - rho1 += rho1v.d[0] + rho1v.d[1] + rho1v.d[2] + rho1v.d[3]; - rho2 += rho2v.d[0] + rho2v.d[1] + rho2v.d[2] + rho2v.d[3]; - rho3 += rho3v.d[0] + rho3v.d[1] + rho3v.d[2] + rho3v.d[3]; - rho4 += rho4v.d[0] + rho4v.d[1] + rho4v.d[2] + rho4v.d[3]; - rho5 += rho5v.d[0] + rho5v.d[1] + rho5v.d[2] + rho5v.d[3]; - rho6 += rho6v.d[0] + rho6v.d[1] + rho6v.d[2] + rho6v.d[3]; - rho7 += rho7v.d[0] + rho7v.d[1] + rho7v.d[2] + rho7v.d[3]; + rho0 += rho0v.d[0] + rho0v.d[1] + rho0v.d[2] + rho0v.d[3]; + rho1 += rho1v.d[0] + rho1v.d[1] + rho1v.d[2] + rho1v.d[3]; + rho2 += rho2v.d[0] + rho2v.d[1] + rho2v.d[2] + rho2v.d[3]; + rho3 += rho3v.d[0] + rho3v.d[1] + rho3v.d[2] + rho3v.d[3]; + rho4 += rho4v.d[0] + rho4v.d[1] + rho4v.d[2] + rho4v.d[3]; + rho5 += rho5v.d[0] + rho5v.d[1] + rho5v.d[2] + rho5v.d[3]; + rho6 += rho6v.d[0] + rho6v.d[1] + rho6v.d[2] + rho6v.d[3]; + rho7 += rho7v.d[0] + rho7v.d[1] + rho7v.d[2] + rho7v.d[3]; #else - // Sum the elements of a given rho?v. This computes the sum of - // elements within lanes and stores the sum to both elements. - rho0v.v = _mm256_hadd_pd( rho0v.v, rho0v.v ); - rho1v.v = _mm256_hadd_pd( rho1v.v, rho1v.v ); - rho2v.v = _mm256_hadd_pd( rho2v.v, rho2v.v ); - rho3v.v = _mm256_hadd_pd( rho3v.v, rho3v.v ); - rho4v.v = _mm256_hadd_pd( rho4v.v, rho4v.v ); - rho5v.v = _mm256_hadd_pd( rho5v.v, rho5v.v ); - rho6v.v = _mm256_hadd_pd( rho6v.v, rho6v.v ); - rho7v.v = _mm256_hadd_pd( rho7v.v, rho7v.v ); + // Sum the elements of a given rho?v. This computes the sum of + // elements within lanes and stores the sum to both elements. + rho0v.v = _mm256_hadd_pd( rho0v.v, rho0v.v ); + rho1v.v = _mm256_hadd_pd( rho1v.v, rho1v.v ); + rho2v.v = _mm256_hadd_pd( rho2v.v, rho2v.v ); + rho3v.v = _mm256_hadd_pd( rho3v.v, rho3v.v ); + rho4v.v = _mm256_hadd_pd( rho4v.v, rho4v.v ); + rho5v.v = _mm256_hadd_pd( rho5v.v, rho5v.v ); + rho6v.v = _mm256_hadd_pd( rho6v.v, rho6v.v ); + rho7v.v = _mm256_hadd_pd( rho7v.v, rho7v.v ); - // Manually add the results from above to finish the sum. - rho0 = rho0v.d[0] + rho0v.d[2]; - rho1 = rho1v.d[0] + rho1v.d[2]; - rho2 = rho2v.d[0] + rho2v.d[2]; - rho3 = rho3v.d[0] + rho3v.d[2]; - rho4 = rho4v.d[0] + rho4v.d[2]; - rho5 = rho5v.d[0] + rho5v.d[2]; - rho6 = rho6v.d[0] + rho6v.d[2]; - rho7 = rho7v.d[0] + rho7v.d[2]; + // Manually add the results from above to finish the sum. + rho0 = rho0v.d[0] + rho0v.d[2]; + rho1 = rho1v.d[0] + rho1v.d[2]; + rho2 = rho2v.d[0] + rho2v.d[2]; + rho3 = rho3v.d[0] + rho3v.d[2]; + rho4 = rho4v.d[0] + rho4v.d[2]; + rho5 = rho5v.d[0] + rho5v.d[2]; + rho6 = rho6v.d[0] + rho6v.d[2]; + rho7 = rho7v.d[0] + rho7v.d[2]; #endif - - // If there are leftover iterations, perform them with scalar code. - for ( i = 0; i < m_left ; ++i ) - { - const double x0c = *x0; - - const double a0c = *a0; - const double a1c = *a1; - const double a2c = *a2; - const double a3c = *a3; - const double a4c = *a4; - const double a5c = *a5; - const double a6c = *a6; - const double a7c = *a7; - - rho0 += a0c * x0c; - rho1 += a1c * x0c; - rho2 += a2c * x0c; - rho3 += a3c * x0c; - rho4 += a4c * x0c; - rho5 += a5c * x0c; - rho6 += a6c * x0c; - rho7 += a7c * x0c; - - x0 += incx; - a0 += inca; - a1 += inca; - a2 += inca; - a3 += inca; - a4 += inca; - a5 += inca; - a6 += inca; - a7 += inca; + // Adjust for scalar subproblem. + m -= n_elem_per_reg * n_iter_unroll * m_viter; + a += n_elem_per_reg * n_iter_unroll * m_viter /* * inca */; + x += n_elem_per_reg * n_iter_unroll * m_viter /* * incx */; } + else if ( lda == 1 ) + { + const dim_t n_iter_unroll = 3; + const dim_t n_reg_per_row = 2; // fuse_fac / n_elem_per_reg; + + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + dim_t m_viter = ( m ) / ( n_reg_per_row * n_iter_unroll ); + + // Initialize pointers for x and A. + double* restrict x0 = x; + double* restrict a0 = a; + + // Initialize rho vector accumulators to zero. + v4df_t rho0v; rho0v.v = _mm256_setzero_pd(); + v4df_t rho1v; rho1v.v = _mm256_setzero_pd(); + v4df_t rho2v; rho2v.v = _mm256_setzero_pd(); + v4df_t rho3v; rho3v.v = _mm256_setzero_pd(); + v4df_t rho4v; rho4v.v = _mm256_setzero_pd(); + v4df_t rho5v; rho5v.v = _mm256_setzero_pd(); + + v4df_t x0v, x1v, x2v; + v4df_t a0v, a1v, a2v, a3v, a4v, a5v; + + for ( dim_t i = 0; i < m_viter; ++i ) + { + // Load the input values. + a0v.v = _mm256_loadu_pd( a0 + 0*inca + 0*n_elem_per_reg ); + a1v.v = _mm256_loadu_pd( a0 + 0*inca + 1*n_elem_per_reg ); + a2v.v = _mm256_loadu_pd( a0 + 1*inca + 0*n_elem_per_reg ); + a3v.v = _mm256_loadu_pd( a0 + 1*inca + 1*n_elem_per_reg ); + a4v.v = _mm256_loadu_pd( a0 + 2*inca + 0*n_elem_per_reg ); + a5v.v = _mm256_loadu_pd( a0 + 2*inca + 1*n_elem_per_reg ); + + x0v.v = _mm256_broadcast_sd( x0 + 0*incx ); + x1v.v = _mm256_broadcast_sd( x0 + 1*incx ); + x2v.v = _mm256_broadcast_sd( x0 + 2*incx ); + + // perform : rho?v += a?v * x?v; + rho0v.v = _mm256_fmadd_pd( a0v.v, x0v.v, rho0v.v ); + rho1v.v = _mm256_fmadd_pd( a1v.v, x0v.v, rho1v.v ); + rho2v.v = _mm256_fmadd_pd( a2v.v, x1v.v, rho2v.v ); + rho3v.v = _mm256_fmadd_pd( a3v.v, x1v.v, rho3v.v ); + rho4v.v = _mm256_fmadd_pd( a4v.v, x2v.v, rho4v.v ); + rho5v.v = _mm256_fmadd_pd( a5v.v, x2v.v, rho5v.v ); + + x0 += incx * n_iter_unroll; + a0 += inca * n_iter_unroll; + } + + // Combine the 8 accumulators into one vector register. + rho0v.v = _mm256_add_pd( rho0v.v, rho2v.v ); + rho0v.v = _mm256_add_pd( rho0v.v, rho4v.v ); + rho1v.v = _mm256_add_pd( rho1v.v, rho3v.v ); + rho1v.v = _mm256_add_pd( rho1v.v, rho5v.v ); + + // Write vector components to scalar values. + rho0 = rho0v.d[0]; + rho1 = rho0v.d[1]; + rho2 = rho0v.d[2]; + rho3 = rho0v.d[3]; + rho4 = rho1v.d[0]; + rho5 = rho1v.d[1]; + rho6 = rho1v.d[2]; + rho7 = rho1v.d[3]; + + // Adjust for scalar subproblem. + m -= n_iter_unroll * m_viter; + a += n_iter_unroll * m_viter * inca; + x += n_iter_unroll * m_viter * incx; + } + else + { + // No vectorization possible; use scalar iterations for the entire + // problem. + } + + // Scalar edge case. + { + // Initialize pointers for x and the b_n columns of A (rows of A^T). + double* restrict x0 = x; + double* restrict a0 = a + 0*lda; + double* restrict a1 = a + 1*lda; + double* restrict a2 = a + 2*lda; + double* restrict a3 = a + 3*lda; + double* restrict a4 = a + 4*lda; + double* restrict a5 = a + 5*lda; + double* restrict a6 = a + 6*lda; + double* restrict a7 = a + 7*lda; + + // If there are leftover iterations, perform them with scalar code. + for ( dim_t i = 0; i < m ; ++i ) + { + const double x0c = *x0; + + const double a0c = *a0; + const double a1c = *a1; + const double a2c = *a2; + const double a3c = *a3; + const double a4c = *a4; + const double a5c = *a5; + const double a6c = *a6; + const double a7c = *a7; + + rho0 += a0c * x0c; + rho1 += a1c * x0c; + rho2 += a2c * x0c; + rho3 += a3c * x0c; + rho4 += a4c * x0c; + rho5 += a5c * x0c; + rho6 += a6c * x0c; + rho7 += a7c * x0c; + + x0 += incx; + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + a4 += inca; + a5 += inca; + a6 += inca; + a7 += inca; + } + } + + // Now prepare the final rho values to output/accumulate back into + // the y vector. + + v4df_t rho0v, rho1v, y0v, y1v; // Insert the scalar rho values into a single vector. rho0v.d[0] = rho0; @@ -621,7 +773,7 @@ void bli_ddotxf_zen_int_8 rho1v.d[3] = rho7; // Broadcast the alpha scalar. - alphav.v = _mm256_broadcast_sd( alpha ); + v4df_t alphav; alphav.v = _mm256_broadcast_sd( alpha ); // We know at this point that alpha is nonzero; however, beta may still // be zero. If beta is indeed zero, we must overwrite y rather than scale @@ -636,7 +788,7 @@ void bli_ddotxf_zen_int_8 else { // Broadcast the beta scalar. - betav.v = _mm256_broadcast_sd( beta ); + v4df_t betav; betav.v = _mm256_broadcast_sd( beta ); // Load y. if ( incy == 1 )