mirror of
https://github.com/amd/blis.git
synced 2026-05-12 01:59:59 +00:00
Beta Zero Checks for sgemm_small
Change-Id: I111b66ad54a27b1977d155904738a55a351e6689
This commit is contained in:
committed by
Nallani Bhaskar
parent
cc98047fd6
commit
e0c95d77e1
@@ -174,7 +174,6 @@ static err_t bli_sgemm_small
|
||||
gint_t K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) .
|
||||
gint_t L = M * N;
|
||||
|
||||
// printf("alpha_cast = %f beta_cast = %f [ Trans = %d %d], [stride = %d %d %d] [m,n,k = %d %d %d]\n",*alpha_cast,*beta_cast, bli_obj_has_trans( a ), bli_obj_has_trans( b ), lda, ldb,ldc, M,N,K);
|
||||
if ((((L) < (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES))
|
||||
|| ((M < BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0)))
|
||||
{
|
||||
@@ -187,12 +186,13 @@ static err_t bli_sgemm_small
|
||||
float *C = c->buffer; // pointer to elements of Matrix C
|
||||
|
||||
float *tA = A, *tB = B, *tC = C;//, *tA_pack;
|
||||
float *tA_packed; // temprorary pointer to hold packed A memory pointer
|
||||
float *tA_packed; // temporary pointer to hold packed A memory pointer
|
||||
guint_t row_idx_packed; //packed A memory row index
|
||||
guint_t lda_packed; //lda of packed A
|
||||
guint_t col_idx_start; //starting index after A matrix is packed.
|
||||
dim_t tb_inc_row = 1; // row stride of matrix B
|
||||
dim_t tb_inc_col = ldb; // column stride of matrix B
|
||||
|
||||
__m256 ymm4, ymm5, ymm6, ymm7;
|
||||
__m256 ymm8, ymm9, ymm10, ymm11;
|
||||
__m256 ymm12, ymm13, ymm14, ymm15;
|
||||
@@ -208,7 +208,13 @@ static err_t bli_sgemm_small
|
||||
mem_t local_mem_buf_A_s;
|
||||
float *A_pack = NULL;
|
||||
rntm_t rntm;
|
||||
|
||||
|
||||
/*Beta Zero Check*/
|
||||
guint_t isbetanonzero=0;
|
||||
if ( !bli_obj_equals( beta, &BLIS_ZERO ) ){
|
||||
isbetanonzero = 1;
|
||||
}
|
||||
|
||||
// when N is equal to 1 call GEMV instead of GEMM
|
||||
if (N == 1)
|
||||
{
|
||||
@@ -224,8 +230,7 @@ static err_t bli_sgemm_small
|
||||
}
|
||||
|
||||
//update the pointer math if matrix B needs to be transposed.
|
||||
if (bli_obj_has_trans( b ))
|
||||
{
|
||||
if (bli_obj_has_trans( b )) {
|
||||
tb_inc_col = 1; //switch row and column strides
|
||||
tb_inc_row = ldb;
|
||||
}
|
||||
@@ -252,17 +257,15 @@ static err_t bli_sgemm_small
|
||||
bli_rntm_set_num_threads_only( 1, &rntm );
|
||||
bli_membrk_rntm_set_membrk( &rntm );
|
||||
|
||||
|
||||
// Get the current size of the buffer pool for A block packing.
|
||||
// We will use the same size to avoid pool re-initliazaton
|
||||
siz_t buffer_size = bli_pool_block_size(
|
||||
bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK),
|
||||
bli_rntm_membrk(&rntm)));
|
||||
// We will use the same size to avoid pool re-initialization
|
||||
siz_t buffer_size = bli_pool_block_size(bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK),
|
||||
bli_rntm_membrk(&rntm)));
|
||||
|
||||
// Based on the available memory in the buffer we will decide if
|
||||
// we want to do packing or not.
|
||||
//
|
||||
// This kernel assumes that "A" will be unpackged if N <= 3.
|
||||
// This kernel assumes that "A" will be un-packged if N <= 3.
|
||||
// Usually this range (N <= 3) is handled by SUP, however,
|
||||
// if SUP is disabled or for any other condition if we do
|
||||
// enter this kernel with N <= 3, we want to make sure that
|
||||
@@ -299,7 +302,6 @@ static err_t bli_sgemm_small
|
||||
// Process MR rows of C matrix at a time.
|
||||
for (row_idx = 0; (row_idx + (MR - 1)) < M; row_idx += MR)
|
||||
{
|
||||
|
||||
col_idx_start = 0;
|
||||
tA_packed = A;
|
||||
row_idx_packed = row_idx;
|
||||
@@ -394,7 +396,6 @@ static err_t bli_sgemm_small
|
||||
}
|
||||
// alpha, beta multiplication.
|
||||
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
|
||||
//multiply A*B by alpha.
|
||||
ymm4 = _mm256_mul_ps(ymm4, ymm0);
|
||||
@@ -410,15 +411,39 @@ static err_t bli_sgemm_small
|
||||
ymm14 = _mm256_mul_ps(ymm14, ymm0);
|
||||
ymm15 = _mm256_mul_ps(ymm15, ymm0);
|
||||
|
||||
// multiply C by beta and accumulate col 1.
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
ymm2 = _mm256_loadu_ps(tC + 16);
|
||||
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
|
||||
ymm2 = _mm256_loadu_ps(tC + 24);
|
||||
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
|
||||
if(isbetanonzero)
|
||||
{
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
// multiply C by beta and accumulate col 1.
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
ymm2 = _mm256_loadu_ps(tC + 16);
|
||||
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
|
||||
ymm2 = _mm256_loadu_ps(tC + 24);
|
||||
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
|
||||
|
||||
float* ttC = tC +ldc;
|
||||
ymm2 = _mm256_loadu_ps(ttC);
|
||||
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 8);
|
||||
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 16);
|
||||
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 24);
|
||||
ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);
|
||||
|
||||
ttC += ldc;
|
||||
ymm2 = _mm256_loadu_ps(ttC);
|
||||
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 8);
|
||||
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 16);
|
||||
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 24);
|
||||
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
|
||||
}
|
||||
_mm256_storeu_ps(tC, ymm4);
|
||||
_mm256_storeu_ps(tC + 8, ymm5);
|
||||
_mm256_storeu_ps(tC + 16, ymm6);
|
||||
@@ -426,14 +451,6 @@ static err_t bli_sgemm_small
|
||||
|
||||
// multiply C by beta and accumulate, col 2.
|
||||
tC += ldc;
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
|
||||
ymm2 = _mm256_loadu_ps(tC + 16);
|
||||
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
|
||||
ymm2 = _mm256_loadu_ps(tC + 24);
|
||||
ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);
|
||||
_mm256_storeu_ps(tC, ymm8);
|
||||
_mm256_storeu_ps(tC + 8, ymm9);
|
||||
_mm256_storeu_ps(tC + 16, ymm10);
|
||||
@@ -441,14 +458,6 @@ static err_t bli_sgemm_small
|
||||
|
||||
// multiply C by beta and accumulate, col 3.
|
||||
tC += ldc;
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
|
||||
ymm2 = _mm256_loadu_ps(tC + 16);
|
||||
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
|
||||
ymm2 = _mm256_loadu_ps(tC + 24);
|
||||
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
|
||||
_mm256_storeu_ps(tC, ymm12);
|
||||
_mm256_storeu_ps(tC + 8, ymm13);
|
||||
_mm256_storeu_ps(tC + 16, ymm14);
|
||||
@@ -538,7 +547,6 @@ static err_t bli_sgemm_small
|
||||
}
|
||||
// alpha, beta multiplication.
|
||||
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
|
||||
//multiply A*B by alpha.
|
||||
ymm4 = _mm256_mul_ps(ymm4, ymm0);
|
||||
@@ -554,15 +562,37 @@ static err_t bli_sgemm_small
|
||||
ymm14 = _mm256_mul_ps(ymm14, ymm0);
|
||||
ymm15 = _mm256_mul_ps(ymm15, ymm0);
|
||||
|
||||
// multiply C by beta and accumulate col 1.
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
ymm2 = _mm256_loadu_ps(tC + 16);
|
||||
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
|
||||
ymm2 = _mm256_loadu_ps(tC + 24);
|
||||
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
|
||||
if(isbetanonzero)
|
||||
{
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
// multiply C by beta and accumulate col 1.
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
ymm2 = _mm256_loadu_ps(tC + 16);
|
||||
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
|
||||
ymm2 = _mm256_loadu_ps(tC + 24);
|
||||
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
|
||||
float* ttC = tC +ldc;
|
||||
ymm2 = _mm256_loadu_ps(ttC);
|
||||
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 8);
|
||||
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 16);
|
||||
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 24);
|
||||
ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);
|
||||
ttC = ttC +ldc;
|
||||
ymm2 = _mm256_loadu_ps(ttC);
|
||||
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 8);
|
||||
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 16);
|
||||
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 24);
|
||||
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
|
||||
}
|
||||
_mm256_storeu_ps(tC, ymm4);
|
||||
_mm256_storeu_ps(tC + 8, ymm5);
|
||||
_mm256_storeu_ps(tC + 16, ymm6);
|
||||
@@ -570,14 +600,6 @@ static err_t bli_sgemm_small
|
||||
|
||||
// multiply C by beta and accumulate, col 2.
|
||||
tC += ldc;
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
|
||||
ymm2 = _mm256_loadu_ps(tC + 16);
|
||||
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
|
||||
ymm2 = _mm256_loadu_ps(tC + 24);
|
||||
ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);
|
||||
_mm256_storeu_ps(tC, ymm8);
|
||||
_mm256_storeu_ps(tC + 8, ymm9);
|
||||
_mm256_storeu_ps(tC + 16, ymm10);
|
||||
@@ -585,14 +607,6 @@ static err_t bli_sgemm_small
|
||||
|
||||
// multiply C by beta and accumulate, col 3.
|
||||
tC += ldc;
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
|
||||
ymm2 = _mm256_loadu_ps(tC + 16);
|
||||
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
|
||||
ymm2 = _mm256_loadu_ps(tC + 24);
|
||||
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
|
||||
_mm256_storeu_ps(tC, ymm12);
|
||||
_mm256_storeu_ps(tC + 8, ymm13);
|
||||
_mm256_storeu_ps(tC + 16, ymm14);
|
||||
@@ -651,7 +665,6 @@ static err_t bli_sgemm_small
|
||||
}
|
||||
// alpha, beta multiplication.
|
||||
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
|
||||
//multiply A*B by alpha.
|
||||
ymm8 = _mm256_mul_ps(ymm8, ymm0);
|
||||
@@ -664,29 +677,34 @@ static err_t bli_sgemm_small
|
||||
ymm15 = _mm256_mul_ps(ymm15, ymm0);
|
||||
|
||||
// multiply C by beta and accumulate, col 1.
|
||||
ymm2 = _mm256_loadu_ps(tC + 0);
|
||||
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
|
||||
ymm2 = _mm256_loadu_ps(tC + 16);
|
||||
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
|
||||
ymm2 = _mm256_loadu_ps(tC + 24);
|
||||
ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);
|
||||
_mm256_storeu_ps(tC + 0, ymm8);
|
||||
if(isbetanonzero)
|
||||
{
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
|
||||
ymm2 = _mm256_loadu_ps(tC + 16);
|
||||
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
|
||||
ymm2 = _mm256_loadu_ps(tC + 24);
|
||||
ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);
|
||||
|
||||
float* ttC = tC +ldc;
|
||||
// multiply C by beta and accumulate, col 2.
|
||||
ymm2 = _mm256_loadu_ps(ttC);
|
||||
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 8);
|
||||
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 16);
|
||||
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 24);
|
||||
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
|
||||
}
|
||||
_mm256_storeu_ps(tC, ymm8);
|
||||
_mm256_storeu_ps(tC + 8, ymm9);
|
||||
_mm256_storeu_ps(tC + 16, ymm10);
|
||||
_mm256_storeu_ps(tC + 24, ymm11);
|
||||
|
||||
// multiply C by beta and accumulate, col 2.
|
||||
tC += ldc;
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
|
||||
ymm2 = _mm256_loadu_ps(tC + 16);
|
||||
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
|
||||
ymm2 = _mm256_loadu_ps(tC + 24);
|
||||
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
|
||||
_mm256_storeu_ps(tC, ymm12);
|
||||
_mm256_storeu_ps(tC + 8, ymm13);
|
||||
_mm256_storeu_ps(tC + 16, ymm14);
|
||||
@@ -735,7 +753,6 @@ static err_t bli_sgemm_small
|
||||
}
|
||||
// alpha, beta multiplication.
|
||||
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
|
||||
//multiply A*B by alpha.
|
||||
ymm12 = _mm256_mul_ps(ymm12, ymm0);
|
||||
@@ -743,15 +760,19 @@ static err_t bli_sgemm_small
|
||||
ymm14 = _mm256_mul_ps(ymm14, ymm0);
|
||||
ymm15 = _mm256_mul_ps(ymm15, ymm0);
|
||||
|
||||
// multiply C by beta and accumulate.
|
||||
ymm2 = _mm256_loadu_ps(tC + 0);
|
||||
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
|
||||
ymm2 = _mm256_loadu_ps(tC + 16);
|
||||
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
|
||||
ymm2 = _mm256_loadu_ps(tC + 24);
|
||||
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
|
||||
if(isbetanonzero)
|
||||
{
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
// multiply C by beta and accumulate.
|
||||
ymm2 = _mm256_loadu_ps(tC + 0);
|
||||
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
|
||||
ymm2 = _mm256_loadu_ps(tC + 16);
|
||||
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
|
||||
ymm2 = _mm256_loadu_ps(tC + 24);
|
||||
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
|
||||
}
|
||||
|
||||
_mm256_storeu_ps(tC + 0, ymm12);
|
||||
_mm256_storeu_ps(tC + 8, ymm13);
|
||||
@@ -823,7 +844,6 @@ static err_t bli_sgemm_small
|
||||
}
|
||||
// alpha, beta multiplication.
|
||||
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
|
||||
//multiply A*B by alpha.
|
||||
ymm4 = _mm256_mul_ps(ymm4, ymm0);
|
||||
@@ -836,37 +856,43 @@ static err_t bli_sgemm_small
|
||||
ymm13 = _mm256_mul_ps(ymm13, ymm0);
|
||||
ymm14 = _mm256_mul_ps(ymm14, ymm0);
|
||||
|
||||
// multiply C by beta and accumulate.
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
ymm2 = _mm256_loadu_ps(tC + 16);
|
||||
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
|
||||
if(isbetanonzero)
|
||||
{
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
// multiply C by beta and accumulate.
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
ymm2 = _mm256_loadu_ps(tC + 16);
|
||||
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
|
||||
float* ttC = tC +ldc;
|
||||
ymm2 = _mm256_loadu_ps(ttC);
|
||||
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 8);
|
||||
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 16);
|
||||
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
|
||||
ttC += ldc;
|
||||
ymm2 = _mm256_loadu_ps(ttC);
|
||||
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 8);
|
||||
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 16);
|
||||
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
|
||||
}
|
||||
_mm256_storeu_ps(tC, ymm4);
|
||||
_mm256_storeu_ps(tC + 8, ymm5);
|
||||
_mm256_storeu_ps(tC + 16, ymm6);
|
||||
|
||||
// multiply C by beta and accumulate.
|
||||
tC += ldc;
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
|
||||
ymm2 = _mm256_loadu_ps(tC + 16);
|
||||
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
|
||||
_mm256_storeu_ps(tC, ymm8);
|
||||
_mm256_storeu_ps(tC + 8, ymm9);
|
||||
_mm256_storeu_ps(tC + 16, ymm10);
|
||||
|
||||
// multiply C by beta and accumulate.
|
||||
tC += ldc;
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
|
||||
ymm2 = _mm256_loadu_ps(tC + 16);
|
||||
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
|
||||
_mm256_storeu_ps(tC, ymm12);
|
||||
_mm256_storeu_ps(tC + 8, ymm13);
|
||||
_mm256_storeu_ps(tC + 16, ymm14);
|
||||
@@ -917,7 +943,6 @@ static err_t bli_sgemm_small
|
||||
}
|
||||
// alpha, beta multiplication.
|
||||
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
|
||||
//multiply A*B by alpha.
|
||||
ymm8 = _mm256_mul_ps(ymm8, ymm0);
|
||||
@@ -927,25 +952,33 @@ static err_t bli_sgemm_small
|
||||
ymm13 = _mm256_mul_ps(ymm13, ymm0);
|
||||
ymm14 = _mm256_mul_ps(ymm14, ymm0);
|
||||
|
||||
// multiply C by beta and accumulate.
|
||||
ymm2 = _mm256_loadu_ps(tC + 0);
|
||||
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
|
||||
ymm2 = _mm256_loadu_ps(tC + 16);
|
||||
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
|
||||
_mm256_storeu_ps(tC + 0, ymm8);
|
||||
if(isbetanonzero)
|
||||
{
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
// multiply C by beta and accumulate.
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
|
||||
ymm2 = _mm256_loadu_ps(tC + 16);
|
||||
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
|
||||
|
||||
float* ttC = tC +ldc;
|
||||
// multiply C by beta and accumulate.
|
||||
ymm2 = _mm256_loadu_ps(ttC);
|
||||
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 8);
|
||||
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 16);
|
||||
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
|
||||
}
|
||||
|
||||
_mm256_storeu_ps(tC, ymm8);
|
||||
_mm256_storeu_ps(tC + 8, ymm9);
|
||||
_mm256_storeu_ps(tC + 16, ymm10);
|
||||
|
||||
// multiply C by beta and accumulate.
|
||||
tC += ldc;
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
|
||||
ymm2 = _mm256_loadu_ps(tC + 16);
|
||||
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
|
||||
|
||||
_mm256_storeu_ps(tC, ymm12);
|
||||
_mm256_storeu_ps(tC + 8, ymm13);
|
||||
_mm256_storeu_ps(tC + 16, ymm14);
|
||||
@@ -989,13 +1022,15 @@ static err_t bli_sgemm_small
|
||||
}
|
||||
// alpha, beta multiplication.
|
||||
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
|
||||
//multiply A*B by alpha.
|
||||
ymm12 = _mm256_mul_ps(ymm12, ymm0);
|
||||
ymm13 = _mm256_mul_ps(ymm13, ymm0);
|
||||
ymm14 = _mm256_mul_ps(ymm14, ymm0);
|
||||
|
||||
if(isbetanonzero)
|
||||
{
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
// multiply C by beta and accumulate.
|
||||
ymm2 = _mm256_loadu_ps(tC + 0);
|
||||
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
|
||||
@@ -1003,7 +1038,7 @@ static err_t bli_sgemm_small
|
||||
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
|
||||
ymm2 = _mm256_loadu_ps(tC + 16);
|
||||
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
|
||||
|
||||
}
|
||||
_mm256_storeu_ps(tC + 0, ymm12);
|
||||
_mm256_storeu_ps(tC + 8, ymm13);
|
||||
_mm256_storeu_ps(tC + 16, ymm14);
|
||||
@@ -1056,7 +1091,6 @@ static err_t bli_sgemm_small
|
||||
}
|
||||
// alpha, beta multiplication.
|
||||
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
|
||||
//multiply A*B by alpha.
|
||||
ymm4 = _mm256_mul_ps(ymm4, ymm0);
|
||||
@@ -1066,29 +1100,35 @@ static err_t bli_sgemm_small
|
||||
ymm8 = _mm256_mul_ps(ymm8, ymm0);
|
||||
ymm9 = _mm256_mul_ps(ymm9, ymm0);
|
||||
|
||||
// multiply C by beta and accumulate.
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
if(isbetanonzero)
|
||||
{
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
// multiply C by beta and accumulate.
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
float* ttC = tC + ldc;
|
||||
ymm2 = _mm256_loadu_ps(ttC);
|
||||
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 8);
|
||||
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
|
||||
ttC += ldc;
|
||||
ymm2 = _mm256_loadu_ps(ttC);
|
||||
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 8);
|
||||
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
|
||||
}
|
||||
_mm256_storeu_ps(tC, ymm4);
|
||||
_mm256_storeu_ps(tC + 8, ymm5);
|
||||
|
||||
// multiply C by beta and accumulate.
|
||||
tC += ldc;
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
|
||||
_mm256_storeu_ps(tC, ymm6);
|
||||
_mm256_storeu_ps(tC + 8, ymm7);
|
||||
|
||||
// multiply C by beta and accumulate.
|
||||
tC += ldc;
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
|
||||
_mm256_storeu_ps(tC, ymm8);
|
||||
_mm256_storeu_ps(tC + 8, ymm9);
|
||||
|
||||
@@ -1131,7 +1171,6 @@ static err_t bli_sgemm_small
|
||||
}
|
||||
// alpha, beta multiplication.
|
||||
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
|
||||
//multiply A*B by alpha.
|
||||
ymm4 = _mm256_mul_ps(ymm4, ymm0);
|
||||
@@ -1139,20 +1178,25 @@ static err_t bli_sgemm_small
|
||||
ymm6 = _mm256_mul_ps(ymm6, ymm0);
|
||||
ymm7 = _mm256_mul_ps(ymm7, ymm0);
|
||||
|
||||
// multiply C by beta and accumulate.
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
if(isbetanonzero)
|
||||
{
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
// multiply C by beta and accumulate.
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
float* ttC = tC + ldc;
|
||||
ymm2 = _mm256_loadu_ps(ttC);
|
||||
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
|
||||
ymm2 = _mm256_loadu_ps(ttC + 8);
|
||||
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
|
||||
}
|
||||
_mm256_storeu_ps(tC, ymm4);
|
||||
_mm256_storeu_ps(tC + 8, ymm5);
|
||||
|
||||
// multiply C by beta and accumulate.
|
||||
tC += ldc;
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
|
||||
_mm256_storeu_ps(tC, ymm6);
|
||||
_mm256_storeu_ps(tC + 8, ymm7);
|
||||
|
||||
@@ -1190,16 +1234,19 @@ static err_t bli_sgemm_small
|
||||
}
|
||||
// alpha, beta multiplication.
|
||||
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
|
||||
ymm4 = _mm256_mul_ps(ymm4, ymm0);
|
||||
ymm5 = _mm256_mul_ps(ymm5, ymm0);
|
||||
|
||||
// multiply C by beta and accumulate.
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
if(isbetanonzero)
|
||||
{
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
||||
ymm2 = _mm256_loadu_ps(tC + 8);
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
}
|
||||
_mm256_storeu_ps(tC, ymm4);
|
||||
_mm256_storeu_ps(tC + 8, ymm5);
|
||||
|
||||
@@ -1244,28 +1291,30 @@ static err_t bli_sgemm_small
|
||||
}
|
||||
// alpha, beta multiplication.
|
||||
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
|
||||
//multiply A*B by alpha.
|
||||
ymm4 = _mm256_mul_ps(ymm4, ymm0);
|
||||
ymm5 = _mm256_mul_ps(ymm5, ymm0);
|
||||
ymm6 = _mm256_mul_ps(ymm6, ymm0);
|
||||
|
||||
// multiply C by beta and accumulate.
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
||||
if(isbetanonzero)
|
||||
{
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
||||
ymm2 = _mm256_loadu_ps(tC + ldc);
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
ymm2 = _mm256_loadu_ps(tC + 2*ldc);
|
||||
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
|
||||
}
|
||||
_mm256_storeu_ps(tC, ymm4);
|
||||
|
||||
// multiply C by beta and accumulate.
|
||||
tC += ldc;
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
_mm256_storeu_ps(tC, ymm5);
|
||||
|
||||
// multiply C by beta and accumulate.
|
||||
tC += ldc;
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
|
||||
_mm256_storeu_ps(tC, ymm6);
|
||||
}
|
||||
n_remainder = N - col_idx;
|
||||
@@ -1299,21 +1348,23 @@ static err_t bli_sgemm_small
|
||||
}
|
||||
// alpha, beta multiplication.
|
||||
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
|
||||
//multiply A*B by alpha.
|
||||
ymm4 = _mm256_mul_ps(ymm4, ymm0);
|
||||
ymm5 = _mm256_mul_ps(ymm5, ymm0);
|
||||
|
||||
// multiply C by beta and accumulate.
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
||||
if(isbetanonzero)
|
||||
{
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
// multiply C by beta and accumulate.
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
||||
ymm2 = _mm256_loadu_ps(tC + ldc);
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
}
|
||||
_mm256_storeu_ps(tC, ymm4);
|
||||
|
||||
// multiply C by beta and accumulate.
|
||||
tC += ldc;
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
_mm256_storeu_ps(tC, ymm5);
|
||||
|
||||
col_idx += 2;
|
||||
@@ -1346,13 +1397,15 @@ static err_t bli_sgemm_small
|
||||
}
|
||||
// alpha, beta multiplication.
|
||||
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
|
||||
ymm4 = _mm256_mul_ps(ymm4, ymm0);
|
||||
|
||||
// multiply C by beta and accumulate.
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
||||
if(isbetanonzero)
|
||||
{
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
// multiply C by beta and accumulate.
|
||||
ymm2 = _mm256_loadu_ps(tC);
|
||||
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
|
||||
}
|
||||
_mm256_storeu_ps(tC, ymm4);
|
||||
|
||||
}
|
||||
@@ -1426,7 +1479,8 @@ static err_t bli_sgemm_small
|
||||
f_temp[i] = tC[i];
|
||||
}
|
||||
ymm2 = _mm256_loadu_ps(f_temp);
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
if(isbetanonzero)
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
_mm256_storeu_ps(f_temp, ymm5);
|
||||
for (int i = 0; i < m_remainder; i++)
|
||||
{
|
||||
@@ -1439,7 +1493,8 @@ static err_t bli_sgemm_small
|
||||
f_temp[i] = tC[i];
|
||||
}
|
||||
ymm2 = _mm256_loadu_ps(f_temp);
|
||||
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
|
||||
if(isbetanonzero)
|
||||
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
|
||||
_mm256_storeu_ps(f_temp, ymm7);
|
||||
for (int i = 0; i < m_remainder; i++)
|
||||
{
|
||||
@@ -1452,7 +1507,8 @@ static err_t bli_sgemm_small
|
||||
f_temp[i] = tC[i];
|
||||
}
|
||||
ymm2 = _mm256_loadu_ps(f_temp);
|
||||
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
|
||||
if(isbetanonzero)
|
||||
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
|
||||
_mm256_storeu_ps(f_temp, ymm9);
|
||||
for (int i = 0; i < m_remainder; i++)
|
||||
{
|
||||
@@ -1510,7 +1566,8 @@ static err_t bli_sgemm_small
|
||||
f_temp[i] = tC[i];
|
||||
}
|
||||
ymm2 = _mm256_loadu_ps(f_temp);
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
if(isbetanonzero)
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
_mm256_storeu_ps(f_temp, ymm5);
|
||||
for (int i = 0; i < m_remainder; i++)
|
||||
{
|
||||
@@ -1523,7 +1580,8 @@ static err_t bli_sgemm_small
|
||||
f_temp[i] = tC[i];
|
||||
}
|
||||
ymm2 = _mm256_loadu_ps(f_temp);
|
||||
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
|
||||
if(isbetanonzero)
|
||||
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
|
||||
_mm256_storeu_ps(f_temp, ymm7);
|
||||
for (int i = 0; i < m_remainder; i++)
|
||||
{
|
||||
@@ -1565,7 +1623,6 @@ static err_t bli_sgemm_small
|
||||
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
|
||||
|
||||
ymm0 = _mm256_broadcast_ss(alpha_cast);
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
|
||||
// multiply C by beta and accumulate.
|
||||
ymm5 = _mm256_mul_ps(ymm5, ymm0);
|
||||
@@ -1575,7 +1632,10 @@ static err_t bli_sgemm_small
|
||||
f_temp[i] = tC[i];
|
||||
}
|
||||
ymm2 = _mm256_loadu_ps(f_temp);
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
if(isbetanonzero){
|
||||
ymm1 = _mm256_broadcast_ss(beta_cast);
|
||||
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
|
||||
}
|
||||
_mm256_storeu_ps(f_temp, ymm5);
|
||||
for (int i = 0; i < m_remainder; i++)
|
||||
{
|
||||
@@ -1606,7 +1666,11 @@ static err_t bli_sgemm_small
|
||||
}
|
||||
|
||||
result *= (*alpha_cast);
|
||||
(*tC) = (*tC) * (*beta_cast) + result;
|
||||
if(isbetanonzero){
|
||||
(*tC) = (*tC) * (*beta_cast) + result;
|
||||
}else{
|
||||
(*tC) = result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3089,7 +3153,7 @@ static err_t bli_dgemm_small
|
||||
}
|
||||
|
||||
// Return the buffer to pool
|
||||
if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) {
|
||||
if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) {
|
||||
#ifdef BLIS_ENABLE_MEM_TRACING
|
||||
printf( "bli_dgemm_small(): releasing mem pool block\n" );
|
||||
#endif
|
||||
@@ -3139,6 +3203,13 @@ static err_t bli_sgemm_small_atbn
|
||||
alpha_cast = (alpha->buffer);
|
||||
beta_cast = (beta->buffer);
|
||||
|
||||
/*Beta Zero Check*/
|
||||
guint_t isbetanonzero=0;
|
||||
if ( !bli_obj_equals( beta, &BLIS_ZERO ) ){
|
||||
isbetanonzero = 1;
|
||||
}
|
||||
|
||||
|
||||
// The non-copy version of the A^T GEMM gives better performance for the small M cases.
|
||||
// The threshold is controlled by BLIS_ATBN_M_THRES
|
||||
if (M <= BLIS_ATBN_M_THRES)
|
||||
@@ -3250,28 +3321,44 @@ static err_t bli_sgemm_small_atbn
|
||||
_mm256_storeu_ps(scratch, ymm4);
|
||||
result = scratch[0] + scratch[4];
|
||||
result *= (*alpha_cast);
|
||||
if(isbetanonzero){
|
||||
tC[0] = result + tC[0] * (*beta_cast);
|
||||
}else{
|
||||
tC[0] = result;
|
||||
}
|
||||
|
||||
ymm7 = _mm256_hadd_ps(ymm7, ymm7);
|
||||
ymm7 = _mm256_hadd_ps(ymm7, ymm7);
|
||||
_mm256_storeu_ps(scratch, ymm7);
|
||||
result = scratch[0] + scratch[4];
|
||||
result *= (*alpha_cast);
|
||||
if(isbetanonzero){
|
||||
tC[1] = result + tC[1] * (*beta_cast);
|
||||
}else{
|
||||
tC[1] = result;
|
||||
}
|
||||
|
||||
ymm10 = _mm256_hadd_ps(ymm10, ymm10);
|
||||
ymm10 = _mm256_hadd_ps(ymm10, ymm10);
|
||||
_mm256_storeu_ps(scratch, ymm10);
|
||||
result = scratch[0] + scratch[4];
|
||||
result *= (*alpha_cast);
|
||||
if(isbetanonzero){
|
||||
tC[2] = result + tC[2] * (*beta_cast);
|
||||
}else{
|
||||
tC[2] = result;
|
||||
}
|
||||
|
||||
ymm13 = _mm256_hadd_ps(ymm13, ymm13);
|
||||
ymm13 = _mm256_hadd_ps(ymm13, ymm13);
|
||||
_mm256_storeu_ps(scratch, ymm13);
|
||||
result = scratch[0] + scratch[4];
|
||||
result *= (*alpha_cast);
|
||||
if(isbetanonzero){
|
||||
tC[3] = result + tC[3] * (*beta_cast);
|
||||
}else{
|
||||
tC[3] = result;
|
||||
}
|
||||
|
||||
tC += ldc;
|
||||
ymm5 = _mm256_hadd_ps(ymm5, ymm5);
|
||||
@@ -3279,28 +3366,44 @@ static err_t bli_sgemm_small_atbn
|
||||
_mm256_storeu_ps(scratch, ymm5);
|
||||
result = scratch[0] + scratch[4];
|
||||
result *= (*alpha_cast);
|
||||
if(isbetanonzero){
|
||||
tC[0] = result + tC[0] * (*beta_cast);
|
||||
}else{
|
||||
tC[0] = result;
|
||||
}
|
||||
|
||||
ymm8 = _mm256_hadd_ps(ymm8, ymm8);
|
||||
ymm8 = _mm256_hadd_ps(ymm8, ymm8);
|
||||
_mm256_storeu_ps(scratch, ymm8);
|
||||
result = scratch[0] + scratch[4];
|
||||
result *= (*alpha_cast);
|
||||
if(isbetanonzero){
|
||||
tC[1] = result + tC[1] * (*beta_cast);
|
||||
}else{
|
||||
tC[1] = result;
|
||||
}
|
||||
|
||||
ymm11 = _mm256_hadd_ps(ymm11, ymm11);
|
||||
ymm11 = _mm256_hadd_ps(ymm11, ymm11);
|
||||
_mm256_storeu_ps(scratch, ymm11);
|
||||
result = scratch[0] + scratch[4];
|
||||
result *= (*alpha_cast);
|
||||
if(isbetanonzero){
|
||||
tC[2] = result + tC[2] * (*beta_cast);
|
||||
}else{
|
||||
tC[2] = result;
|
||||
}
|
||||
|
||||
ymm14 = _mm256_hadd_ps(ymm14, ymm14);
|
||||
ymm14 = _mm256_hadd_ps(ymm14, ymm14);
|
||||
_mm256_storeu_ps(scratch, ymm14);
|
||||
result = scratch[0] + scratch[4];
|
||||
result *= (*alpha_cast);
|
||||
if(isbetanonzero){
|
||||
tC[3] = result + tC[3] * (*beta_cast);
|
||||
}else{
|
||||
tC[3] = result;
|
||||
}
|
||||
|
||||
tC += ldc;
|
||||
ymm6 = _mm256_hadd_ps(ymm6, ymm6);
|
||||
@@ -3308,28 +3411,44 @@ static err_t bli_sgemm_small_atbn
|
||||
_mm256_storeu_ps(scratch, ymm6);
|
||||
result = scratch[0] + scratch[4];
|
||||
result *= (*alpha_cast);
|
||||
if(isbetanonzero){
|
||||
tC[0] = result + tC[0] * (*beta_cast);
|
||||
}else{
|
||||
tC[0] = result;
|
||||
}
|
||||
|
||||
ymm9 = _mm256_hadd_ps(ymm9, ymm9);
|
||||
ymm9 = _mm256_hadd_ps(ymm9, ymm9);
|
||||
_mm256_storeu_ps(scratch, ymm9);
|
||||
result = scratch[0] + scratch[4];
|
||||
result *= (*alpha_cast);
|
||||
if(isbetanonzero){
|
||||
tC[1] = result + tC[1] * (*beta_cast);
|
||||
}else{
|
||||
tC[1] = result;
|
||||
}
|
||||
|
||||
ymm12 = _mm256_hadd_ps(ymm12, ymm12);
|
||||
ymm12 = _mm256_hadd_ps(ymm12, ymm12);
|
||||
_mm256_storeu_ps(scratch, ymm12);
|
||||
result = scratch[0] + scratch[4];
|
||||
result *= (*alpha_cast);
|
||||
if(isbetanonzero){
|
||||
tC[2] = result + tC[2] * (*beta_cast);
|
||||
}else{
|
||||
tC[2] = result;
|
||||
}
|
||||
|
||||
ymm15 = _mm256_hadd_ps(ymm15, ymm15);
|
||||
ymm15 = _mm256_hadd_ps(ymm15, ymm15);
|
||||
_mm256_storeu_ps(scratch, ymm15);
|
||||
result = scratch[0] + scratch[4];
|
||||
result *= (*alpha_cast);
|
||||
if(isbetanonzero){
|
||||
tC[3] = result + tC[3] * (*beta_cast);
|
||||
}else{
|
||||
tC[3] = result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3413,29 +3532,44 @@ static err_t bli_sgemm_small_atbn
|
||||
_mm256_storeu_ps(scratch, ymm4);
|
||||
result = scratch[0] + scratch[4];
|
||||
result *= (*alpha_cast);
|
||||
if(isbetanonzero){
|
||||
tC[0] = result + tC[0] * (*beta_cast);
|
||||
}else{
|
||||
tC[0] = result;
|
||||
}
|
||||
|
||||
ymm7 = _mm256_hadd_ps(ymm7, ymm7);
|
||||
ymm7 = _mm256_hadd_ps(ymm7, ymm7);
|
||||
_mm256_storeu_ps(scratch, ymm7);
|
||||
result = scratch[0] + scratch[4];
|
||||
result *= (*alpha_cast);
|
||||
if(isbetanonzero){
|
||||
tC[1] = result + tC[1] * (*beta_cast);
|
||||
}else{
|
||||
tC[1] = result;
|
||||
}
|
||||
|
||||
ymm10 = _mm256_hadd_ps(ymm10, ymm10);
|
||||
ymm10 = _mm256_hadd_ps(ymm10, ymm10);
|
||||
_mm256_storeu_ps(scratch, ymm10);
|
||||
result = scratch[0] + scratch[4];
|
||||
result *= (*alpha_cast);
|
||||
if(isbetanonzero){
|
||||
tC[2] = result + tC[2] * (*beta_cast);
|
||||
}else{
|
||||
tC[2] = result;
|
||||
}
|
||||
|
||||
ymm13 = _mm256_hadd_ps(ymm13, ymm13);
|
||||
ymm13 = _mm256_hadd_ps(ymm13, ymm13);
|
||||
_mm256_storeu_ps(scratch, ymm13);
|
||||
result = scratch[0] + scratch[4];
|
||||
result *= (*alpha_cast);
|
||||
if(isbetanonzero){
|
||||
tC[3] = result + tC[3] * (*beta_cast);
|
||||
|
||||
}else{
|
||||
tC[3] = result;
|
||||
}
|
||||
}
|
||||
}
|
||||
processed_row = row_idx;
|
||||
@@ -3485,7 +3619,11 @@ static err_t bli_sgemm_small_atbn
|
||||
_mm256_storeu_ps(scratch, ymm4);
|
||||
result = scratch[0] + scratch[4];
|
||||
result *= (*alpha_cast);
|
||||
if(isbetanonzero){
|
||||
tC[0] = result + tC[0] * (*beta_cast);
|
||||
}else{
|
||||
tC[0] = result;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user