diff --git a/kernels/zen/3/bli_gemm_small.c b/kernels/zen/3/bli_gemm_small.c index aced37b75..f678f745f 100644 --- a/kernels/zen/3/bli_gemm_small.c +++ b/kernels/zen/3/bli_gemm_small.c @@ -174,7 +174,6 @@ static err_t bli_sgemm_small gint_t K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . gint_t L = M * N; - // printf("alpha_cast = %f beta_cast = %f [ Trans = %d %d], [stride = %d %d %d] [m,n,k = %d %d %d]\n",*alpha_cast,*beta_cast, bli_obj_has_trans( a ), bli_obj_has_trans( b ), lda, ldb,ldc, M,N,K); if ((((L) < (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES)) || ((M < BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0))) { @@ -187,12 +186,13 @@ static err_t bli_sgemm_small float *C = c->buffer; // pointer to elements of Matrix C float *tA = A, *tB = B, *tC = C;//, *tA_pack; - float *tA_packed; // temprorary pointer to hold packed A memory pointer + float *tA_packed; // temporary pointer to hold packed A memory pointer guint_t row_idx_packed; //packed A memory row index guint_t lda_packed; //lda of packed A guint_t col_idx_start; //starting index after A matrix is packed. dim_t tb_inc_row = 1; // row stride of matrix B dim_t tb_inc_col = ldb; // column stride of matrix B + __m256 ymm4, ymm5, ymm6, ymm7; __m256 ymm8, ymm9, ymm10, ymm11; __m256 ymm12, ymm13, ymm14, ymm15; @@ -208,7 +208,13 @@ static err_t bli_sgemm_small mem_t local_mem_buf_A_s; float *A_pack = NULL; rntm_t rntm; - + + /*Beta Zero Check*/ + guint_t isbetanonzero=0; + if ( !bli_obj_equals( beta, &BLIS_ZERO ) ){ + isbetanonzero = 1; + } + // when N is equal to 1 call GEMV instead of GEMM if (N == 1) { @@ -224,8 +230,7 @@ static err_t bli_sgemm_small } //update the pointer math if matrix B needs to be transposed. - if (bli_obj_has_trans( b )) - { + if (bli_obj_has_trans( b )) { tb_inc_col = 1; //switch row and column strides tb_inc_row = ldb; } @@ -252,17 +257,15 @@ static err_t bli_sgemm_small bli_rntm_set_num_threads_only( 1, &rntm ); bli_membrk_rntm_set_membrk( &rntm ); - // Get the current size of the buffer pool for A block packing. - // We will use the same size to avoid pool re-initliazaton - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + // We will use the same size to avoid pool re-initialization + siz_t buffer_size = bli_pool_block_size(bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); // Based on the available memory in the buffer we will decide if // we want to do packing or not. // - // This kernel assumes that "A" will be unpackged if N <= 3. + // This kernel assumes that "A" will be un-packged if N <= 3. // Usually this range (N <= 3) is handled by SUP, however, // if SUP is disabled or for any other condition if we do // enter this kernel with N <= 3, we want to make sure that @@ -299,7 +302,6 @@ static err_t bli_sgemm_small // Process MR rows of C matrix at a time. for (row_idx = 0; (row_idx + (MR - 1)) < M; row_idx += MR) { - col_idx_start = 0; tA_packed = A; row_idx_packed = row_idx; @@ -394,7 +396,6 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm4 = _mm256_mul_ps(ymm4, ymm0); @@ -410,15 +411,39 @@ static err_t bli_sgemm_small ymm14 = _mm256_mul_ps(ymm14, ymm0); ymm15 = _mm256_mul_ps(ymm15, ymm0); - // multiply C by beta and accumulate col 1. - ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); + if(isbetanonzero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + // multiply C by beta and accumulate col 1. + ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_ps(tC + 24); + ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); + + float* ttC = tC +ldc; + ymm2 = _mm256_loadu_ps(ttC); + ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_ps(ttC + 16); + ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); + ymm2 = _mm256_loadu_ps(ttC + 24); + ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11); + + ttC += ldc; + ymm2 = _mm256_loadu_ps(ttC); + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_ps(ttC + 16); + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_ps(ttC + 24); + ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15); + } _mm256_storeu_ps(tC, ymm4); _mm256_storeu_ps(tC + 8, ymm5); _mm256_storeu_ps(tC + 16, ymm6); @@ -426,14 +451,6 @@ static err_t bli_sgemm_small // multiply C by beta and accumulate, col 2. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11); _mm256_storeu_ps(tC, ymm8); _mm256_storeu_ps(tC + 8, ymm9); _mm256_storeu_ps(tC + 16, ymm10); @@ -441,14 +458,6 @@ static err_t bli_sgemm_small // multiply C by beta and accumulate, col 3. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15); _mm256_storeu_ps(tC, ymm12); _mm256_storeu_ps(tC + 8, ymm13); _mm256_storeu_ps(tC + 16, ymm14); @@ -538,7 +547,6 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm4 = _mm256_mul_ps(ymm4, ymm0); @@ -554,15 +562,37 @@ static err_t bli_sgemm_small ymm14 = _mm256_mul_ps(ymm14, ymm0); ymm15 = _mm256_mul_ps(ymm15, ymm0); - // multiply C by beta and accumulate col 1. - ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); + if(isbetanonzero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + // multiply C by beta and accumulate col 1. + ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_ps(tC + 24); + ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); + float* ttC = tC +ldc; + ymm2 = _mm256_loadu_ps(ttC); + ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_ps(ttC + 16); + ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); + ymm2 = _mm256_loadu_ps(ttC + 24); + ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11); + ttC = ttC +ldc; + ymm2 = _mm256_loadu_ps(ttC); + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_ps(ttC + 16); + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_ps(ttC + 24); + ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15); + } _mm256_storeu_ps(tC, ymm4); _mm256_storeu_ps(tC + 8, ymm5); _mm256_storeu_ps(tC + 16, ymm6); @@ -570,14 +600,6 @@ static err_t bli_sgemm_small // multiply C by beta and accumulate, col 2. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11); _mm256_storeu_ps(tC, ymm8); _mm256_storeu_ps(tC + 8, ymm9); _mm256_storeu_ps(tC + 16, ymm10); @@ -585,14 +607,6 @@ static err_t bli_sgemm_small // multiply C by beta and accumulate, col 3. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15); _mm256_storeu_ps(tC, ymm12); _mm256_storeu_ps(tC + 8, ymm13); _mm256_storeu_ps(tC + 16, ymm14); @@ -651,7 +665,6 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm8 = _mm256_mul_ps(ymm8, ymm0); @@ -664,29 +677,34 @@ static err_t bli_sgemm_small ymm15 = _mm256_mul_ps(ymm15, ymm0); // multiply C by beta and accumulate, col 1. - ymm2 = _mm256_loadu_ps(tC + 0); - ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11); - _mm256_storeu_ps(tC + 0, ymm8); + if(isbetanonzero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + ymm2 = _mm256_loadu_ps(tC); + ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); + ymm2 = _mm256_loadu_ps(tC + 24); + ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11); + + float* ttC = tC +ldc; + // multiply C by beta and accumulate, col 2. + ymm2 = _mm256_loadu_ps(ttC); + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_ps(ttC + 16); + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_ps(ttC + 24); + ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15); + } + _mm256_storeu_ps(tC, ymm8); _mm256_storeu_ps(tC + 8, ymm9); _mm256_storeu_ps(tC + 16, ymm10); _mm256_storeu_ps(tC + 24, ymm11); - - // multiply C by beta and accumulate, col 2. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15); _mm256_storeu_ps(tC, ymm12); _mm256_storeu_ps(tC + 8, ymm13); _mm256_storeu_ps(tC + 16, ymm14); @@ -735,7 +753,6 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm12 = _mm256_mul_ps(ymm12, ymm0); @@ -743,15 +760,19 @@ static err_t bli_sgemm_small ymm14 = _mm256_mul_ps(ymm14, ymm0); ymm15 = _mm256_mul_ps(ymm15, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_ps(tC + 0); - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15); + if(isbetanonzero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_ps(tC + 0); + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); + ymm2 = _mm256_loadu_ps(tC + 24); + ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15); + } _mm256_storeu_ps(tC + 0, ymm12); _mm256_storeu_ps(tC + 8, ymm13); @@ -823,7 +844,6 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm4 = _mm256_mul_ps(ymm4, ymm0); @@ -836,37 +856,43 @@ static err_t bli_sgemm_small ymm13 = _mm256_mul_ps(ymm13, ymm0); ymm14 = _mm256_mul_ps(ymm14, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); + if(isbetanonzero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); + float* ttC = tC +ldc; + ymm2 = _mm256_loadu_ps(ttC); + ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_ps(ttC + 16); + ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); + ttC += ldc; + ymm2 = _mm256_loadu_ps(ttC); + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_ps(ttC + 16); + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); + } _mm256_storeu_ps(tC, ymm4); _mm256_storeu_ps(tC + 8, ymm5); _mm256_storeu_ps(tC + 16, ymm6); // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); _mm256_storeu_ps(tC, ymm8); _mm256_storeu_ps(tC + 8, ymm9); _mm256_storeu_ps(tC + 16, ymm10); // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); _mm256_storeu_ps(tC, ymm12); _mm256_storeu_ps(tC + 8, ymm13); _mm256_storeu_ps(tC + 16, ymm14); @@ -917,7 +943,6 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm8 = _mm256_mul_ps(ymm8, ymm0); @@ -927,25 +952,33 @@ static err_t bli_sgemm_small ymm13 = _mm256_mul_ps(ymm13, ymm0); ymm14 = _mm256_mul_ps(ymm14, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_ps(tC + 0); - ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); - _mm256_storeu_ps(tC + 0, ymm8); + if(isbetanonzero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_ps(tC); + ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_ps(tC + 16); + ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); + + float* ttC = tC +ldc; + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_ps(ttC); + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_ps(ttC + 16); + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); + } + + _mm256_storeu_ps(tC, ymm8); _mm256_storeu_ps(tC + 8, ymm9); _mm256_storeu_ps(tC + 16, ymm10); - // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); + _mm256_storeu_ps(tC, ymm12); _mm256_storeu_ps(tC + 8, ymm13); _mm256_storeu_ps(tC + 16, ymm14); @@ -989,13 +1022,15 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm12 = _mm256_mul_ps(ymm12, ymm0); ymm13 = _mm256_mul_ps(ymm13, ymm0); ymm14 = _mm256_mul_ps(ymm14, ymm0); + if(isbetanonzero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); // multiply C by beta and accumulate. ymm2 = _mm256_loadu_ps(tC + 0); ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); @@ -1003,7 +1038,7 @@ static err_t bli_sgemm_small ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); ymm2 = _mm256_loadu_ps(tC + 16); ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); - + } _mm256_storeu_ps(tC + 0, ymm12); _mm256_storeu_ps(tC + 8, ymm13); _mm256_storeu_ps(tC + 16, ymm14); @@ -1056,7 +1091,6 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm4 = _mm256_mul_ps(ymm4, ymm0); @@ -1066,29 +1100,35 @@ static err_t bli_sgemm_small ymm8 = _mm256_mul_ps(ymm8, ymm0); ymm9 = _mm256_mul_ps(ymm9, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + if(isbetanonzero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + float* ttC = tC + ldc; + ymm2 = _mm256_loadu_ps(ttC); + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); + ttC += ldc; + ymm2 = _mm256_loadu_ps(ttC); + ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); + } _mm256_storeu_ps(tC, ymm4); _mm256_storeu_ps(tC + 8, ymm5); // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); _mm256_storeu_ps(tC, ymm6); _mm256_storeu_ps(tC + 8, ymm7); // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); _mm256_storeu_ps(tC, ymm8); _mm256_storeu_ps(tC + 8, ymm9); @@ -1131,7 +1171,6 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm4 = _mm256_mul_ps(ymm4, ymm0); @@ -1139,20 +1178,25 @@ static err_t bli_sgemm_small ymm6 = _mm256_mul_ps(ymm6, ymm0); ymm7 = _mm256_mul_ps(ymm7, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + if(isbetanonzero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + float* ttC = tC + ldc; + ymm2 = _mm256_loadu_ps(ttC); + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); + ymm2 = _mm256_loadu_ps(ttC + 8); + ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); + } _mm256_storeu_ps(tC, ymm4); _mm256_storeu_ps(tC + 8, ymm5); // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); _mm256_storeu_ps(tC, ymm6); _mm256_storeu_ps(tC + 8, ymm7); @@ -1190,16 +1234,19 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); ymm4 = _mm256_mul_ps(ymm4, ymm0); ymm5 = _mm256_mul_ps(ymm5, ymm0); // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + if(isbetanonzero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + 8); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + } _mm256_storeu_ps(tC, ymm4); _mm256_storeu_ps(tC + 8, ymm5); @@ -1244,28 +1291,30 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm4 = _mm256_mul_ps(ymm4, ymm0); ymm5 = _mm256_mul_ps(ymm5, ymm0); ymm6 = _mm256_mul_ps(ymm6, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + if(isbetanonzero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + ldc); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + ymm2 = _mm256_loadu_ps(tC + 2*ldc); + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); + } _mm256_storeu_ps(tC, ymm4); // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); _mm256_storeu_ps(tC, ymm5); // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); _mm256_storeu_ps(tC, ymm6); } n_remainder = N - col_idx; @@ -1299,21 +1348,23 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); //multiply A*B by alpha. ymm4 = _mm256_mul_ps(ymm4, ymm0); ymm5 = _mm256_mul_ps(ymm5, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + if(isbetanonzero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_ps(tC + ldc); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + } _mm256_storeu_ps(tC, ymm4); - // multiply C by beta and accumulate. tC += ldc; - ymm2 = _mm256_loadu_ps(tC); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); _mm256_storeu_ps(tC, ymm5); col_idx += 2; @@ -1346,13 +1397,15 @@ static err_t bli_sgemm_small } // alpha, beta multiplication. ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); - ymm4 = _mm256_mul_ps(ymm4, ymm0); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + if(isbetanonzero) + { + ymm1 = _mm256_broadcast_ss(beta_cast); + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_ps(tC); + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); + } _mm256_storeu_ps(tC, ymm4); } @@ -1426,7 +1479,8 @@ static err_t bli_sgemm_small f_temp[i] = tC[i]; } ymm2 = _mm256_loadu_ps(f_temp); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + if(isbetanonzero) + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); _mm256_storeu_ps(f_temp, ymm5); for (int i = 0; i < m_remainder; i++) { @@ -1439,7 +1493,8 @@ static err_t bli_sgemm_small f_temp[i] = tC[i]; } ymm2 = _mm256_loadu_ps(f_temp); - ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); + if(isbetanonzero) + ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); _mm256_storeu_ps(f_temp, ymm7); for (int i = 0; i < m_remainder; i++) { @@ -1452,7 +1507,8 @@ static err_t bli_sgemm_small f_temp[i] = tC[i]; } ymm2 = _mm256_loadu_ps(f_temp); - ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); + if(isbetanonzero) + ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); _mm256_storeu_ps(f_temp, ymm9); for (int i = 0; i < m_remainder; i++) { @@ -1510,7 +1566,8 @@ static err_t bli_sgemm_small f_temp[i] = tC[i]; } ymm2 = _mm256_loadu_ps(f_temp); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + if(isbetanonzero) + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); _mm256_storeu_ps(f_temp, ymm5); for (int i = 0; i < m_remainder; i++) { @@ -1523,7 +1580,8 @@ static err_t bli_sgemm_small f_temp[i] = tC[i]; } ymm2 = _mm256_loadu_ps(f_temp); - ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); + if(isbetanonzero) + ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); _mm256_storeu_ps(f_temp, ymm7); for (int i = 0; i < m_remainder; i++) { @@ -1565,7 +1623,6 @@ static err_t bli_sgemm_small ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); ymm0 = _mm256_broadcast_ss(alpha_cast); - ymm1 = _mm256_broadcast_ss(beta_cast); // multiply C by beta and accumulate. ymm5 = _mm256_mul_ps(ymm5, ymm0); @@ -1575,7 +1632,10 @@ static err_t bli_sgemm_small f_temp[i] = tC[i]; } ymm2 = _mm256_loadu_ps(f_temp); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + if(isbetanonzero){ + ymm1 = _mm256_broadcast_ss(beta_cast); + ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); + } _mm256_storeu_ps(f_temp, ymm5); for (int i = 0; i < m_remainder; i++) { @@ -1606,7 +1666,11 @@ static err_t bli_sgemm_small } result *= (*alpha_cast); - (*tC) = (*tC) * (*beta_cast) + result; + if(isbetanonzero){ + (*tC) = (*tC) * (*beta_cast) + result; + }else{ + (*tC) = result; + } } } } @@ -3089,7 +3153,7 @@ static err_t bli_dgemm_small } // Return the buffer to pool - if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { + if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_dgemm_small(): releasing mem pool block\n" ); #endif @@ -3139,6 +3203,13 @@ static err_t bli_sgemm_small_atbn alpha_cast = (alpha->buffer); beta_cast = (beta->buffer); + /*Beta Zero Check*/ + guint_t isbetanonzero=0; + if ( !bli_obj_equals( beta, &BLIS_ZERO ) ){ + isbetanonzero = 1; + } + + // The non-copy version of the A^T GEMM gives better performance for the small M cases. // The threshold is controlled by BLIS_ATBN_M_THRES if (M <= BLIS_ATBN_M_THRES) @@ -3250,28 +3321,44 @@ static err_t bli_sgemm_small_atbn _mm256_storeu_ps(scratch, ymm4); result = scratch[0] + scratch[4]; result *= (*alpha_cast); + if(isbetanonzero){ tC[0] = result + tC[0] * (*beta_cast); + }else{ + tC[0] = result; + } ymm7 = _mm256_hadd_ps(ymm7, ymm7); ymm7 = _mm256_hadd_ps(ymm7, ymm7); _mm256_storeu_ps(scratch, ymm7); result = scratch[0] + scratch[4]; result *= (*alpha_cast); + if(isbetanonzero){ tC[1] = result + tC[1] * (*beta_cast); + }else{ + tC[1] = result; + } ymm10 = _mm256_hadd_ps(ymm10, ymm10); ymm10 = _mm256_hadd_ps(ymm10, ymm10); _mm256_storeu_ps(scratch, ymm10); result = scratch[0] + scratch[4]; result *= (*alpha_cast); + if(isbetanonzero){ tC[2] = result + tC[2] * (*beta_cast); + }else{ + tC[2] = result; + } ymm13 = _mm256_hadd_ps(ymm13, ymm13); ymm13 = _mm256_hadd_ps(ymm13, ymm13); _mm256_storeu_ps(scratch, ymm13); result = scratch[0] + scratch[4]; result *= (*alpha_cast); + if(isbetanonzero){ tC[3] = result + tC[3] * (*beta_cast); + }else{ + tC[3] = result; + } tC += ldc; ymm5 = _mm256_hadd_ps(ymm5, ymm5); @@ -3279,28 +3366,44 @@ static err_t bli_sgemm_small_atbn _mm256_storeu_ps(scratch, ymm5); result = scratch[0] + scratch[4]; result *= (*alpha_cast); + if(isbetanonzero){ tC[0] = result + tC[0] * (*beta_cast); + }else{ + tC[0] = result; + } ymm8 = _mm256_hadd_ps(ymm8, ymm8); ymm8 = _mm256_hadd_ps(ymm8, ymm8); _mm256_storeu_ps(scratch, ymm8); result = scratch[0] + scratch[4]; result *= (*alpha_cast); + if(isbetanonzero){ tC[1] = result + tC[1] * (*beta_cast); + }else{ + tC[1] = result; + } ymm11 = _mm256_hadd_ps(ymm11, ymm11); ymm11 = _mm256_hadd_ps(ymm11, ymm11); _mm256_storeu_ps(scratch, ymm11); result = scratch[0] + scratch[4]; result *= (*alpha_cast); + if(isbetanonzero){ tC[2] = result + tC[2] * (*beta_cast); + }else{ + tC[2] = result; + } ymm14 = _mm256_hadd_ps(ymm14, ymm14); ymm14 = _mm256_hadd_ps(ymm14, ymm14); _mm256_storeu_ps(scratch, ymm14); result = scratch[0] + scratch[4]; result *= (*alpha_cast); + if(isbetanonzero){ tC[3] = result + tC[3] * (*beta_cast); + }else{ + tC[3] = result; + } tC += ldc; ymm6 = _mm256_hadd_ps(ymm6, ymm6); @@ -3308,28 +3411,44 @@ static err_t bli_sgemm_small_atbn _mm256_storeu_ps(scratch, ymm6); result = scratch[0] + scratch[4]; result *= (*alpha_cast); + if(isbetanonzero){ tC[0] = result + tC[0] * (*beta_cast); + }else{ + tC[0] = result; + } ymm9 = _mm256_hadd_ps(ymm9, ymm9); ymm9 = _mm256_hadd_ps(ymm9, ymm9); _mm256_storeu_ps(scratch, ymm9); result = scratch[0] + scratch[4]; result *= (*alpha_cast); + if(isbetanonzero){ tC[1] = result + tC[1] * (*beta_cast); + }else{ + tC[1] = result; + } ymm12 = _mm256_hadd_ps(ymm12, ymm12); ymm12 = _mm256_hadd_ps(ymm12, ymm12); _mm256_storeu_ps(scratch, ymm12); result = scratch[0] + scratch[4]; result *= (*alpha_cast); + if(isbetanonzero){ tC[2] = result + tC[2] * (*beta_cast); + }else{ + tC[2] = result; + } ymm15 = _mm256_hadd_ps(ymm15, ymm15); ymm15 = _mm256_hadd_ps(ymm15, ymm15); _mm256_storeu_ps(scratch, ymm15); result = scratch[0] + scratch[4]; result *= (*alpha_cast); + if(isbetanonzero){ tC[3] = result + tC[3] * (*beta_cast); + }else{ + tC[3] = result; + } } } @@ -3413,29 +3532,44 @@ static err_t bli_sgemm_small_atbn _mm256_storeu_ps(scratch, ymm4); result = scratch[0] + scratch[4]; result *= (*alpha_cast); + if(isbetanonzero){ tC[0] = result + tC[0] * (*beta_cast); + }else{ + tC[0] = result; + } ymm7 = _mm256_hadd_ps(ymm7, ymm7); ymm7 = _mm256_hadd_ps(ymm7, ymm7); _mm256_storeu_ps(scratch, ymm7); result = scratch[0] + scratch[4]; result *= (*alpha_cast); + if(isbetanonzero){ tC[1] = result + tC[1] * (*beta_cast); + }else{ + tC[1] = result; + } ymm10 = _mm256_hadd_ps(ymm10, ymm10); ymm10 = _mm256_hadd_ps(ymm10, ymm10); _mm256_storeu_ps(scratch, ymm10); result = scratch[0] + scratch[4]; result *= (*alpha_cast); + if(isbetanonzero){ tC[2] = result + tC[2] * (*beta_cast); + }else{ + tC[2] = result; + } ymm13 = _mm256_hadd_ps(ymm13, ymm13); ymm13 = _mm256_hadd_ps(ymm13, ymm13); _mm256_storeu_ps(scratch, ymm13); result = scratch[0] + scratch[4]; result *= (*alpha_cast); + if(isbetanonzero){ tC[3] = result + tC[3] * (*beta_cast); - + }else{ + tC[3] = result; + } } } processed_row = row_idx; @@ -3485,7 +3619,11 @@ static err_t bli_sgemm_small_atbn _mm256_storeu_ps(scratch, ymm4); result = scratch[0] + scratch[4]; result *= (*alpha_cast); + if(isbetanonzero){ tC[0] = result + tC[0] * (*beta_cast); + }else{ + tC[0] = result; + } } }