mirror of
https://github.com/amd/blis.git
synced 2026-05-11 09:39:59 +00:00
Merge branch 'dev'
This commit is contained in:
@@ -97,8 +97,8 @@ void bli_sgemm_bulldozer_asm_8x8_fma4
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
uint64_t k_iter = k / 4;
|
||||
uint64_t k_left = k % 4;
|
||||
uint64_t k_iter = k / 4;
|
||||
uint64_t k_left = k % 4;
|
||||
|
||||
__asm__ volatile
|
||||
(
|
||||
@@ -865,8 +865,8 @@ void bli_dgemm_bulldozer_asm_4x6_fma4
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
dim_t k_iter = k / 12;
|
||||
dim_t k_left = k % 12;
|
||||
uint64_t k_iter = k / 12;
|
||||
uint64_t k_left = k % 12;
|
||||
|
||||
__asm__
|
||||
(
|
||||
@@ -1076,8 +1076,8 @@ void bli_cgemm_bulldozer_asm_8x4_fma4
|
||||
//void* a_next = bli_auxinfo_next_a( data );
|
||||
void* b_next = bli_auxinfo_next_b( data );
|
||||
|
||||
dim_t k_iter = k / 4;
|
||||
dim_t k_left = k % 4;
|
||||
uint64_t k_iter = k / 4;
|
||||
uint64_t k_left = k % 4;
|
||||
|
||||
__asm__ volatile
|
||||
(
|
||||
@@ -1883,8 +1883,8 @@ void bli_zgemm_bulldozer_asm_4x4_fma4
|
||||
//void* a_next = bli_auxinfo_next_a( data );
|
||||
//void* b_next = bli_auxinfo_next_b( data );
|
||||
|
||||
dim_t k_iter = k / 4;
|
||||
dim_t k_left = k % 4;
|
||||
uint64_t k_iter = k / 4;
|
||||
uint64_t k_left = k % 4;
|
||||
|
||||
__asm__ volatile
|
||||
(
|
||||
|
||||
@@ -265,22 +265,22 @@ ahead*/
|
||||
\
|
||||
PREFETCH_A_L1(n, 0) \
|
||||
\
|
||||
VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+ 0)*8)) \
|
||||
VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+ 1)*8)) \
|
||||
VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+ 0)*4)) \
|
||||
VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+ 1)*4)) \
|
||||
VFMADD231PS(ZMM( 8), ZMM(0), ZMM(3)) \
|
||||
VFMADD231PS(ZMM( 9), ZMM(1), ZMM(3)) \
|
||||
VFMADD231PS(ZMM(10), ZMM(0), ZMM(4)) \
|
||||
VFMADD231PS(ZMM(11), ZMM(1), ZMM(4)) \
|
||||
\
|
||||
VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+ 2)*8)) \
|
||||
VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+ 3)*8)) \
|
||||
VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+ 2)*4)) \
|
||||
VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+ 3)*4)) \
|
||||
VFMADD231PS(ZMM(12), ZMM(0), ZMM(3)) \
|
||||
VFMADD231PS(ZMM(13), ZMM(1), ZMM(3)) \
|
||||
VFMADD231PS(ZMM(14), ZMM(0), ZMM(4)) \
|
||||
VFMADD231PS(ZMM(15), ZMM(1), ZMM(4)) \
|
||||
\
|
||||
VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+ 4)*8)) \
|
||||
VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+ 5)*8)) \
|
||||
VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+ 4)*4)) \
|
||||
VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+ 5)*4)) \
|
||||
VFMADD231PS(ZMM(16), ZMM(0), ZMM(3)) \
|
||||
VFMADD231PS(ZMM(17), ZMM(1), ZMM(3)) \
|
||||
VFMADD231PS(ZMM(18), ZMM(0), ZMM(4)) \
|
||||
@@ -288,29 +288,29 @@ ahead*/
|
||||
\
|
||||
PREFETCH_A_L1(n, 1) \
|
||||
\
|
||||
VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+ 6)*8)) \
|
||||
VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+ 7)*8)) \
|
||||
VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+ 6)*4)) \
|
||||
VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+ 7)*4)) \
|
||||
VFMADD231PS(ZMM(20), ZMM(0), ZMM(3)) \
|
||||
VFMADD231PS(ZMM(21), ZMM(1), ZMM(3)) \
|
||||
VFMADD231PS(ZMM(22), ZMM(0), ZMM(4)) \
|
||||
VFMADD231PS(ZMM(23), ZMM(1), ZMM(4)) \
|
||||
\
|
||||
VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+ 8)*8)) \
|
||||
VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+ 9)*8)) \
|
||||
VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+ 8)*4)) \
|
||||
VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+ 9)*4)) \
|
||||
VFMADD231PS(ZMM(24), ZMM(0), ZMM(3)) \
|
||||
VFMADD231PS(ZMM(25), ZMM(1), ZMM(3)) \
|
||||
VFMADD231PS(ZMM(26), ZMM(0), ZMM(4)) \
|
||||
VFMADD231PS(ZMM(27), ZMM(1), ZMM(4)) \
|
||||
\
|
||||
VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+10)*8)) \
|
||||
VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+11)*8)) \
|
||||
VBROADCASTSS(ZMM(3), MEM(RBX,(12*n+10)*4)) \
|
||||
VBROADCASTSS(ZMM(4), MEM(RBX,(12*n+11)*4)) \
|
||||
VFMADD231PS(ZMM(28), ZMM(0), ZMM(3)) \
|
||||
VFMADD231PS(ZMM(29), ZMM(1), ZMM(3)) \
|
||||
VFMADD231PS(ZMM(30), ZMM(0), ZMM(4)) \
|
||||
VFMADD231PS(ZMM(31), ZMM(1), ZMM(4)) \
|
||||
\
|
||||
VMOVAPD(ZMM(0), MEM(RAX,(16*n+0)*8)) \
|
||||
VMOVAPD(ZMM(1), MEM(RAX,(16*n+8)*8))
|
||||
VMOVAPD(ZMM(0), MEM(RAX,(32*n+0)*4)) \
|
||||
VMOVAPD(ZMM(1), MEM(RAX,(32*n+16)*4))
|
||||
|
||||
//This is an array used for the scatter/gather instructions.
|
||||
static int64_t offsets[16] __attribute__((aligned(64))) =
|
||||
|
||||
@@ -67,42 +67,8 @@ void bli_sdotxf_zen_int_8
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
const dim_t fuse_fac = 8;
|
||||
|
||||
const dim_t n_elem_per_reg = 8;
|
||||
const dim_t n_iter_unroll = 1;
|
||||
|
||||
dim_t i;
|
||||
dim_t m_viter;
|
||||
dim_t m_left;
|
||||
|
||||
float* restrict a0;
|
||||
float* restrict a1;
|
||||
float* restrict a2;
|
||||
float* restrict a3;
|
||||
float* restrict a4;
|
||||
float* restrict a5;
|
||||
float* restrict a6;
|
||||
float* restrict a7;
|
||||
|
||||
float* restrict x0;
|
||||
|
||||
float rho0, rho1, rho2, rho3;
|
||||
float rho4, rho5, rho6, rho7;
|
||||
|
||||
v8sf_t a0v, a1v, a2v, a3v;
|
||||
v8sf_t a4v, a5v, a6v, a7v;
|
||||
|
||||
v8sf_t x0v;
|
||||
|
||||
v8sf_t rho0v, rho1v, rho2v, rho3v;
|
||||
v8sf_t rho4v, rho5v, rho6v, rho7v;
|
||||
|
||||
v8sf_t alphav;
|
||||
v8sf_t betav;
|
||||
v8sf_t y0v;
|
||||
|
||||
v8sf_t onev;
|
||||
const dim_t fuse_fac = 8;
|
||||
const dim_t n_elem_per_reg = 8;
|
||||
|
||||
// If the b_n dimension is zero, y is empty and there is no computation.
|
||||
if ( bli_zero_dim1( b_n ) ) return;
|
||||
@@ -130,7 +96,7 @@ void bli_sdotxf_zen_int_8
|
||||
{
|
||||
sdotxv_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_DOTXV_KER, cntx );
|
||||
|
||||
for ( i = 0; i < b_n; ++i )
|
||||
for ( dim_t i = 0; i < b_n; ++i )
|
||||
{
|
||||
float* a1 = a + (0 )*inca + (i )*lda;
|
||||
float* x1 = x + (0 )*incx;
|
||||
@@ -155,159 +121,265 @@ void bli_sdotxf_zen_int_8
|
||||
// At this point, we know that b_n is exactly equal to the fusing factor.
|
||||
// However, m may not be a multiple of the number of elements per vector.
|
||||
|
||||
// Use the unrolling factor and the number of elements per register
|
||||
// to compute the number of vectorized and leftover iterations.
|
||||
m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll );
|
||||
m_left = ( m ) % ( n_elem_per_reg * n_iter_unroll );
|
||||
// Going forward, we handle two possible storage formats of A explicitly:
|
||||
// (1) A is stored by columns, or (2) A is stored by rows. Either case is
|
||||
// further split into two subproblems along the m dimension:
|
||||
// (a) a vectorized part, starting at m = 0 and ending at any 0 <= m' <= m.
|
||||
// (b) a scalar part, starting at m' and ending at m. If no vectorization
|
||||
// is possible then m' == 0 and thus the scalar part is the entire
|
||||
// problem. If 0 < m', then the a and x pointers and m variable will
|
||||
// be adjusted accordingly for the second subproblem.
|
||||
// Note: since parts (b) for both (1) and (2) are so similar, they are
|
||||
// factored out into one code block after the following conditional, which
|
||||
// distinguishes between (1) and (2).
|
||||
|
||||
// If there is anything that would interfere with our use of contiguous
|
||||
// vector loads/stores, override m_viter and m_left to use scalar code
|
||||
// for all iterations. (NOTE: vector instructions can be used even if
|
||||
// the increment of the output vector, incy, is nonunit.)
|
||||
if ( inca != 1 || incx != 1 )
|
||||
// Intermediate variables to hold the completed dot products
|
||||
float rho0 = 0, rho1 = 0, rho2 = 0, rho3 = 0,
|
||||
rho4 = 0, rho5 = 0, rho6 = 0, rho7 = 0;
|
||||
|
||||
if ( inca == 1 && incx == 1 )
|
||||
{
|
||||
m_viter = 0;
|
||||
m_left = m;
|
||||
}
|
||||
const dim_t n_iter_unroll = 1;
|
||||
|
||||
// Set up pointers for x and the b_n columns of A (rows of A^T).
|
||||
x0 = x;
|
||||
a0 = a + 0*lda;
|
||||
a1 = a + 1*lda;
|
||||
a2 = a + 2*lda;
|
||||
a3 = a + 3*lda;
|
||||
a4 = a + 4*lda;
|
||||
a5 = a + 5*lda;
|
||||
a6 = a + 6*lda;
|
||||
a7 = a + 7*lda;
|
||||
// Use the unrolling factor and the number of elements per register
|
||||
// to compute the number of vectorized and leftover iterations.
|
||||
dim_t m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll );
|
||||
|
||||
// Initialize b_n rho vector accumulators to zero.
|
||||
rho0v.v = _mm256_setzero_ps();
|
||||
rho1v.v = _mm256_setzero_ps();
|
||||
rho2v.v = _mm256_setzero_ps();
|
||||
rho3v.v = _mm256_setzero_ps();
|
||||
rho4v.v = _mm256_setzero_ps();
|
||||
rho5v.v = _mm256_setzero_ps();
|
||||
rho6v.v = _mm256_setzero_ps();
|
||||
rho7v.v = _mm256_setzero_ps();
|
||||
// Set up pointers for x and the b_n columns of A (rows of A^T).
|
||||
float* restrict x0 = x;
|
||||
float* restrict a0 = a + 0*lda;
|
||||
float* restrict a1 = a + 1*lda;
|
||||
float* restrict a2 = a + 2*lda;
|
||||
float* restrict a3 = a + 3*lda;
|
||||
float* restrict a4 = a + 4*lda;
|
||||
float* restrict a5 = a + 5*lda;
|
||||
float* restrict a6 = a + 6*lda;
|
||||
float* restrict a7 = a + 7*lda;
|
||||
|
||||
// If there are vectorized iterations, perform them with vector
|
||||
// instructions.
|
||||
for ( i = 0; i < m_viter; ++i )
|
||||
{
|
||||
// Load the input values.
|
||||
x0v.v = _mm256_loadu_ps( x0 + 0*n_elem_per_reg );
|
||||
// Initialize b_n rho vector accumulators to zero.
|
||||
v8sf_t rho0v; rho0v.v = _mm256_setzero_ps();
|
||||
v8sf_t rho1v; rho1v.v = _mm256_setzero_ps();
|
||||
v8sf_t rho2v; rho2v.v = _mm256_setzero_ps();
|
||||
v8sf_t rho3v; rho3v.v = _mm256_setzero_ps();
|
||||
v8sf_t rho4v; rho4v.v = _mm256_setzero_ps();
|
||||
v8sf_t rho5v; rho5v.v = _mm256_setzero_ps();
|
||||
v8sf_t rho6v; rho6v.v = _mm256_setzero_ps();
|
||||
v8sf_t rho7v; rho7v.v = _mm256_setzero_ps();
|
||||
|
||||
a0v.v = _mm256_loadu_ps( a0 + 0*n_elem_per_reg );
|
||||
a1v.v = _mm256_loadu_ps( a1 + 0*n_elem_per_reg );
|
||||
a2v.v = _mm256_loadu_ps( a2 + 0*n_elem_per_reg );
|
||||
a3v.v = _mm256_loadu_ps( a3 + 0*n_elem_per_reg );
|
||||
a4v.v = _mm256_loadu_ps( a4 + 0*n_elem_per_reg );
|
||||
a5v.v = _mm256_loadu_ps( a5 + 0*n_elem_per_reg );
|
||||
a6v.v = _mm256_loadu_ps( a6 + 0*n_elem_per_reg );
|
||||
a7v.v = _mm256_loadu_ps( a7 + 0*n_elem_per_reg );
|
||||
v8sf_t x0v;
|
||||
v8sf_t a0v, a1v, a2v, a3v, a4v, a5v, a6v, a7v;
|
||||
|
||||
// perform : rho?v += a?v * x0v;
|
||||
rho0v.v = _mm256_fmadd_ps( a0v.v, x0v.v, rho0v.v );
|
||||
rho1v.v = _mm256_fmadd_ps( a1v.v, x0v.v, rho1v.v );
|
||||
rho2v.v = _mm256_fmadd_ps( a2v.v, x0v.v, rho2v.v );
|
||||
rho3v.v = _mm256_fmadd_ps( a3v.v, x0v.v, rho3v.v );
|
||||
rho4v.v = _mm256_fmadd_ps( a4v.v, x0v.v, rho4v.v );
|
||||
rho5v.v = _mm256_fmadd_ps( a5v.v, x0v.v, rho5v.v );
|
||||
rho6v.v = _mm256_fmadd_ps( a6v.v, x0v.v, rho6v.v );
|
||||
rho7v.v = _mm256_fmadd_ps( a7v.v, x0v.v, rho7v.v );
|
||||
// If there are vectorized iterations, perform them with vector
|
||||
// instructions.
|
||||
for ( dim_t i = 0; i < m_viter; ++i )
|
||||
{
|
||||
// Load the input values.
|
||||
x0v.v = _mm256_loadu_ps( x0 + 0*n_elem_per_reg );
|
||||
|
||||
x0 += n_elem_per_reg * n_iter_unroll;
|
||||
a0 += n_elem_per_reg * n_iter_unroll;
|
||||
a1 += n_elem_per_reg * n_iter_unroll;
|
||||
a2 += n_elem_per_reg * n_iter_unroll;
|
||||
a3 += n_elem_per_reg * n_iter_unroll;
|
||||
a4 += n_elem_per_reg * n_iter_unroll;
|
||||
a5 += n_elem_per_reg * n_iter_unroll;
|
||||
a6 += n_elem_per_reg * n_iter_unroll;
|
||||
a7 += n_elem_per_reg * n_iter_unroll;
|
||||
}
|
||||
a0v.v = _mm256_loadu_ps( a0 + 0*n_elem_per_reg );
|
||||
a1v.v = _mm256_loadu_ps( a1 + 0*n_elem_per_reg );
|
||||
a2v.v = _mm256_loadu_ps( a2 + 0*n_elem_per_reg );
|
||||
a3v.v = _mm256_loadu_ps( a3 + 0*n_elem_per_reg );
|
||||
a4v.v = _mm256_loadu_ps( a4 + 0*n_elem_per_reg );
|
||||
a5v.v = _mm256_loadu_ps( a5 + 0*n_elem_per_reg );
|
||||
a6v.v = _mm256_loadu_ps( a6 + 0*n_elem_per_reg );
|
||||
a7v.v = _mm256_loadu_ps( a7 + 0*n_elem_per_reg );
|
||||
|
||||
// perform: rho?v += a?v * x0v;
|
||||
rho0v.v = _mm256_fmadd_ps( a0v.v, x0v.v, rho0v.v );
|
||||
rho1v.v = _mm256_fmadd_ps( a1v.v, x0v.v, rho1v.v );
|
||||
rho2v.v = _mm256_fmadd_ps( a2v.v, x0v.v, rho2v.v );
|
||||
rho3v.v = _mm256_fmadd_ps( a3v.v, x0v.v, rho3v.v );
|
||||
rho4v.v = _mm256_fmadd_ps( a4v.v, x0v.v, rho4v.v );
|
||||
rho5v.v = _mm256_fmadd_ps( a5v.v, x0v.v, rho5v.v );
|
||||
rho6v.v = _mm256_fmadd_ps( a6v.v, x0v.v, rho6v.v );
|
||||
rho7v.v = _mm256_fmadd_ps( a7v.v, x0v.v, rho7v.v );
|
||||
|
||||
x0 += n_elem_per_reg * n_iter_unroll;
|
||||
a0 += n_elem_per_reg * n_iter_unroll;
|
||||
a1 += n_elem_per_reg * n_iter_unroll;
|
||||
a2 += n_elem_per_reg * n_iter_unroll;
|
||||
a3 += n_elem_per_reg * n_iter_unroll;
|
||||
a4 += n_elem_per_reg * n_iter_unroll;
|
||||
a5 += n_elem_per_reg * n_iter_unroll;
|
||||
a6 += n_elem_per_reg * n_iter_unroll;
|
||||
a7 += n_elem_per_reg * n_iter_unroll;
|
||||
}
|
||||
|
||||
#if 0
|
||||
rho0 += rho0v.f[0] + rho0v.f[1] + rho0v.f[2] + rho0v.f[3] +
|
||||
rho0v.f[4] + rho0v.f[5] + rho0v.f[6] + rho0v.f[7];
|
||||
rho1 += rho1v.f[0] + rho1v.f[1] + rho1v.f[2] + rho1v.f[3] +
|
||||
rho1v.f[4] + rho1v.f[5] + rho1v.f[6] + rho1v.f[7];
|
||||
rho2 += rho2v.f[0] + rho2v.f[1] + rho2v.f[2] + rho2v.f[3] +
|
||||
rho2v.f[4] + rho2v.f[5] + rho2v.f[6] + rho2v.f[7];
|
||||
rho3 += rho3v.f[0] + rho3v.f[1] + rho3v.f[2] + rho3v.f[3] +
|
||||
rho3v.f[4] + rho3v.f[5] + rho3v.f[6] + rho3v.f[7];
|
||||
rho4 += rho4v.f[0] + rho4v.f[1] + rho4v.f[2] + rho4v.f[3] +
|
||||
rho4v.f[4] + rho4v.f[5] + rho4v.f[6] + rho4v.f[7];
|
||||
rho5 += rho5v.f[0] + rho5v.f[1] + rho5v.f[2] + rho5v.f[3] +
|
||||
rho5v.f[4] + rho5v.f[5] + rho5v.f[6] + rho5v.f[7];
|
||||
rho6 += rho6v.f[0] + rho6v.f[1] + rho6v.f[2] + rho6v.f[3] +
|
||||
rho6v.f[4] + rho6v.f[5] + rho6v.f[6] + rho6v.f[7];
|
||||
rho7 += rho7v.f[0] + rho7v.f[1] + rho7v.f[2] + rho7v.f[3] +
|
||||
rho7v.f[4] + rho7v.f[5] + rho7v.f[6] + rho7v.f[7];
|
||||
rho0 += rho0v.f[0] + rho0v.f[1] + rho0v.f[2] + rho0v.f[3] +
|
||||
rho0v.f[4] + rho0v.f[5] + rho0v.f[6] + rho0v.f[7];
|
||||
rho1 += rho1v.f[0] + rho1v.f[1] + rho1v.f[2] + rho1v.f[3] +
|
||||
rho1v.f[4] + rho1v.f[5] + rho1v.f[6] + rho1v.f[7];
|
||||
rho2 += rho2v.f[0] + rho2v.f[1] + rho2v.f[2] + rho2v.f[3] +
|
||||
rho2v.f[4] + rho2v.f[5] + rho2v.f[6] + rho2v.f[7];
|
||||
rho3 += rho3v.f[0] + rho3v.f[1] + rho3v.f[2] + rho3v.f[3] +
|
||||
rho3v.f[4] + rho3v.f[5] + rho3v.f[6] + rho3v.f[7];
|
||||
rho4 += rho4v.f[0] + rho4v.f[1] + rho4v.f[2] + rho4v.f[3] +
|
||||
rho4v.f[4] + rho4v.f[5] + rho4v.f[6] + rho4v.f[7];
|
||||
rho5 += rho5v.f[0] + rho5v.f[1] + rho5v.f[2] + rho5v.f[3] +
|
||||
rho5v.f[4] + rho5v.f[5] + rho5v.f[6] + rho5v.f[7];
|
||||
rho6 += rho6v.f[0] + rho6v.f[1] + rho6v.f[2] + rho6v.f[3] +
|
||||
rho6v.f[4] + rho6v.f[5] + rho6v.f[6] + rho6v.f[7];
|
||||
rho7 += rho7v.f[0] + rho7v.f[1] + rho7v.f[2] + rho7v.f[3] +
|
||||
rho7v.f[4] + rho7v.f[5] + rho7v.f[6] + rho7v.f[7];
|
||||
#else
|
||||
// Now we need to sum the elements within each vector.
|
||||
// Now we need to sum the elements within each vector.
|
||||
|
||||
onev.v = _mm256_set1_ps( 1.0f );
|
||||
v8sf_t onev; onev.v = _mm256_set1_ps( 1.0f );
|
||||
|
||||
// Sum the elements of a given rho?v by dotting it with 1. The '1' in
|
||||
// '0xf1' stores the sum of the upper four and lower four values to
|
||||
// the low elements of each lane: elements 4 and 0, respectively. (The
|
||||
// 'f' in '0xf1' means include all four elements of each lane in the
|
||||
// summation.)
|
||||
rho0v.v = _mm256_dp_ps( rho0v.v, onev.v, 0xf1 );
|
||||
rho1v.v = _mm256_dp_ps( rho1v.v, onev.v, 0xf1 );
|
||||
rho2v.v = _mm256_dp_ps( rho2v.v, onev.v, 0xf1 );
|
||||
rho3v.v = _mm256_dp_ps( rho3v.v, onev.v, 0xf1 );
|
||||
rho4v.v = _mm256_dp_ps( rho4v.v, onev.v, 0xf1 );
|
||||
rho5v.v = _mm256_dp_ps( rho5v.v, onev.v, 0xf1 );
|
||||
rho6v.v = _mm256_dp_ps( rho6v.v, onev.v, 0xf1 );
|
||||
rho7v.v = _mm256_dp_ps( rho7v.v, onev.v, 0xf1 );
|
||||
// Sum the elements of a given rho?v by dotting it with 1. The '1' in
|
||||
// '0xf1' stores the sum of the upper four and lower four values to
|
||||
// the low elements of each lane: elements 4 and 0, respectively. (The
|
||||
// 'f' in '0xf1' means include all four elements of each lane in the
|
||||
// summation.)
|
||||
rho0v.v = _mm256_dp_ps( rho0v.v, onev.v, 0xf1 );
|
||||
rho1v.v = _mm256_dp_ps( rho1v.v, onev.v, 0xf1 );
|
||||
rho2v.v = _mm256_dp_ps( rho2v.v, onev.v, 0xf1 );
|
||||
rho3v.v = _mm256_dp_ps( rho3v.v, onev.v, 0xf1 );
|
||||
rho4v.v = _mm256_dp_ps( rho4v.v, onev.v, 0xf1 );
|
||||
rho5v.v = _mm256_dp_ps( rho5v.v, onev.v, 0xf1 );
|
||||
rho6v.v = _mm256_dp_ps( rho6v.v, onev.v, 0xf1 );
|
||||
rho7v.v = _mm256_dp_ps( rho7v.v, onev.v, 0xf1 );
|
||||
|
||||
// Manually add the results from above to finish the sum.
|
||||
rho0 = rho0v.f[0] + rho0v.f[4];
|
||||
rho1 = rho1v.f[0] + rho1v.f[4];
|
||||
rho2 = rho2v.f[0] + rho2v.f[4];
|
||||
rho3 = rho3v.f[0] + rho3v.f[4];
|
||||
rho4 = rho4v.f[0] + rho4v.f[4];
|
||||
rho5 = rho5v.f[0] + rho5v.f[4];
|
||||
rho6 = rho6v.f[0] + rho6v.f[4];
|
||||
rho7 = rho7v.f[0] + rho7v.f[4];
|
||||
// Manually add the results from above to finish the sum.
|
||||
rho0 = rho0v.f[0] + rho0v.f[4];
|
||||
rho1 = rho1v.f[0] + rho1v.f[4];
|
||||
rho2 = rho2v.f[0] + rho2v.f[4];
|
||||
rho3 = rho3v.f[0] + rho3v.f[4];
|
||||
rho4 = rho4v.f[0] + rho4v.f[4];
|
||||
rho5 = rho5v.f[0] + rho5v.f[4];
|
||||
rho6 = rho6v.f[0] + rho6v.f[4];
|
||||
rho7 = rho7v.f[0] + rho7v.f[4];
|
||||
#endif
|
||||
|
||||
// If there are leftover iterations, perform them with scalar code.
|
||||
for ( i = 0; i < m_left ; ++i )
|
||||
{
|
||||
const float x0c = *x0;
|
||||
|
||||
const float a0c = *a0;
|
||||
const float a1c = *a1;
|
||||
const float a2c = *a2;
|
||||
const float a3c = *a3;
|
||||
const float a4c = *a4;
|
||||
const float a5c = *a5;
|
||||
const float a6c = *a6;
|
||||
const float a7c = *a7;
|
||||
|
||||
rho0 += a0c * x0c;
|
||||
rho1 += a1c * x0c;
|
||||
rho2 += a2c * x0c;
|
||||
rho3 += a3c * x0c;
|
||||
rho4 += a4c * x0c;
|
||||
rho5 += a5c * x0c;
|
||||
rho6 += a6c * x0c;
|
||||
rho7 += a7c * x0c;
|
||||
|
||||
x0 += incx;
|
||||
a0 += inca;
|
||||
a1 += inca;
|
||||
a2 += inca;
|
||||
a3 += inca;
|
||||
a4 += inca;
|
||||
a5 += inca;
|
||||
a6 += inca;
|
||||
a7 += inca;
|
||||
// Adjust for scalar subproblem.
|
||||
m -= n_elem_per_reg * n_iter_unroll * m_viter;
|
||||
a += n_elem_per_reg * n_iter_unroll * m_viter /* * inca */;
|
||||
x += n_elem_per_reg * n_iter_unroll * m_viter /* * incx */;
|
||||
}
|
||||
else if ( lda == 1 )
|
||||
{
|
||||
const dim_t n_iter_unroll = 4;
|
||||
|
||||
// Use the unrolling factor and the number of elements per register
|
||||
// to compute the number of vectorized and leftover iterations.
|
||||
dim_t m_viter = ( m ) / ( n_iter_unroll );
|
||||
|
||||
// Initialize pointers for x and A.
|
||||
float* restrict x0 = x;
|
||||
float* restrict a0 = a;
|
||||
|
||||
// Initialize rho vector accumulators to zero.
|
||||
v8sf_t rho0v; rho0v.v = _mm256_setzero_ps();
|
||||
v8sf_t rho1v; rho1v.v = _mm256_setzero_ps();
|
||||
v8sf_t rho2v; rho2v.v = _mm256_setzero_ps();
|
||||
v8sf_t rho3v; rho3v.v = _mm256_setzero_ps();
|
||||
|
||||
v8sf_t x0v, x1v, x2v, x3v;
|
||||
v8sf_t a0v, a1v, a2v, a3v;
|
||||
|
||||
for ( dim_t i = 0; i < m_viter; ++i )
|
||||
{
|
||||
// Load the input values.
|
||||
a0v.v = _mm256_loadu_ps( a0 + 0*inca );
|
||||
a1v.v = _mm256_loadu_ps( a0 + 1*inca );
|
||||
a2v.v = _mm256_loadu_ps( a0 + 2*inca );
|
||||
a3v.v = _mm256_loadu_ps( a0 + 3*inca );
|
||||
|
||||
x0v.v = _mm256_broadcast_ss( x0 + 0*incx );
|
||||
x1v.v = _mm256_broadcast_ss( x0 + 1*incx );
|
||||
x2v.v = _mm256_broadcast_ss( x0 + 2*incx );
|
||||
x3v.v = _mm256_broadcast_ss( x0 + 3*incx );
|
||||
|
||||
// perform : rho?v += a?v * x?v;
|
||||
rho0v.v = _mm256_fmadd_ps( a0v.v, x0v.v, rho0v.v );
|
||||
rho1v.v = _mm256_fmadd_ps( a1v.v, x1v.v, rho1v.v );
|
||||
rho2v.v = _mm256_fmadd_ps( a2v.v, x2v.v, rho2v.v );
|
||||
rho3v.v = _mm256_fmadd_ps( a3v.v, x3v.v, rho3v.v );
|
||||
|
||||
x0 += incx * n_iter_unroll;
|
||||
a0 += inca * n_iter_unroll;
|
||||
}
|
||||
|
||||
// Combine the 8 accumulators into one vector register.
|
||||
rho0v.v = _mm256_add_ps( rho0v.v, rho1v.v );
|
||||
rho2v.v = _mm256_add_ps( rho2v.v, rho3v.v );
|
||||
rho0v.v = _mm256_add_ps( rho0v.v, rho2v.v );
|
||||
|
||||
// Write vector components to scalar values.
|
||||
rho0 = rho0v.f[0];
|
||||
rho1 = rho0v.f[1];
|
||||
rho2 = rho0v.f[2];
|
||||
rho3 = rho0v.f[3];
|
||||
rho4 = rho0v.f[4];
|
||||
rho5 = rho0v.f[5];
|
||||
rho6 = rho0v.f[6];
|
||||
rho7 = rho0v.f[7];
|
||||
|
||||
// Adjust for scalar subproblem.
|
||||
m -= n_iter_unroll * m_viter;
|
||||
a += n_iter_unroll * m_viter * inca;
|
||||
x += n_iter_unroll * m_viter * incx;
|
||||
}
|
||||
else
|
||||
{
|
||||
// No vectorization possible; use scalar iterations for the entire
|
||||
// problem.
|
||||
}
|
||||
|
||||
// Scalar edge case.
|
||||
{
|
||||
// Initialize pointers for x and the b_n columns of A (rows of A^T).
|
||||
float* restrict x0 = x;
|
||||
float* restrict a0 = a + 0*lda;
|
||||
float* restrict a1 = a + 1*lda;
|
||||
float* restrict a2 = a + 2*lda;
|
||||
float* restrict a3 = a + 3*lda;
|
||||
float* restrict a4 = a + 4*lda;
|
||||
float* restrict a5 = a + 5*lda;
|
||||
float* restrict a6 = a + 6*lda;
|
||||
float* restrict a7 = a + 7*lda;
|
||||
|
||||
// If there are leftover iterations, perform them with scalar code.
|
||||
for ( dim_t i = 0; i < m ; ++i )
|
||||
{
|
||||
const float x0c = *x0;
|
||||
|
||||
const float a0c = *a0;
|
||||
const float a1c = *a1;
|
||||
const float a2c = *a2;
|
||||
const float a3c = *a3;
|
||||
const float a4c = *a4;
|
||||
const float a5c = *a5;
|
||||
const float a6c = *a6;
|
||||
const float a7c = *a7;
|
||||
|
||||
rho0 += a0c * x0c;
|
||||
rho1 += a1c * x0c;
|
||||
rho2 += a2c * x0c;
|
||||
rho3 += a3c * x0c;
|
||||
rho4 += a4c * x0c;
|
||||
rho5 += a5c * x0c;
|
||||
rho6 += a6c * x0c;
|
||||
rho7 += a7c * x0c;
|
||||
|
||||
x0 += incx;
|
||||
a0 += inca;
|
||||
a1 += inca;
|
||||
a2 += inca;
|
||||
a3 += inca;
|
||||
a4 += inca;
|
||||
a5 += inca;
|
||||
a6 += inca;
|
||||
a7 += inca;
|
||||
}
|
||||
}
|
||||
|
||||
// Now prepare the final rho values to output/accumulate back into
|
||||
// the y vector.
|
||||
|
||||
v8sf_t rho0v, y0v;
|
||||
|
||||
// Insert the scalar rho values into a single vector.
|
||||
rho0v.f[0] = rho0;
|
||||
@@ -320,7 +392,7 @@ void bli_sdotxf_zen_int_8
|
||||
rho0v.f[7] = rho7;
|
||||
|
||||
// Broadcast the alpha scalar.
|
||||
alphav.v = _mm256_broadcast_ss( alpha );
|
||||
v8sf_t alphav; alphav.v = _mm256_broadcast_ss( alpha );
|
||||
|
||||
// We know at this point that alpha is nonzero; however, beta may still
|
||||
// be zero. If beta is indeed zero, we must overwrite y rather than scale
|
||||
@@ -334,7 +406,7 @@ void bli_sdotxf_zen_int_8
|
||||
else
|
||||
{
|
||||
// Broadcast the beta scalar.
|
||||
betav.v = _mm256_broadcast_ss( beta );
|
||||
v8sf_t betav; betav.v = _mm256_broadcast_ss( beta );
|
||||
|
||||
// Load y.
|
||||
if ( incy == 1 )
|
||||
@@ -386,39 +458,7 @@ void bli_ddotxf_zen_int_8
|
||||
)
|
||||
{
|
||||
const dim_t fuse_fac = 8;
|
||||
|
||||
const dim_t n_elem_per_reg = 4;
|
||||
const dim_t n_iter_unroll = 1;
|
||||
|
||||
dim_t i;
|
||||
dim_t m_viter;
|
||||
dim_t m_left;
|
||||
|
||||
double* restrict a0;
|
||||
double* restrict a1;
|
||||
double* restrict a2;
|
||||
double* restrict a3;
|
||||
double* restrict a4;
|
||||
double* restrict a5;
|
||||
double* restrict a6;
|
||||
double* restrict a7;
|
||||
|
||||
double* restrict x0;
|
||||
|
||||
double rho0, rho1, rho2, rho3;
|
||||
double rho4, rho5, rho6, rho7;
|
||||
|
||||
v4df_t a0v, a1v, a2v, a3v;
|
||||
v4df_t a4v, a5v, a6v, a7v;
|
||||
|
||||
v4df_t x0v;
|
||||
|
||||
v4df_t rho0v, rho1v, rho2v, rho3v;
|
||||
v4df_t rho4v, rho5v, rho6v, rho7v;
|
||||
|
||||
v4df_t alphav;
|
||||
v4df_t betav;
|
||||
v4df_t y0v, y1v;
|
||||
|
||||
// If the b_n dimension is zero, y is empty and there is no computation.
|
||||
if ( bli_zero_dim1( b_n ) ) return;
|
||||
@@ -446,7 +486,7 @@ void bli_ddotxf_zen_int_8
|
||||
{
|
||||
ddotxv_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_DOTXV_KER, cntx );
|
||||
|
||||
for ( i = 0; i < b_n; ++i )
|
||||
for ( dim_t i = 0; i < b_n; ++i )
|
||||
{
|
||||
double* a1 = a + (0 )*inca + (i )*lda;
|
||||
double* x1 = x + (0 )*incx;
|
||||
@@ -471,144 +511,256 @@ void bli_ddotxf_zen_int_8
|
||||
// At this point, we know that b_n is exactly equal to the fusing factor.
|
||||
// However, m may not be a multiple of the number of elements per vector.
|
||||
|
||||
// Use the unrolling factor and the number of elements per register
|
||||
// to compute the number of vectorized and leftover iterations.
|
||||
m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll );
|
||||
m_left = ( m ) % ( n_elem_per_reg * n_iter_unroll );
|
||||
// Going forward, we handle two possible storage formats of A explicitly:
|
||||
// (1) A is stored by columns, or (2) A is stored by rows. Either case is
|
||||
// further split into two subproblems along the m dimension:
|
||||
// (a) a vectorized part, starting at m = 0 and ending at any 0 <= m' <= m.
|
||||
// (b) a scalar part, starting at m' and ending at m. If no vectorization
|
||||
// is possible then m' == 0 and thus the scalar part is the entire
|
||||
// problem. If 0 < m', then the a and x pointers and m variable will
|
||||
// be adjusted accordingly for the second subproblem.
|
||||
// Note: since parts (b) for both (1) and (2) are so similar, they are
|
||||
// factored out into one code block after the following conditional, which
|
||||
// distinguishes between (1) and (2).
|
||||
|
||||
// If there is anything that would interfere with our use of contiguous
|
||||
// vector loads/stores, override m_viter and m_left to use scalar code
|
||||
// for all iterations. (NOTE: vector instructions can be used even if
|
||||
// the increment of the output vector, incy, is nonunit.)
|
||||
if ( inca != 1 || incx != 1 )
|
||||
// Intermediate variables to hold the completed dot products
|
||||
double rho0 = 0, rho1 = 0, rho2 = 0, rho3 = 0,
|
||||
rho4 = 0, rho5 = 0, rho6 = 0, rho7 = 0;
|
||||
|
||||
if ( inca == 1 && incx == 1 )
|
||||
{
|
||||
m_viter = 0;
|
||||
m_left = m;
|
||||
}
|
||||
const dim_t n_iter_unroll = 1;
|
||||
|
||||
// Set up pointers for x and the b_n columns of A (rows of A^T).
|
||||
x0 = x;
|
||||
a0 = a + 0*lda;
|
||||
a1 = a + 1*lda;
|
||||
a2 = a + 2*lda;
|
||||
a3 = a + 3*lda;
|
||||
a4 = a + 4*lda;
|
||||
a5 = a + 5*lda;
|
||||
a6 = a + 6*lda;
|
||||
a7 = a + 7*lda;
|
||||
// Use the unrolling factor and the number of elements per register
|
||||
// to compute the number of vectorized and leftover iterations.
|
||||
dim_t m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll );
|
||||
|
||||
// Initialize b_n rho vector accumulators to zero.
|
||||
rho0v.v = _mm256_setzero_pd();
|
||||
rho1v.v = _mm256_setzero_pd();
|
||||
rho2v.v = _mm256_setzero_pd();
|
||||
rho3v.v = _mm256_setzero_pd();
|
||||
rho4v.v = _mm256_setzero_pd();
|
||||
rho5v.v = _mm256_setzero_pd();
|
||||
rho6v.v = _mm256_setzero_pd();
|
||||
rho7v.v = _mm256_setzero_pd();
|
||||
// Set up pointers for x and the b_n columns of A (rows of A^T).
|
||||
double* restrict x0 = x;
|
||||
double* restrict a0 = a + 0*lda;
|
||||
double* restrict a1 = a + 1*lda;
|
||||
double* restrict a2 = a + 2*lda;
|
||||
double* restrict a3 = a + 3*lda;
|
||||
double* restrict a4 = a + 4*lda;
|
||||
double* restrict a5 = a + 5*lda;
|
||||
double* restrict a6 = a + 6*lda;
|
||||
double* restrict a7 = a + 7*lda;
|
||||
|
||||
// If there are vectorized iterations, perform them with vector
|
||||
// instructions.
|
||||
for ( i = 0; i < m_viter; ++i )
|
||||
{
|
||||
// Load the input values.
|
||||
x0v.v = _mm256_loadu_pd( x0 + 0*n_elem_per_reg );
|
||||
// Initialize b_n rho vector accumulators to zero.
|
||||
v4df_t rho0v; rho0v.v = _mm256_setzero_pd();
|
||||
v4df_t rho1v; rho1v.v = _mm256_setzero_pd();
|
||||
v4df_t rho2v; rho2v.v = _mm256_setzero_pd();
|
||||
v4df_t rho3v; rho3v.v = _mm256_setzero_pd();
|
||||
v4df_t rho4v; rho4v.v = _mm256_setzero_pd();
|
||||
v4df_t rho5v; rho5v.v = _mm256_setzero_pd();
|
||||
v4df_t rho6v; rho6v.v = _mm256_setzero_pd();
|
||||
v4df_t rho7v; rho7v.v = _mm256_setzero_pd();
|
||||
|
||||
a0v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
|
||||
a1v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
|
||||
a2v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg );
|
||||
a3v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg );
|
||||
a4v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg );
|
||||
a5v.v = _mm256_loadu_pd( a5 + 0*n_elem_per_reg );
|
||||
a6v.v = _mm256_loadu_pd( a6 + 0*n_elem_per_reg );
|
||||
a7v.v = _mm256_loadu_pd( a7 + 0*n_elem_per_reg );
|
||||
v4df_t x0v;
|
||||
v4df_t a0v, a1v, a2v, a3v, a4v, a5v, a6v, a7v;
|
||||
|
||||
// perform : rho?v += a?v * x0v;
|
||||
rho0v.v = _mm256_fmadd_pd( a0v.v, x0v.v, rho0v.v );
|
||||
rho1v.v = _mm256_fmadd_pd( a1v.v, x0v.v, rho1v.v );
|
||||
rho2v.v = _mm256_fmadd_pd( a2v.v, x0v.v, rho2v.v );
|
||||
rho3v.v = _mm256_fmadd_pd( a3v.v, x0v.v, rho3v.v );
|
||||
rho4v.v = _mm256_fmadd_pd( a4v.v, x0v.v, rho4v.v );
|
||||
rho5v.v = _mm256_fmadd_pd( a5v.v, x0v.v, rho5v.v );
|
||||
rho6v.v = _mm256_fmadd_pd( a6v.v, x0v.v, rho6v.v );
|
||||
rho7v.v = _mm256_fmadd_pd( a7v.v, x0v.v, rho7v.v );
|
||||
// If there are vectorized iterations, perform them with vector
|
||||
// instructions.
|
||||
for ( dim_t i = 0; i < m_viter; ++i )
|
||||
{
|
||||
// Load the input values.
|
||||
x0v.v = _mm256_loadu_pd( x0 + 0*n_elem_per_reg );
|
||||
|
||||
x0 += n_elem_per_reg * n_iter_unroll;
|
||||
a0 += n_elem_per_reg * n_iter_unroll;
|
||||
a1 += n_elem_per_reg * n_iter_unroll;
|
||||
a2 += n_elem_per_reg * n_iter_unroll;
|
||||
a3 += n_elem_per_reg * n_iter_unroll;
|
||||
a4 += n_elem_per_reg * n_iter_unroll;
|
||||
a5 += n_elem_per_reg * n_iter_unroll;
|
||||
a6 += n_elem_per_reg * n_iter_unroll;
|
||||
a7 += n_elem_per_reg * n_iter_unroll;
|
||||
}
|
||||
a0v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
|
||||
a1v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
|
||||
a2v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg );
|
||||
a3v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg );
|
||||
a4v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg );
|
||||
a5v.v = _mm256_loadu_pd( a5 + 0*n_elem_per_reg );
|
||||
a6v.v = _mm256_loadu_pd( a6 + 0*n_elem_per_reg );
|
||||
a7v.v = _mm256_loadu_pd( a7 + 0*n_elem_per_reg );
|
||||
|
||||
// perform: rho?v += a?v * x0v;
|
||||
rho0v.v = _mm256_fmadd_pd( a0v.v, x0v.v, rho0v.v );
|
||||
rho1v.v = _mm256_fmadd_pd( a1v.v, x0v.v, rho1v.v );
|
||||
rho2v.v = _mm256_fmadd_pd( a2v.v, x0v.v, rho2v.v );
|
||||
rho3v.v = _mm256_fmadd_pd( a3v.v, x0v.v, rho3v.v );
|
||||
rho4v.v = _mm256_fmadd_pd( a4v.v, x0v.v, rho4v.v );
|
||||
rho5v.v = _mm256_fmadd_pd( a5v.v, x0v.v, rho5v.v );
|
||||
rho6v.v = _mm256_fmadd_pd( a6v.v, x0v.v, rho6v.v );
|
||||
rho7v.v = _mm256_fmadd_pd( a7v.v, x0v.v, rho7v.v );
|
||||
|
||||
x0 += n_elem_per_reg * n_iter_unroll;
|
||||
a0 += n_elem_per_reg * n_iter_unroll;
|
||||
a1 += n_elem_per_reg * n_iter_unroll;
|
||||
a2 += n_elem_per_reg * n_iter_unroll;
|
||||
a3 += n_elem_per_reg * n_iter_unroll;
|
||||
a4 += n_elem_per_reg * n_iter_unroll;
|
||||
a5 += n_elem_per_reg * n_iter_unroll;
|
||||
a6 += n_elem_per_reg * n_iter_unroll;
|
||||
a7 += n_elem_per_reg * n_iter_unroll;
|
||||
}
|
||||
|
||||
#if 0
|
||||
rho0 += rho0v.d[0] + rho0v.d[1] + rho0v.d[2] + rho0v.d[3];
|
||||
rho1 += rho1v.d[0] + rho1v.d[1] + rho1v.d[2] + rho1v.d[3];
|
||||
rho2 += rho2v.d[0] + rho2v.d[1] + rho2v.d[2] + rho2v.d[3];
|
||||
rho3 += rho3v.d[0] + rho3v.d[1] + rho3v.d[2] + rho3v.d[3];
|
||||
rho4 += rho4v.d[0] + rho4v.d[1] + rho4v.d[2] + rho4v.d[3];
|
||||
rho5 += rho5v.d[0] + rho5v.d[1] + rho5v.d[2] + rho5v.d[3];
|
||||
rho6 += rho6v.d[0] + rho6v.d[1] + rho6v.d[2] + rho6v.d[3];
|
||||
rho7 += rho7v.d[0] + rho7v.d[1] + rho7v.d[2] + rho7v.d[3];
|
||||
rho0 += rho0v.d[0] + rho0v.d[1] + rho0v.d[2] + rho0v.d[3];
|
||||
rho1 += rho1v.d[0] + rho1v.d[1] + rho1v.d[2] + rho1v.d[3];
|
||||
rho2 += rho2v.d[0] + rho2v.d[1] + rho2v.d[2] + rho2v.d[3];
|
||||
rho3 += rho3v.d[0] + rho3v.d[1] + rho3v.d[2] + rho3v.d[3];
|
||||
rho4 += rho4v.d[0] + rho4v.d[1] + rho4v.d[2] + rho4v.d[3];
|
||||
rho5 += rho5v.d[0] + rho5v.d[1] + rho5v.d[2] + rho5v.d[3];
|
||||
rho6 += rho6v.d[0] + rho6v.d[1] + rho6v.d[2] + rho6v.d[3];
|
||||
rho7 += rho7v.d[0] + rho7v.d[1] + rho7v.d[2] + rho7v.d[3];
|
||||
#else
|
||||
// Sum the elements of a given rho?v. This computes the sum of
|
||||
// elements within lanes and stores the sum to both elements.
|
||||
rho0v.v = _mm256_hadd_pd( rho0v.v, rho0v.v );
|
||||
rho1v.v = _mm256_hadd_pd( rho1v.v, rho1v.v );
|
||||
rho2v.v = _mm256_hadd_pd( rho2v.v, rho2v.v );
|
||||
rho3v.v = _mm256_hadd_pd( rho3v.v, rho3v.v );
|
||||
rho4v.v = _mm256_hadd_pd( rho4v.v, rho4v.v );
|
||||
rho5v.v = _mm256_hadd_pd( rho5v.v, rho5v.v );
|
||||
rho6v.v = _mm256_hadd_pd( rho6v.v, rho6v.v );
|
||||
rho7v.v = _mm256_hadd_pd( rho7v.v, rho7v.v );
|
||||
// Sum the elements of a given rho?v. This computes the sum of
|
||||
// elements within lanes and stores the sum to both elements.
|
||||
rho0v.v = _mm256_hadd_pd( rho0v.v, rho0v.v );
|
||||
rho1v.v = _mm256_hadd_pd( rho1v.v, rho1v.v );
|
||||
rho2v.v = _mm256_hadd_pd( rho2v.v, rho2v.v );
|
||||
rho3v.v = _mm256_hadd_pd( rho3v.v, rho3v.v );
|
||||
rho4v.v = _mm256_hadd_pd( rho4v.v, rho4v.v );
|
||||
rho5v.v = _mm256_hadd_pd( rho5v.v, rho5v.v );
|
||||
rho6v.v = _mm256_hadd_pd( rho6v.v, rho6v.v );
|
||||
rho7v.v = _mm256_hadd_pd( rho7v.v, rho7v.v );
|
||||
|
||||
// Manually add the results from above to finish the sum.
|
||||
rho0 = rho0v.d[0] + rho0v.d[2];
|
||||
rho1 = rho1v.d[0] + rho1v.d[2];
|
||||
rho2 = rho2v.d[0] + rho2v.d[2];
|
||||
rho3 = rho3v.d[0] + rho3v.d[2];
|
||||
rho4 = rho4v.d[0] + rho4v.d[2];
|
||||
rho5 = rho5v.d[0] + rho5v.d[2];
|
||||
rho6 = rho6v.d[0] + rho6v.d[2];
|
||||
rho7 = rho7v.d[0] + rho7v.d[2];
|
||||
// Manually add the results from above to finish the sum.
|
||||
rho0 = rho0v.d[0] + rho0v.d[2];
|
||||
rho1 = rho1v.d[0] + rho1v.d[2];
|
||||
rho2 = rho2v.d[0] + rho2v.d[2];
|
||||
rho3 = rho3v.d[0] + rho3v.d[2];
|
||||
rho4 = rho4v.d[0] + rho4v.d[2];
|
||||
rho5 = rho5v.d[0] + rho5v.d[2];
|
||||
rho6 = rho6v.d[0] + rho6v.d[2];
|
||||
rho7 = rho7v.d[0] + rho7v.d[2];
|
||||
#endif
|
||||
|
||||
// If there are leftover iterations, perform them with scalar code.
|
||||
for ( i = 0; i < m_left ; ++i )
|
||||
{
|
||||
const double x0c = *x0;
|
||||
|
||||
const double a0c = *a0;
|
||||
const double a1c = *a1;
|
||||
const double a2c = *a2;
|
||||
const double a3c = *a3;
|
||||
const double a4c = *a4;
|
||||
const double a5c = *a5;
|
||||
const double a6c = *a6;
|
||||
const double a7c = *a7;
|
||||
|
||||
rho0 += a0c * x0c;
|
||||
rho1 += a1c * x0c;
|
||||
rho2 += a2c * x0c;
|
||||
rho3 += a3c * x0c;
|
||||
rho4 += a4c * x0c;
|
||||
rho5 += a5c * x0c;
|
||||
rho6 += a6c * x0c;
|
||||
rho7 += a7c * x0c;
|
||||
|
||||
x0 += incx;
|
||||
a0 += inca;
|
||||
a1 += inca;
|
||||
a2 += inca;
|
||||
a3 += inca;
|
||||
a4 += inca;
|
||||
a5 += inca;
|
||||
a6 += inca;
|
||||
a7 += inca;
|
||||
// Adjust for scalar subproblem.
|
||||
m -= n_elem_per_reg * n_iter_unroll * m_viter;
|
||||
a += n_elem_per_reg * n_iter_unroll * m_viter /* * inca */;
|
||||
x += n_elem_per_reg * n_iter_unroll * m_viter /* * incx */;
|
||||
}
|
||||
else if ( lda == 1 )
|
||||
{
|
||||
const dim_t n_iter_unroll = 3;
|
||||
const dim_t n_reg_per_row = 2; // fuse_fac / n_elem_per_reg;
|
||||
|
||||
// Use the unrolling factor and the number of elements per register
|
||||
// to compute the number of vectorized and leftover iterations.
|
||||
dim_t m_viter = ( m ) / ( n_reg_per_row * n_iter_unroll );
|
||||
|
||||
// Initialize pointers for x and A.
|
||||
double* restrict x0 = x;
|
||||
double* restrict a0 = a;
|
||||
|
||||
// Initialize rho vector accumulators to zero.
|
||||
v4df_t rho0v; rho0v.v = _mm256_setzero_pd();
|
||||
v4df_t rho1v; rho1v.v = _mm256_setzero_pd();
|
||||
v4df_t rho2v; rho2v.v = _mm256_setzero_pd();
|
||||
v4df_t rho3v; rho3v.v = _mm256_setzero_pd();
|
||||
v4df_t rho4v; rho4v.v = _mm256_setzero_pd();
|
||||
v4df_t rho5v; rho5v.v = _mm256_setzero_pd();
|
||||
|
||||
v4df_t x0v, x1v, x2v;
|
||||
v4df_t a0v, a1v, a2v, a3v, a4v, a5v;
|
||||
|
||||
for ( dim_t i = 0; i < m_viter; ++i )
|
||||
{
|
||||
// Load the input values.
|
||||
a0v.v = _mm256_loadu_pd( a0 + 0*inca + 0*n_elem_per_reg );
|
||||
a1v.v = _mm256_loadu_pd( a0 + 0*inca + 1*n_elem_per_reg );
|
||||
a2v.v = _mm256_loadu_pd( a0 + 1*inca + 0*n_elem_per_reg );
|
||||
a3v.v = _mm256_loadu_pd( a0 + 1*inca + 1*n_elem_per_reg );
|
||||
a4v.v = _mm256_loadu_pd( a0 + 2*inca + 0*n_elem_per_reg );
|
||||
a5v.v = _mm256_loadu_pd( a0 + 2*inca + 1*n_elem_per_reg );
|
||||
|
||||
x0v.v = _mm256_broadcast_sd( x0 + 0*incx );
|
||||
x1v.v = _mm256_broadcast_sd( x0 + 1*incx );
|
||||
x2v.v = _mm256_broadcast_sd( x0 + 2*incx );
|
||||
|
||||
// perform : rho?v += a?v * x?v;
|
||||
rho0v.v = _mm256_fmadd_pd( a0v.v, x0v.v, rho0v.v );
|
||||
rho1v.v = _mm256_fmadd_pd( a1v.v, x0v.v, rho1v.v );
|
||||
rho2v.v = _mm256_fmadd_pd( a2v.v, x1v.v, rho2v.v );
|
||||
rho3v.v = _mm256_fmadd_pd( a3v.v, x1v.v, rho3v.v );
|
||||
rho4v.v = _mm256_fmadd_pd( a4v.v, x2v.v, rho4v.v );
|
||||
rho5v.v = _mm256_fmadd_pd( a5v.v, x2v.v, rho5v.v );
|
||||
|
||||
x0 += incx * n_iter_unroll;
|
||||
a0 += inca * n_iter_unroll;
|
||||
}
|
||||
|
||||
// Combine the 8 accumulators into one vector register.
|
||||
rho0v.v = _mm256_add_pd( rho0v.v, rho2v.v );
|
||||
rho0v.v = _mm256_add_pd( rho0v.v, rho4v.v );
|
||||
rho1v.v = _mm256_add_pd( rho1v.v, rho3v.v );
|
||||
rho1v.v = _mm256_add_pd( rho1v.v, rho5v.v );
|
||||
|
||||
// Write vector components to scalar values.
|
||||
rho0 = rho0v.d[0];
|
||||
rho1 = rho0v.d[1];
|
||||
rho2 = rho0v.d[2];
|
||||
rho3 = rho0v.d[3];
|
||||
rho4 = rho1v.d[0];
|
||||
rho5 = rho1v.d[1];
|
||||
rho6 = rho1v.d[2];
|
||||
rho7 = rho1v.d[3];
|
||||
|
||||
// Adjust for scalar subproblem.
|
||||
m -= n_iter_unroll * m_viter;
|
||||
a += n_iter_unroll * m_viter * inca;
|
||||
x += n_iter_unroll * m_viter * incx;
|
||||
}
|
||||
else
|
||||
{
|
||||
// No vectorization possible; use scalar iterations for the entire
|
||||
// problem.
|
||||
}
|
||||
|
||||
// Scalar edge case.
|
||||
{
|
||||
// Initialize pointers for x and the b_n columns of A (rows of A^T).
|
||||
double* restrict x0 = x;
|
||||
double* restrict a0 = a + 0*lda;
|
||||
double* restrict a1 = a + 1*lda;
|
||||
double* restrict a2 = a + 2*lda;
|
||||
double* restrict a3 = a + 3*lda;
|
||||
double* restrict a4 = a + 4*lda;
|
||||
double* restrict a5 = a + 5*lda;
|
||||
double* restrict a6 = a + 6*lda;
|
||||
double* restrict a7 = a + 7*lda;
|
||||
|
||||
// If there are leftover iterations, perform them with scalar code.
|
||||
for ( dim_t i = 0; i < m ; ++i )
|
||||
{
|
||||
const double x0c = *x0;
|
||||
|
||||
const double a0c = *a0;
|
||||
const double a1c = *a1;
|
||||
const double a2c = *a2;
|
||||
const double a3c = *a3;
|
||||
const double a4c = *a4;
|
||||
const double a5c = *a5;
|
||||
const double a6c = *a6;
|
||||
const double a7c = *a7;
|
||||
|
||||
rho0 += a0c * x0c;
|
||||
rho1 += a1c * x0c;
|
||||
rho2 += a2c * x0c;
|
||||
rho3 += a3c * x0c;
|
||||
rho4 += a4c * x0c;
|
||||
rho5 += a5c * x0c;
|
||||
rho6 += a6c * x0c;
|
||||
rho7 += a7c * x0c;
|
||||
|
||||
x0 += incx;
|
||||
a0 += inca;
|
||||
a1 += inca;
|
||||
a2 += inca;
|
||||
a3 += inca;
|
||||
a4 += inca;
|
||||
a5 += inca;
|
||||
a6 += inca;
|
||||
a7 += inca;
|
||||
}
|
||||
}
|
||||
|
||||
// Now prepare the final rho values to output/accumulate back into
|
||||
// the y vector.
|
||||
|
||||
v4df_t rho0v, rho1v, y0v, y1v;
|
||||
|
||||
// Insert the scalar rho values into a single vector.
|
||||
rho0v.d[0] = rho0;
|
||||
@@ -621,7 +773,7 @@ void bli_ddotxf_zen_int_8
|
||||
rho1v.d[3] = rho7;
|
||||
|
||||
// Broadcast the alpha scalar.
|
||||
alphav.v = _mm256_broadcast_sd( alpha );
|
||||
v4df_t alphav; alphav.v = _mm256_broadcast_sd( alpha );
|
||||
|
||||
// We know at this point that alpha is nonzero; however, beta may still
|
||||
// be zero. If beta is indeed zero, we must overwrite y rather than scale
|
||||
@@ -636,7 +788,7 @@ void bli_ddotxf_zen_int_8
|
||||
else
|
||||
{
|
||||
// Broadcast the beta scalar.
|
||||
betav.v = _mm256_broadcast_sd( beta );
|
||||
v4df_t betav; betav.v = _mm256_broadcast_sd( beta );
|
||||
|
||||
// Load y.
|
||||
if ( incy == 1 )
|
||||
|
||||
Reference in New Issue
Block a user