Beta Zero Checks for sgemm_small

Change-Id: I111b66ad54a27b1977d155904738a55a351e6689
This commit is contained in:
Nallani Bhaskar
2020-03-06 16:29:30 +05:30
committed by Nallani Bhaskar
parent cc98047fd6
commit e0c95d77e1

View File

@@ -174,7 +174,6 @@ static err_t bli_sgemm_small
gint_t K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) .
gint_t L = M * N;
// printf("alpha_cast = %f beta_cast = %f [ Trans = %d %d], [stride = %d %d %d] [m,n,k = %d %d %d]\n",*alpha_cast,*beta_cast, bli_obj_has_trans( a ), bli_obj_has_trans( b ), lda, ldb,ldc, M,N,K);
if ((((L) < (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES))
|| ((M < BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0)))
{
@@ -187,12 +186,13 @@ static err_t bli_sgemm_small
float *C = c->buffer; // pointer to elements of Matrix C
float *tA = A, *tB = B, *tC = C;//, *tA_pack;
float *tA_packed; // temprorary pointer to hold packed A memory pointer
float *tA_packed; // temporary pointer to hold packed A memory pointer
guint_t row_idx_packed; //packed A memory row index
guint_t lda_packed; //lda of packed A
guint_t col_idx_start; //starting index after A matrix is packed.
dim_t tb_inc_row = 1; // row stride of matrix B
dim_t tb_inc_col = ldb; // column stride of matrix B
__m256 ymm4, ymm5, ymm6, ymm7;
__m256 ymm8, ymm9, ymm10, ymm11;
__m256 ymm12, ymm13, ymm14, ymm15;
@@ -208,7 +208,13 @@ static err_t bli_sgemm_small
mem_t local_mem_buf_A_s;
float *A_pack = NULL;
rntm_t rntm;
/*Beta Zero Check*/
guint_t isbetanonzero=0;
if ( !bli_obj_equals( beta, &BLIS_ZERO ) ){
isbetanonzero = 1;
}
// when N is equal to 1 call GEMV instead of GEMM
if (N == 1)
{
@@ -224,8 +230,7 @@ static err_t bli_sgemm_small
}
//update the pointer math if matrix B needs to be transposed.
if (bli_obj_has_trans( b ))
{
if (bli_obj_has_trans( b )) {
tb_inc_col = 1; //switch row and column strides
tb_inc_row = ldb;
}
@@ -252,17 +257,15 @@ static err_t bli_sgemm_small
bli_rntm_set_num_threads_only( 1, &rntm );
bli_membrk_rntm_set_membrk( &rntm );
// Get the current size of the buffer pool for A block packing.
// We will use the same size to avoid pool re-initliazaton
siz_t buffer_size = bli_pool_block_size(
bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK),
bli_rntm_membrk(&rntm)));
// We will use the same size to avoid pool re-initialization
siz_t buffer_size = bli_pool_block_size(bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK),
bli_rntm_membrk(&rntm)));
// Based on the available memory in the buffer we will decide if
// we want to do packing or not.
//
// This kernel assumes that "A" will be unpackged if N <= 3.
// This kernel assumes that "A" will be un-packged if N <= 3.
// Usually this range (N <= 3) is handled by SUP, however,
// if SUP is disabled or for any other condition if we do
// enter this kernel with N <= 3, we want to make sure that
@@ -299,7 +302,6 @@ static err_t bli_sgemm_small
// Process MR rows of C matrix at a time.
for (row_idx = 0; (row_idx + (MR - 1)) < M; row_idx += MR)
{
col_idx_start = 0;
tA_packed = A;
row_idx_packed = row_idx;
@@ -394,7 +396,6 @@ static err_t bli_sgemm_small
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_ps(ymm4, ymm0);
@@ -410,15 +411,39 @@ static err_t bli_sgemm_small
ymm14 = _mm256_mul_ps(ymm14, ymm0);
ymm15 = _mm256_mul_ps(ymm15, ymm0);
// multiply C by beta and accumulate col 1.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
if(isbetanonzero)
{
ymm1 = _mm256_broadcast_ss(beta_cast);
// multiply C by beta and accumulate col 1.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
float* ttC = tC +ldc;
ymm2 = _mm256_loadu_ps(ttC);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(ttC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_ps(ttC + 16);
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
ymm2 = _mm256_loadu_ps(ttC + 24);
ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);
ttC += ldc;
ymm2 = _mm256_loadu_ps(ttC);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(ttC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(ttC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_ps(ttC + 24);
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
}
_mm256_storeu_ps(tC, ymm4);
_mm256_storeu_ps(tC + 8, ymm5);
_mm256_storeu_ps(tC + 16, ymm6);
@@ -426,14 +451,6 @@ static err_t bli_sgemm_small
// multiply C by beta and accumulate, col 2.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);
_mm256_storeu_ps(tC, ymm8);
_mm256_storeu_ps(tC + 8, ymm9);
_mm256_storeu_ps(tC + 16, ymm10);
@@ -441,14 +458,6 @@ static err_t bli_sgemm_small
// multiply C by beta and accumulate, col 3.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
_mm256_storeu_ps(tC, ymm12);
_mm256_storeu_ps(tC + 8, ymm13);
_mm256_storeu_ps(tC + 16, ymm14);
@@ -538,7 +547,6 @@ static err_t bli_sgemm_small
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_ps(ymm4, ymm0);
@@ -554,15 +562,37 @@ static err_t bli_sgemm_small
ymm14 = _mm256_mul_ps(ymm14, ymm0);
ymm15 = _mm256_mul_ps(ymm15, ymm0);
// multiply C by beta and accumulate col 1.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
if(isbetanonzero)
{
ymm1 = _mm256_broadcast_ss(beta_cast);
// multiply C by beta and accumulate col 1.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
float* ttC = tC +ldc;
ymm2 = _mm256_loadu_ps(ttC);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(ttC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_ps(ttC + 16);
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
ymm2 = _mm256_loadu_ps(ttC + 24);
ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);
ttC = ttC +ldc;
ymm2 = _mm256_loadu_ps(ttC);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(ttC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(ttC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_ps(ttC + 24);
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
}
_mm256_storeu_ps(tC, ymm4);
_mm256_storeu_ps(tC + 8, ymm5);
_mm256_storeu_ps(tC + 16, ymm6);
@@ -570,14 +600,6 @@ static err_t bli_sgemm_small
// multiply C by beta and accumulate, col 2.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);
_mm256_storeu_ps(tC, ymm8);
_mm256_storeu_ps(tC + 8, ymm9);
_mm256_storeu_ps(tC + 16, ymm10);
@@ -585,14 +607,6 @@ static err_t bli_sgemm_small
// multiply C by beta and accumulate, col 3.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
_mm256_storeu_ps(tC, ymm12);
_mm256_storeu_ps(tC + 8, ymm13);
_mm256_storeu_ps(tC + 16, ymm14);
@@ -651,7 +665,6 @@ static err_t bli_sgemm_small
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm8 = _mm256_mul_ps(ymm8, ymm0);
@@ -664,29 +677,34 @@ static err_t bli_sgemm_small
ymm15 = _mm256_mul_ps(ymm15, ymm0);
// multiply C by beta and accumulate, col 1.
ymm2 = _mm256_loadu_ps(tC + 0);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);
_mm256_storeu_ps(tC + 0, ymm8);
if(isbetanonzero)
{
ymm1 = _mm256_broadcast_ss(beta_cast);
ymm2 = _mm256_loadu_ps(tC);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);
float* ttC = tC +ldc;
// multiply C by beta and accumulate, col 2.
ymm2 = _mm256_loadu_ps(ttC);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(ttC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(ttC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_ps(ttC + 24);
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
}
_mm256_storeu_ps(tC, ymm8);
_mm256_storeu_ps(tC + 8, ymm9);
_mm256_storeu_ps(tC + 16, ymm10);
_mm256_storeu_ps(tC + 24, ymm11);
// multiply C by beta and accumulate, col 2.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
_mm256_storeu_ps(tC, ymm12);
_mm256_storeu_ps(tC + 8, ymm13);
_mm256_storeu_ps(tC + 16, ymm14);
@@ -735,7 +753,6 @@ static err_t bli_sgemm_small
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm12 = _mm256_mul_ps(ymm12, ymm0);
@@ -743,15 +760,19 @@ static err_t bli_sgemm_small
ymm14 = _mm256_mul_ps(ymm14, ymm0);
ymm15 = _mm256_mul_ps(ymm15, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC + 0);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
if(isbetanonzero)
{
ymm1 = _mm256_broadcast_ss(beta_cast);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC + 0);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
ymm2 = _mm256_loadu_ps(tC + 24);
ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
}
_mm256_storeu_ps(tC + 0, ymm12);
_mm256_storeu_ps(tC + 8, ymm13);
@@ -823,7 +844,6 @@ static err_t bli_sgemm_small
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_ps(ymm4, ymm0);
@@ -836,37 +856,43 @@ static err_t bli_sgemm_small
ymm13 = _mm256_mul_ps(ymm13, ymm0);
ymm14 = _mm256_mul_ps(ymm14, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
if(isbetanonzero)
{
ymm1 = _mm256_broadcast_ss(beta_cast);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
float* ttC = tC +ldc;
ymm2 = _mm256_loadu_ps(ttC);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(ttC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_ps(ttC + 16);
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
ttC += ldc;
ymm2 = _mm256_loadu_ps(ttC);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(ttC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(ttC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
}
_mm256_storeu_ps(tC, ymm4);
_mm256_storeu_ps(tC + 8, ymm5);
_mm256_storeu_ps(tC + 16, ymm6);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
_mm256_storeu_ps(tC, ymm8);
_mm256_storeu_ps(tC + 8, ymm9);
_mm256_storeu_ps(tC + 16, ymm10);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
_mm256_storeu_ps(tC, ymm12);
_mm256_storeu_ps(tC + 8, ymm13);
_mm256_storeu_ps(tC + 16, ymm14);
@@ -917,7 +943,6 @@ static err_t bli_sgemm_small
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm8 = _mm256_mul_ps(ymm8, ymm0);
@@ -927,25 +952,33 @@ static err_t bli_sgemm_small
ymm13 = _mm256_mul_ps(ymm13, ymm0);
ymm14 = _mm256_mul_ps(ymm14, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC + 0);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
_mm256_storeu_ps(tC + 0, ymm8);
if(isbetanonzero)
{
ymm1 = _mm256_broadcast_ss(beta_cast);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
float* ttC = tC +ldc;
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(ttC);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(ttC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(ttC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
}
_mm256_storeu_ps(tC, ymm8);
_mm256_storeu_ps(tC + 8, ymm9);
_mm256_storeu_ps(tC + 16, ymm10);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
_mm256_storeu_ps(tC, ymm12);
_mm256_storeu_ps(tC + 8, ymm13);
_mm256_storeu_ps(tC + 16, ymm14);
@@ -989,13 +1022,15 @@ static err_t bli_sgemm_small
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm12 = _mm256_mul_ps(ymm12, ymm0);
ymm13 = _mm256_mul_ps(ymm13, ymm0);
ymm14 = _mm256_mul_ps(ymm14, ymm0);
if(isbetanonzero)
{
ymm1 = _mm256_broadcast_ss(beta_cast);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC + 0);
ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
@@ -1003,7 +1038,7 @@ static err_t bli_sgemm_small
ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
ymm2 = _mm256_loadu_ps(tC + 16);
ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
}
_mm256_storeu_ps(tC + 0, ymm12);
_mm256_storeu_ps(tC + 8, ymm13);
_mm256_storeu_ps(tC + 16, ymm14);
@@ -1056,7 +1091,6 @@ static err_t bli_sgemm_small
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_ps(ymm4, ymm0);
@@ -1066,29 +1100,35 @@ static err_t bli_sgemm_small
ymm8 = _mm256_mul_ps(ymm8, ymm0);
ymm9 = _mm256_mul_ps(ymm9, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
if(isbetanonzero)
{
ymm1 = _mm256_broadcast_ss(beta_cast);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
float* ttC = tC + ldc;
ymm2 = _mm256_loadu_ps(ttC);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_ps(ttC + 8);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
ttC += ldc;
ymm2 = _mm256_loadu_ps(ttC);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(ttC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
}
_mm256_storeu_ps(tC, ymm4);
_mm256_storeu_ps(tC + 8, ymm5);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
_mm256_storeu_ps(tC, ymm6);
_mm256_storeu_ps(tC + 8, ymm7);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
_mm256_storeu_ps(tC, ymm8);
_mm256_storeu_ps(tC + 8, ymm9);
@@ -1131,7 +1171,6 @@ static err_t bli_sgemm_small
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_ps(ymm4, ymm0);
@@ -1139,20 +1178,25 @@ static err_t bli_sgemm_small
ymm6 = _mm256_mul_ps(ymm6, ymm0);
ymm7 = _mm256_mul_ps(ymm7, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
if(isbetanonzero)
{
ymm1 = _mm256_broadcast_ss(beta_cast);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
float* ttC = tC + ldc;
ymm2 = _mm256_loadu_ps(ttC);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_ps(ttC + 8);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
}
_mm256_storeu_ps(tC, ymm4);
_mm256_storeu_ps(tC + 8, ymm5);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
_mm256_storeu_ps(tC, ymm6);
_mm256_storeu_ps(tC + 8, ymm7);
@@ -1190,16 +1234,19 @@ static err_t bli_sgemm_small
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
ymm4 = _mm256_mul_ps(ymm4, ymm0);
ymm5 = _mm256_mul_ps(ymm5, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
if(isbetanonzero)
{
ymm1 = _mm256_broadcast_ss(beta_cast);
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + 8);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
}
_mm256_storeu_ps(tC, ymm4);
_mm256_storeu_ps(tC + 8, ymm5);
@@ -1244,28 +1291,30 @@ static err_t bli_sgemm_small
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_ps(ymm4, ymm0);
ymm5 = _mm256_mul_ps(ymm5, ymm0);
ymm6 = _mm256_mul_ps(ymm6, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
if(isbetanonzero)
{
ymm1 = _mm256_broadcast_ss(beta_cast);
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + ldc);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
ymm2 = _mm256_loadu_ps(tC + 2*ldc);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
}
_mm256_storeu_ps(tC, ymm4);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
_mm256_storeu_ps(tC, ymm5);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
_mm256_storeu_ps(tC, ymm6);
}
n_remainder = N - col_idx;
@@ -1299,21 +1348,23 @@ static err_t bli_sgemm_small
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
//multiply A*B by alpha.
ymm4 = _mm256_mul_ps(ymm4, ymm0);
ymm5 = _mm256_mul_ps(ymm5, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
if(isbetanonzero)
{
ymm1 = _mm256_broadcast_ss(beta_cast);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
ymm2 = _mm256_loadu_ps(tC + ldc);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
}
_mm256_storeu_ps(tC, ymm4);
// multiply C by beta and accumulate.
tC += ldc;
ymm2 = _mm256_loadu_ps(tC);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
_mm256_storeu_ps(tC, ymm5);
col_idx += 2;
@@ -1346,13 +1397,15 @@ static err_t bli_sgemm_small
}
// alpha, beta multiplication.
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
ymm4 = _mm256_mul_ps(ymm4, ymm0);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
if(isbetanonzero)
{
ymm1 = _mm256_broadcast_ss(beta_cast);
// multiply C by beta and accumulate.
ymm2 = _mm256_loadu_ps(tC);
ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
}
_mm256_storeu_ps(tC, ymm4);
}
@@ -1426,7 +1479,8 @@ static err_t bli_sgemm_small
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_ps(f_temp);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
if(isbetanonzero)
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
_mm256_storeu_ps(f_temp, ymm5);
for (int i = 0; i < m_remainder; i++)
{
@@ -1439,7 +1493,8 @@ static err_t bli_sgemm_small
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_ps(f_temp);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
if(isbetanonzero)
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
_mm256_storeu_ps(f_temp, ymm7);
for (int i = 0; i < m_remainder; i++)
{
@@ -1452,7 +1507,8 @@ static err_t bli_sgemm_small
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_ps(f_temp);
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
if(isbetanonzero)
ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
_mm256_storeu_ps(f_temp, ymm9);
for (int i = 0; i < m_remainder; i++)
{
@@ -1510,7 +1566,8 @@ static err_t bli_sgemm_small
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_ps(f_temp);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
if(isbetanonzero)
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
_mm256_storeu_ps(f_temp, ymm5);
for (int i = 0; i < m_remainder; i++)
{
@@ -1523,7 +1580,8 @@ static err_t bli_sgemm_small
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_ps(f_temp);
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
if(isbetanonzero)
ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
_mm256_storeu_ps(f_temp, ymm7);
for (int i = 0; i < m_remainder; i++)
{
@@ -1565,7 +1623,6 @@ static err_t bli_sgemm_small
ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
ymm0 = _mm256_broadcast_ss(alpha_cast);
ymm1 = _mm256_broadcast_ss(beta_cast);
// multiply C by beta and accumulate.
ymm5 = _mm256_mul_ps(ymm5, ymm0);
@@ -1575,7 +1632,10 @@ static err_t bli_sgemm_small
f_temp[i] = tC[i];
}
ymm2 = _mm256_loadu_ps(f_temp);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
if(isbetanonzero){
ymm1 = _mm256_broadcast_ss(beta_cast);
ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
}
_mm256_storeu_ps(f_temp, ymm5);
for (int i = 0; i < m_remainder; i++)
{
@@ -1606,7 +1666,11 @@ static err_t bli_sgemm_small
}
result *= (*alpha_cast);
(*tC) = (*tC) * (*beta_cast) + result;
if(isbetanonzero){
(*tC) = (*tC) * (*beta_cast) + result;
}else{
(*tC) = result;
}
}
}
}
@@ -3089,7 +3153,7 @@ static err_t bli_dgemm_small
}
// Return the buffer to pool
if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) {
if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) {
#ifdef BLIS_ENABLE_MEM_TRACING
printf( "bli_dgemm_small(): releasing mem pool block\n" );
#endif
@@ -3139,6 +3203,13 @@ static err_t bli_sgemm_small_atbn
alpha_cast = (alpha->buffer);
beta_cast = (beta->buffer);
/*Beta Zero Check*/
guint_t isbetanonzero=0;
if ( !bli_obj_equals( beta, &BLIS_ZERO ) ){
isbetanonzero = 1;
}
// The non-copy version of the A^T GEMM gives better performance for the small M cases.
// The threshold is controlled by BLIS_ATBN_M_THRES
if (M <= BLIS_ATBN_M_THRES)
@@ -3250,28 +3321,44 @@ static err_t bli_sgemm_small_atbn
_mm256_storeu_ps(scratch, ymm4);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
if(isbetanonzero){
tC[0] = result + tC[0] * (*beta_cast);
}else{
tC[0] = result;
}
ymm7 = _mm256_hadd_ps(ymm7, ymm7);
ymm7 = _mm256_hadd_ps(ymm7, ymm7);
_mm256_storeu_ps(scratch, ymm7);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
if(isbetanonzero){
tC[1] = result + tC[1] * (*beta_cast);
}else{
tC[1] = result;
}
ymm10 = _mm256_hadd_ps(ymm10, ymm10);
ymm10 = _mm256_hadd_ps(ymm10, ymm10);
_mm256_storeu_ps(scratch, ymm10);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
if(isbetanonzero){
tC[2] = result + tC[2] * (*beta_cast);
}else{
tC[2] = result;
}
ymm13 = _mm256_hadd_ps(ymm13, ymm13);
ymm13 = _mm256_hadd_ps(ymm13, ymm13);
_mm256_storeu_ps(scratch, ymm13);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
if(isbetanonzero){
tC[3] = result + tC[3] * (*beta_cast);
}else{
tC[3] = result;
}
tC += ldc;
ymm5 = _mm256_hadd_ps(ymm5, ymm5);
@@ -3279,28 +3366,44 @@ static err_t bli_sgemm_small_atbn
_mm256_storeu_ps(scratch, ymm5);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
if(isbetanonzero){
tC[0] = result + tC[0] * (*beta_cast);
}else{
tC[0] = result;
}
ymm8 = _mm256_hadd_ps(ymm8, ymm8);
ymm8 = _mm256_hadd_ps(ymm8, ymm8);
_mm256_storeu_ps(scratch, ymm8);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
if(isbetanonzero){
tC[1] = result + tC[1] * (*beta_cast);
}else{
tC[1] = result;
}
ymm11 = _mm256_hadd_ps(ymm11, ymm11);
ymm11 = _mm256_hadd_ps(ymm11, ymm11);
_mm256_storeu_ps(scratch, ymm11);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
if(isbetanonzero){
tC[2] = result + tC[2] * (*beta_cast);
}else{
tC[2] = result;
}
ymm14 = _mm256_hadd_ps(ymm14, ymm14);
ymm14 = _mm256_hadd_ps(ymm14, ymm14);
_mm256_storeu_ps(scratch, ymm14);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
if(isbetanonzero){
tC[3] = result + tC[3] * (*beta_cast);
}else{
tC[3] = result;
}
tC += ldc;
ymm6 = _mm256_hadd_ps(ymm6, ymm6);
@@ -3308,28 +3411,44 @@ static err_t bli_sgemm_small_atbn
_mm256_storeu_ps(scratch, ymm6);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
if(isbetanonzero){
tC[0] = result + tC[0] * (*beta_cast);
}else{
tC[0] = result;
}
ymm9 = _mm256_hadd_ps(ymm9, ymm9);
ymm9 = _mm256_hadd_ps(ymm9, ymm9);
_mm256_storeu_ps(scratch, ymm9);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
if(isbetanonzero){
tC[1] = result + tC[1] * (*beta_cast);
}else{
tC[1] = result;
}
ymm12 = _mm256_hadd_ps(ymm12, ymm12);
ymm12 = _mm256_hadd_ps(ymm12, ymm12);
_mm256_storeu_ps(scratch, ymm12);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
if(isbetanonzero){
tC[2] = result + tC[2] * (*beta_cast);
}else{
tC[2] = result;
}
ymm15 = _mm256_hadd_ps(ymm15, ymm15);
ymm15 = _mm256_hadd_ps(ymm15, ymm15);
_mm256_storeu_ps(scratch, ymm15);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
if(isbetanonzero){
tC[3] = result + tC[3] * (*beta_cast);
}else{
tC[3] = result;
}
}
}
@@ -3413,29 +3532,44 @@ static err_t bli_sgemm_small_atbn
_mm256_storeu_ps(scratch, ymm4);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
if(isbetanonzero){
tC[0] = result + tC[0] * (*beta_cast);
}else{
tC[0] = result;
}
ymm7 = _mm256_hadd_ps(ymm7, ymm7);
ymm7 = _mm256_hadd_ps(ymm7, ymm7);
_mm256_storeu_ps(scratch, ymm7);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
if(isbetanonzero){
tC[1] = result + tC[1] * (*beta_cast);
}else{
tC[1] = result;
}
ymm10 = _mm256_hadd_ps(ymm10, ymm10);
ymm10 = _mm256_hadd_ps(ymm10, ymm10);
_mm256_storeu_ps(scratch, ymm10);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
if(isbetanonzero){
tC[2] = result + tC[2] * (*beta_cast);
}else{
tC[2] = result;
}
ymm13 = _mm256_hadd_ps(ymm13, ymm13);
ymm13 = _mm256_hadd_ps(ymm13, ymm13);
_mm256_storeu_ps(scratch, ymm13);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
if(isbetanonzero){
tC[3] = result + tC[3] * (*beta_cast);
}else{
tC[3] = result;
}
}
}
processed_row = row_idx;
@@ -3485,7 +3619,11 @@ static err_t bli_sgemm_small_atbn
_mm256_storeu_ps(scratch, ymm4);
result = scratch[0] + scratch[4];
result *= (*alpha_cast);
if(isbetanonzero){
tC[0] = result + tC[0] * (*beta_cast);
}else{
tC[0] = result;
}
}
}