diff --git a/frame/3/trsm/bli_trsm_front.c b/frame/3/trsm/bli_trsm_front.c index ff6264fa7..0c3bb11d2 100644 --- a/frame/3/trsm/bli_trsm_front.c +++ b/frame/3/trsm/bli_trsm_front.c @@ -55,50 +55,6 @@ void bli_trsm_front obj_t b_local; obj_t c_local; -#ifdef PRINT_SMALL_TRSM_INFO - printf("Side:: %c\n", side ? 'R' : 'L'); - if (bli_obj_datatype(*a) == BLIS_FLOAT) - printf("Alpha:: %9.2e\n", *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, *alpha))); - else if (bli_obj_datatype(*a) == BLIS_DOUBLE) - printf("Alpha is double:: %9.2e\n", *((double *)bli_obj_buffer_for_const(BLIS_DOUBLE, *alpha))); - else - printf("Unsupported datatype for Alpha\n"); - - printf("A:: M = %d, N = %d, elem_size = %d, row_off = %ld, col_off = %ld, rs = %d, cs = %d, trans = %c, TRIANG = %c, unit diag = %c\n", a->dim[0], a->dim[1], bli_obj_elem_size(*a ), bli_obj_row_off(*a), bli_obj_col_off(*a), a->rs, a->cs, bli_obj_has_trans(*a) ? 'Y' : 'N', bli_obj_is_upper(*a) ? 'U' : bli_obj_is_lower(*a) ? 'L' : 'N', bli_obj_has_unit_diag(*a) ? 'Y' : 'N'); -#ifdef PRINT_SMALL_TRSM - //bli_printm("a", a, "%4.1f", ""); -#endif - printf("B:: M = %d, N = %d, elem_size = %d, row_off = %ld, col_off = %ld, rs = %d, cs = %d, trans = %c\n", b->dim[0], b->dim[1], bli_obj_elem_size(*a ), bli_obj_row_off(*a), bli_obj_col_off(*a), b->rs, b->cs, bli_obj_has_trans(*b) ? 'Y' : 'N'); -#ifdef PRINT_SMALL_TRSM - //bli_printm("b", b, "%4.1f", ""); -#endif - fflush(stdout); -#endif -#if 0 -for (i = 0; i < m; i++) //no. of cols of B -{ - for (j = 0; j < n; j++) //no. of rows of B - { - B[i*n + j] = 1001 + j + (i*n); - } -} -for (i = 0; i < m; i++) //no. of cols of B -{ - for (j = i; j < m; j++) //no. of rows of B - { - L[i*m + j] = 2001 + j + (i*m); - } -} -#endif -#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - gint_t status = bli_trsm_small( side, alpha, a, b, cntx, cntl ); - if ( status == BLIS_SUCCESS ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } -#endif - // Check parameters. if ( bli_error_checking_is_enabled() ) bli_trsm_check( side, alpha, a, b, &BLIS_ZERO, b, cntx ); diff --git a/frame/compat/bla_trsm.c b/frame/compat/bla_trsm.c index fff59a351..f3edca62a 100644 --- a/frame/compat/bla_trsm.c +++ b/frame/compat/bla_trsm.c @@ -229,6 +229,7 @@ void PASTEF77(ch,blasname) \ (ftype*)b, rs_b, \ NULL \ ); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ return; \ } \ else if(bli_is_trans(blis_transa)) \ @@ -244,6 +245,7 @@ void PASTEF77(ch,blasname) \ (ftype*)b, rs_b, \ NULL \ ); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ return; \ } \ } \ @@ -268,6 +270,7 @@ void PASTEF77(ch,blasname) \ PASTEMAC(ch,invscals)( a_conj, b[indx] ); \ } \ }\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ return; \ } \ } \ @@ -290,6 +293,7 @@ void PASTEF77(ch,blasname) \ (ftype*)a, cs_a, rs_a, \ (ftype*)b, cs_b, \ NULL); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ return; \ } \ else if(bli_is_trans(blis_transa)) \ @@ -307,6 +311,7 @@ void PASTEF77(ch,blasname) \ (ftype*)a, cs_a, rs_a, \ (ftype*)b, cs_b, \ NULL); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ return; \ } \ } \ @@ -331,6 +336,7 @@ void PASTEF77(ch,blasname) \ PASTEMAC(ch,invscals)( a_conj, b[indx*cs_b] ); \ }\ } \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ return; \ } \ } \ @@ -374,6 +380,265 @@ void PASTEF77(ch,blasname) \ #endif #ifdef BLIS_ENABLE_BLAS -INSERT_GENTFUNC_BLAS( trsm, trsm ) +#ifdef BLIS_CONFIG_EPYC +void dtrsm_ +( + const f77_char* side, + const f77_char* uploa, + const f77_char* transa, + const f77_char* diaga, + const f77_int* m, + const f77_int* n, + const double* alpha, + const double* a, const f77_int* lda, + double* b, const f77_int* ldb +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(d), + *side, *uploa,*transa, *diaga, *m, *n, + (void*)alpha,*lda, *ldb); + + side_t blis_side; + uplo_t blis_uploa; + trans_t blis_transa; + diag_t blis_diaga; + dim_t m0, n0; + conj_t conja = BLIS_NO_CONJUGATE ; + + /* Initialize BLIS. */ + bli_init_auto(); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(trsm) + ( + MKSTR(d), + MKSTR(trsm), + side, + uploa, + transa, + diaga, + m, + n, + lda, + ldb + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const num_t dt = BLIS_DOUBLE; + + if( n0 == 1 ) + { + if( blis_side == BLIS_LEFT ) + { + if(bli_is_notrans(blis_transa)) + { + bli_dtrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + bli_dtrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) + { + /* b = alpha * b; */ + bli_dscalv_ex + ( + conja, + m0, + (double*)alpha, + b, rs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + double inva = 1.0/ *a; + for(int indx = 0; indx < m0; indx ++) + { + b[indx] = ( inva * b[indx] ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( m0 == 1 ) + { + if(blis_side == BLIS_RIGHT) + { + if(bli_is_notrans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_dtrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (double*)alpha, + (double*)a, cs_a, rs_a, + (double*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_dtrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (double*)alpha, + (double*)a, cs_a, rs_a, + (double*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) + { + /* b = alpha * b; */ + bli_dscalv_ex + ( + conja, + n0, + (double*)alpha, + b, cs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + double inva = 1.0/ *a; + for(int indx = 0; indx < n0; indx ++) + { + b[indx*cs_b] = (inva * b[indx*cs_b] ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + + const struc_t struca = BLIS_TRIANGULAR; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + + dim_t mn0_a; + + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); + + bli_obj_init_finish_1x1( dt, (double*)alpha, &alphao ); + + bli_obj_init_finish( dt, mn0_a, mn0_a, (double*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0, n0, (double*)b, rs_b, cs_b, &bo ); + + bli_obj_set_uplo( blis_uploa, &ao ); + bli_obj_set_diag( blis_diaga, &ao ); + bli_obj_set_conjtrans( blis_transa, &ao ); + + bli_obj_set_struc( struca, &ao ); + +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + /* Irrespective of num threads single thread bli_dtrsm_small + * is performing better than other implementations for [m,n]<=128 */ + /* ToDo: This condition will be tunned for single thread */ + if(m0 <=128 && n0<=128) + { + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } #endif + bli_trsmnat + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + /* Finalize BLIS. */ + bli_finalize_auto(); +} + +GENTFUNC( float, s, trsm, trsm ) +INSERT_GENTFUNC_BLAS_CZ( trsm, trsm ) +#else +INSERT_GENTFUNC_BLAS( trsm, trsm ) +#endif +#endif diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index c6ea0d12b..195819647 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -1,703 +1,387 @@ /* -BLIS -An object-based framework for developing high-performance BLAS-like -libraries. + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. -Copyright (C) 2018-2019, Advanced Micro Devices, Inc. + Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: -- Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. -- Redistributions in binary form must reproduce the above copyright -notice, this list of conditions and the following disclaimer in the -documentation and/or other materials provided with the distribution. -- Neither the name of The University of Texas at Austin nor the names -of its contributors may be used to endorse or promote products -derived from this software without specific prior written permission. + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ #include "blis.h" #ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM #include "immintrin.h" -#define GEMM_BLK_V1 8 //Block size to perform gemm and apply trsm -#define GEMM_ACCUM_A 1 //Peform B1=B1-(B0*A0) operation instead of B1'=(B0*A0) and then B1=B1-B1' -#define OPT_CACHE_BLOCKING_L1 1 //Perform trsm block-wise in blocks of GEMM_BLK_V1 instead of all columns of B together. -#define REARRANGE_SHFL 0 //Rearrange operations using blend or shuffle -#define BLI_AlXB_M_SP 16 -#define BLI_XAltB_N_SP 128 -#define BLI_AutXB_M_SP 64 -#define BLI_AutXB_N_SP 128 -// XA = B; A is lower-traingular; No transpose; double precision; non-unit diagonal -static err_t bli_dtrsm_small_XAlB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); +#define BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL -//XA = B; A is lower triabgular; No transpose; double precision; unit-diagonal -static err_t bli_dtrsm_small_XAlB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); +/* + trsm kernels function pointer +*/ +typedef err_t (*trsmsmall_ker_ft) +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); -//XA = B; A is lower-triangular; A is transposed; double precision; non-unit-diagonal -static err_t bli_dtrsm_small_XAltB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); +//AX = B; A is lower triangular; No transpose; +//double precision; non-unit diagonal +static err_t bli_dtrsm_small_AlXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); -//XA = B; A is lower-triangular; A is transposed; double precision; unit-diagonal -static err_t bli_dtrsm_small_XAltB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); +/* TRSM for the case AX = alpha * B, Double precision + * A is upper-triangular, non-transpose, non-unit diagonal + * dimensions A: mxm X: mxn B: mxn +*/ +static err_t bli_dtrsm_small_AuXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); -// XA = B; A is upper triangular; No transpose; double presicion; non-unit diagonal -static err_t bli_dtrsm_small_XAuB - ( - side_t side, - obj_t* alpha, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); +//AX = B; A is lower triangular; transpose; +//double precision; non-unit diagonal +static err_t dtrsm_AltXB_ref +( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag +); -//XA = B; A is upper triangular; No transpose; double precision; unit-diagonal -static err_t bli_dtrsm_small_XAuB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); - -//XA = B; A is upper-triangular; A is transposed; double precision; non-unit diagonal -static err_t bli_dtrsm_small_XAutB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); - -//XA = B; A is upper-triangular; A is transposed; double precision; unit diagonal -static err_t bli_dtrsm_small_XAutB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); - -//AX = B; A is lower triangular; No transpose; double precision; non-unit diagonal -static err_t bli_dtrsm_small_AlXB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); - -//AX = B; A is lower triangular; No transpose; double precision; unit diagonal -static err_t bli_dtrsm_small_AlXB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); - - - -static void (*fp_blis_strsm_microkernel)( float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b - ); -static void blis_strsm_microkernel( float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b - ); -static void blis_strsm_microkernel_alpha( float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b, - float alphaVal - ); -static void blis_strsm_microkernel_unitDiag( float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b - ); -static void blis_strsm_microkernel_alpha_unitDiag( float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b, - float alphaVal - ); -static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b); -static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b, - float alphaVal); -static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b); -static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b, - float alphaVal); - - -static void blis_dtrsm_microkernel( double *ptr_l, - double *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b - ); - -static void blis_dtrsm_microkernel_alpha( double *ptr_l, - double *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b, - double alphaVal - ); - -static void blis_dtrsm_microkernel_unitDiag( double *ptr_l, - double *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b - ); - -static void blis_dtrsm_microkernel_alpha_unitDiag( double *ptr_l, - double *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b, - double alphaVal - ); - -static void dtrsm_XAtB_block_allSmallSizedMatrices(double *ptr_l, - double *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b); -static void dtrsm_XAtB_block_allSmallSizedMatrices_alpha(double *ptr_l, - double *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b, - double alphaVal); -static void dtrsm_XAtB_block_allSmallSizedMatrices_unitDiag(double *ptr_l, - double *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b); -static void dtrsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(double *ptr_l, - double *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b, - double alphaVal); -static void trsm_AutXB_block_allSmallSizedMatrices(float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b); -static void trsm_AutXB_block_allSmallSizedMatrices_alpha(float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b, - float alpha); -static void trsm_AutXB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b); -static void trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b, - float alpha); - -//AX = B; A is lower triangular; No transpose; single precision -static err_t bli_strsm_small_AlXB - ( - side_t side, - obj_t* alpha, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); -//A.'X = B; A is upper triangular; A has to be transposed; single precision -static err_t bli_strsm_small_AutXB - ( - side_t side, - obj_t* alpha, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); - -//XA.' = B; A is lower triangular; A has to be transposed; single precision -static err_t bli_strsm_small_XAltB - ( - side_t side, - obj_t* alpha, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); - -//A.'X = B; A is upper triangular; A has to be transposed; double precision +//A.'X = B; A is upper triangular; +//A has to be transposed; double precision static err_t bli_dtrsm_small_AutXB - ( - side_t side, - obj_t* alpha, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); +( + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +//AX = B; A is lower triangular; transpose; double precision +static err_t bli_dtrsm_small_AltXB +( + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); /* -* The bli_trsm_small implements unpacked version of TRSM -* Currently only column-major is supported, A & B are column-major -* Input: A: MxM (triangular matrix) -* B: MxN matrix -* Output: X: MxN matrix such that AX = alpha*B or XA = alpha*B or A'X = alpha*B or XA' = alpha*B -* Here the output X is stored in B -* The custom-kernel will be called only when M*(M+N)* sizeof(Matrix Elements) < L3 cache + * Reference implementations + * ToDo: We can combine all these reference implementation + into a macro */ -err_t bli_trsm_small - ( - side_t side, - obj_t* alpha, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +//A'X = B; A is upper triangular; transpose; +//non-unitDiagonal double precision +static err_t dtrsm_AutXB_ref +( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool unitDiagonal +) { -#ifdef BLIS_ENABLE_MULTITHREADING - return BLIS_NOT_YET_IMPLEMENTED; -#endif - - dim_t m = bli_obj_length(b); - dim_t n = bli_obj_width(b); - - if(!(m && n)) - return BLIS_SUCCESS; - - - // If alpha is zero, B matrix will become zero after scaling & hence solution is also zero matrix - if (bli_obj_equals(alpha, &BLIS_ZERO)) + dim_t i, j, k; + for (k = 0; k < M; k++) { - return BLIS_NOT_YET_IMPLEMENTED; // scale B by alpha - } - // We have to call matrix scaling if alpha != 1.0 - - // if row major format return. Check this again. - if ((bli_obj_row_stride(a) != 1) || - (bli_obj_row_stride(b) != 1)) + double lkk_inv = 1.0; + if(!unitDiagonal) lkk_inv = 1.0/A[k+k*lda]; + for (j = 0; j < N; j++) + { + B[k + j*ldb] *= lkk_inv; + for (i = k+1; i < M; i++) + { + B[i + j*ldb] -= A[i*lda + k] * B[k + j*ldb]; + } + } + }// k -loop + return BLIS_SUCCESS; +}// end of function + +/* TRSM scalar code for the case AX = alpha * B + * A is upper-triangular, non-unit-diagonal + * Dimensions: A: mxm X: mxn B:mxn + */ +static err_t dtrsm_AuXB_ref +( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag +) +{ + dim_t i, j, k; + for (k = M-1; k >= 0; k--) { - return BLIS_INVALID_ROW_STRIDE; - } - - num_t dt = ((*b).info & (0x7 << 0)); - - // only float and double datatypes are supported as of now. - if (dt != BLIS_DOUBLE && dt != BLIS_FLOAT) - { - return BLIS_EXPECTED_REAL_DATATYPE; - } - - // A is expected to be triangular in trsm - if (!bli_obj_is_upper_or_lower (a)) - { - return BLIS_EXPECTED_TRIANGULAR_OBJECT; - } - - // can use other control structs - even can use array of function pointers, - // indexed by a number with bits formed by f('side', 'uplo', 'transa', dt). - // In the below implementation, based on the number of finally implemented - // cases, can move the checks with more cases higher up. - - if(side == BLIS_LEFT) - { - if(bli_obj_has_trans(a)) - { - if(dt == BLIS_DOUBLE) - { - if(bli_obj_is_upper(a)) - { - //return bli_dtrsm_small_AutXB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - else - { - //return bli_dtrsm_small_AltXB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - } - else - { - if(bli_obj_is_upper(a)) - { - return bli_strsm_small_AutXB(side, alpha, a, b, cntx, cntl); - } - else - { - //return bli_strsm_small_AltXB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - - } - } - else - { - if(dt == BLIS_DOUBLE) - { - if(bli_obj_is_upper(a)) - { - //return bli_dtrsm_small_AuXB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - else - { - if(bli_obj_has_unit_diag(a)) - return bli_dtrsm_small_AlXB_unitDiag(side, alpha, a, b, cntx, cntl); - else - return bli_dtrsm_small_AlXB(side, alpha, a, b, cntx, cntl); - } - } - else - { - if(bli_obj_is_upper(a)) - { - //return bli_strsm_small_AuXB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - else - { - return bli_strsm_small_AlXB(side, alpha, a, b, cntx, cntl); - } - - } - - } - } - else - { - if(bli_obj_has_trans(a)) - { - if(dt == BLIS_DOUBLE) - { - if(bli_obj_is_upper(a)) - { - if(bli_obj_has_unit_diag(a)) - return bli_dtrsm_small_XAutB_unitDiag(side, alpha, a, b, cntx, cntl); - else - return bli_dtrsm_small_XAutB(side, alpha, a, b, cntx, cntl); - } - else - { - if(bli_obj_has_unit_diag(a)) - return bli_dtrsm_small_XAltB_unitDiag(side, alpha, a, b, cntx, cntl); - else - return bli_dtrsm_small_XAltB(side, alpha, a, b, cntx, cntl); - } - } - else - { - if(bli_obj_is_upper(a)) - { - //return bli_strsm_small_XAutB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - else - { - return bli_strsm_small_XAltB(side, alpha, a, b, cntx, cntl); - } - - } - } - else - { - if(dt == BLIS_DOUBLE) - { - if(bli_obj_is_upper(a)) - { - if(bli_obj_has_unit_diag(a)) - return bli_dtrsm_small_XAuB_unitDiag(side, alpha, a, b, cntx, cntl); - else - return bli_dtrsm_small_XAuB(side, alpha, a, b, cntx, cntl); - } - else - { - if(bli_obj_has_unit_diag(a)) - return bli_dtrsm_small_XAlB_unitDiag(side, alpha, a, b, cntx, cntl); - else - return bli_dtrsm_small_XAlB(side, alpha, a, b, cntx, cntl); - } - } - else - { - if(bli_obj_is_upper(a)) - { - //return bli_strsm_small_XAuB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - else - { - //return bli_strsm_small_XAlB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - - } - - } - } - return BLIS_NOT_YET_IMPLEMENTED; -}; + double lkk_inv = 1.0; + if(!is_unitdiag) lkk_inv = 1.0/A[k+k*lda]; + for (j = N -1; j >= 0; j--) + { + B[k + j*ldb] *= lkk_inv; + for (i = k-1; i >=0; i--) + { + B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; + } + } + }// k -loop + return BLIS_SUCCESS; +}// end of function /* TRSM scalar code for the case AX = alpha * B * A is lower-triangular, non-unit-diagonal, no transpose * Dimensions: A: mxm X: mxn B:mxn */ - -static err_t dtrsm_small_AlXB ( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb - ) +static err_t dtrsm_AlXB_ref +( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag +) { - - dim_t i, j, k; - - for (k = 0; k < M; k++) - { - double lkk_inv = 1.0/A[k+k*lda]; - for (j = 0; j < N; j++) + dim_t i, j, k; + for (k = 0; k < M; k++) { - B[k + j*ldb] *= lkk_inv; - for (i = k+1; i < M; i++) + double lkk_inv = 1.0; + if(!is_unitdiag) lkk_inv = 1.0/A[k+k*lda]; + for (j = 0; j < N; j++) { - B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; + B[k + j*ldb] *= lkk_inv; + for (i = k+1; i < M; i++) + { + B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; + } } - } - }// k -loop - return BLIS_SUCCESS; + }// k -loop + return BLIS_SUCCESS; }// end of function /* TRSM scalar code for the case AX = alpha * B - * A is lower-triangular, unit-diagonal, no transpose + * A is lower-triangular, non-unit-diagonal, transpose * Dimensions: A: mxm X: mxn B:mxn */ - -static err_t dtrsm_small_AlXB_unitDiag ( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb - ) +static err_t dtrsm_AltXB_ref +( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag +) { - - dim_t i, j, k; - - for (k = 0; k < M; k++) - { - for (j = 0; j < N; j++) - { - for (i = k+1; i < M; i++) + dim_t i, j, k; + for (k = M-1; k >= 0; k--) + { + double lkk_inv = 1.0; + if(!is_unitdiag) lkk_inv = 1.0/A[k+k*lda]; + for (j = N -1; j >= 0; j--) { - B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; + B[k + j*ldb] *= lkk_inv; + for (i = k-1; i >=0; i--) + { + B[i + j*ldb] -= A[i*lda + k] * B[k + j*ldb]; + } } - } - } - return BLIS_SUCCESS; + }// k -loop + return BLIS_SUCCESS; }// end of function +// XA = B; A is lower-traingular; No transpose; +//double precision; non-unit diagonal +static err_t bli_dtrsm_small_XAlB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +//XA = B; A is lower triabgular; No transpose; +//double precision; unit-diagonal +static err_t bli_dtrsm_small_XAlB_unitDiag +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +//XA' = B; A is lower-triangular; A is transposed; +// double precision; non-unit-diagonal +static err_t bli_dtrsm_small_XAltB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +//XA' = B; A is lower-triangular; A is transposed; +//double precision; unit-diagonal +static err_t bli_dtrsm_small_XAltB_unitDiag +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +// XA = B; A is upper triangular; No transpose; +//double presicion; non-unit diagonal +static err_t bli_dtrsm_small_XAuB +( + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +//XA = B; A is upper triangular; No transpose; +//double precision; unit-diagonal +static err_t bli_dtrsm_XAuB_unitDiag_ref +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +//XA' = B; A is upper-triangular; A is transposed; +//double precision; non-unit diagonal +static err_t bli_dtrsm_small_XAutB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +//XA' = B; A is upper-triangular; A is transposed; +//double precision; unit diagonal +static err_t bli_dtrsm_small_XAutB_unitDiag +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + + /* TRSM scalar code for the case XA = alpha * B * A is upper-triangular, non-unit-diagonal no transpose * Dimensions: X:mxn A:nxn B:mxn */ -static err_t dtrsm_small_XAuB ( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb +static err_t dtrsm_XAuB_ref +( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb ) { - - dim_t i, j, k; - for(k = 0; k < N; k++) - { - double lkk_inv = 1.0/A[k+k*lda]; - for(i = 0; i < M; i++) - { - B[i+k*ldb] *= lkk_inv; - for(j = k+1; j < N; j++) - { - B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda]; - } - } - + dim_t i, j, k; + for(k = 0; k < N; k++) + { + double lkk_inv = 1.0/A[k+k*lda]; + for(i = 0; i < M; i++) + { + B[i+k*ldb] *= lkk_inv; + for(j = k+1; j < N; j++) + { + B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda]; + } + } + } -return BLIS_SUCCESS; + return BLIS_SUCCESS; } /* TRSM scalar code for the case XA = alpha * B * A is lower-triangular, non-unit triangular, no transpose * Dimensions: X:mxn A:nxn B:mxn */ - -static err_t dtrsm_small_XAlB ( - double *A, - double *B, - double alpha, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb +static err_t dtrsm_XAlB_ref +( + double *A, + double *B, + double alpha, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb ) { - dim_t i, j, k; for(j = 0; j < N; j++) + { for(i = 0; i < M; i++) + { B[i+j*ldb] *= alpha; + } + } for(k = N;k--;) { @@ -711,32 +395,36 @@ static err_t dtrsm_small_XAlB ( } } } -return BLIS_SUCCESS; + return BLIS_SUCCESS; } /* TRSM scalar code for the case XA = alpha * B * A is lower-triangular, unit-diagonal, no transpose *Dimensions: X:mxn A:nxn B:mxn */ -static err_t dtrsm_small_XAlB_unitDiag( - double *A, - double *B, - double alpha, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb +static err_t dtrsm_XAlB_unitDiag_ref +( + double *A, + double *B, + double alpha, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb ) { - dim_t i, j, k; - for(j = 0 ; j < N; j++) + { for(i = 0; i < M; i++) + { B[i+j*ldb] *= alpha; + } + } + double A_k_j; - for(k = N; k--;) - { + for(k = N; k--;) + { for(j = k; j--;) { A_k_j = A[(k)+(j)*lda]; @@ -746,31 +434,32 @@ static err_t dtrsm_small_XAlB_unitDiag( } } } - - -return BLIS_SUCCESS; + return BLIS_SUCCESS; } /* TRSM scalar code for the case XA = alpha * B *A is upper-triangular, non-unit-diagonal, A is transposed * Dimensions: X:mxn A:nxn B:mxn */ -static err_t dtrsm_small_XAutB ( - double *A, - double *B, - double alpha, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb +static err_t dtrsm_XAutB_ref +( + double *A, + double *B, + double alpha, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb ) { - dim_t i, j, k; - for(j = 0; j < N; j++) + { for(i = 0; i < M; i++) + { B[i+j*ldb] *=alpha; + } + } for(k = N; k--;) { @@ -784,33 +473,37 @@ static err_t dtrsm_small_XAutB ( } } } -return BLIS_SUCCESS; + return BLIS_SUCCESS; } /* TRSM scalar code for the case XA = alpha * B * A is upper-triangular, unit-diagonal, A has to be transposed * Dimensions: X:mxn A:nxn B:mxn */ -static err_t dtrsm_small_XAutB_unitDiag( - double *A, - double *B, - double alpha, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb +static err_t dtrsm_XAutB_unitDiag_ref +( + double *A, + double *B, + double alpha, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb ) { - dim_t i, j, k; double A_k_j; for(j = 0; j< N; j++) + { for(i = 0; i< M; i++) + { B[i+j*ldb] *= alpha; + } + } - for(k = N; k--;) - { + for(k = N; k--;) + { for(j = k; j--;) { A_k_j = A[(j)+(k)*lda]; @@ -821,25 +514,24 @@ static err_t dtrsm_small_XAutB_unitDiag( } } } -return BLIS_SUCCESS; + return BLIS_SUCCESS; } /* TRSM scalar code for the case XA = alpha * B * A is lower-triangular, non-unit-diagonal, A has to be transposed * Dimensions: X:mxn A:nxn B:mxn */ -static err_t dtrsm_small_XAltB ( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb +static err_t dtrsm_XAltB_ref +( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb ) { - dim_t i, j, k; - for(k = 0; k < N; k++) { double lkk_inv = 1.0/A[k+k*lda]; @@ -852,14 +544,14 @@ static err_t dtrsm_small_XAltB ( } } } -return BLIS_SUCCESS; + return BLIS_SUCCESS; } /* TRSM scalar code for XA = alpha * B * A is lower-triangular, unit-diagonal, A has to be transposed * Dimensions: X:mxn A:nxn B:mxn */ -static err_t dtrsm_small_XAltB_unitDiag( +static err_t dtrsm_XAltB_unitDiag_ref( double *A, double *B, dim_t M, @@ -868,9 +560,7 @@ static err_t dtrsm_small_XAltB_unitDiag( dim_t ldb ) { - dim_t i, j, k; - for(k = 0; k < N; k++) { for(i = 0; i < M; i++) @@ -881,25 +571,24 @@ static err_t dtrsm_small_XAltB_unitDiag( } } } -return BLIS_SUCCESS; + return BLIS_SUCCESS; } /* TRSM scalar code for the case XA = alpha * B * A is upper-triangular, unit-diagonal, no transpose * Dimensions: X:mxn A:nxn B:mxn */ -static err_t dtrsm_small_XAuB_unitDiag ( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb +static err_t dtrsm_XAuB_unitDiag_ref +( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb ) { - dim_t i, j, k; - for(k = 0; k < N; k++) { for(i = 0; i < M; i++) @@ -910,13 +599,125 @@ static err_t dtrsm_small_XAuB_unitDiag ( } } } -return BLIS_SUCCESS; + return BLIS_SUCCESS; } +/* + * Kernels Table +*/ +trsmsmall_ker_ft ker_fps[16] = +{ + bli_dtrsm_small_AlXB, + bli_dtrsm_small_AltXB, + bli_dtrsm_small_AuXB, + bli_dtrsm_small_AutXB, + bli_dtrsm_small_AlXB, + bli_dtrsm_small_AltXB, + bli_dtrsm_small_AuXB, + bli_dtrsm_small_AutXB, + bli_dtrsm_small_XAlB, + bli_dtrsm_small_XAltB, + bli_dtrsm_small_XAuB, + bli_dtrsm_small_XAutB, + bli_dtrsm_small_XAlB_unitDiag, + bli_dtrsm_small_XAltB_unitDiag, + bli_dtrsm_XAuB_unitDiag_ref, + bli_dtrsm_small_XAutB_unitDiag +}; + +/* +* The bli_trsm_small implements a version of TRSM where A is packed and reused +* +* Input: A: MxM (triangular matrix) +* B: MxN matrix +* Output: X: MxN matrix such that + AX = alpha*B or XA = alpha*B or A'X = alpha*B or XA' = alpha*B +* Here the output X is stored in B +* +* Note: Currently only dtrsm is supported when A & B are column-major +*/ +err_t bli_trsm_small +( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + err_t err; + dim_t m = bli_obj_length(b); + dim_t n = bli_obj_width(b); + + if(!(m && n)) { + return BLIS_SUCCESS; + } + + bool unitdiag = bli_obj_has_unit_diag(a); + bool uplo = bli_obj_is_upper(a); + bool transa = bli_obj_has_trans(a); + + /* ToDo: Temporary threshold condition for trsm single thread. + * It will be updated with arch based threshold function which reads + * tunned thresholds for all 64 (datatype,side,uplo,transa,unit,) trsm + combinations + */ + if(m > 128 || n > 128) { + return BLIS_NOT_YET_IMPLEMENTED; + } + + /* If alpha is zero, B matrix will become zero after scaling + hence solution is also zero matrix */ + if (bli_obj_equals(alpha, &BLIS_ZERO)) { + return BLIS_NOT_YET_IMPLEMENTED; // scale B by alpha + } + + // Return if inputs are row major as currently + // we are supporing col major only + if ((bli_obj_row_stride(a) != 1) || + (bli_obj_row_stride(b) != 1)) { + return BLIS_INVALID_ROW_STRIDE; + } + + //Curretnly optimized for double data type only + num_t dt = bli_obj_dt(a); + if (dt != BLIS_DOUBLE) { + return BLIS_NOT_YET_IMPLEMENTED; + } + + // A is expected to be triangular in trsm + if (!bli_obj_is_upper_or_lower (a)) { + return BLIS_EXPECTED_TRIANGULAR_OBJECT; + } + + /* + * Compose kernel index based on inputs + */ + + dim_t keridx = ( (( side & 0x1) << 3) | ((unitdiag & 0x1) << 2) | + (( uplo & 0x1) << 1) | ( transa & 0x1) ); + + + trsmsmall_ker_ft ker_fp = ker_fps[ keridx ]; + + /*Call the kernel*/ + err = ker_fp + ( + alpha, + a, + b, + cntx, + cntl + ); + + return err; +}; + /* TRSM for the case AX = alpha * B, Double precision * A is lower-triangular, no-transpose, non-unit diagonal * dimensions A: mxm X: mxn B: mxn - + b01---> * ***************** ** * * * * * @@ -936,38 +737,2931 @@ a10 ****** b11 ***************** **************** ***************** a11---> */ -static err_t bli_dtrsm_small_AlXB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +static err_t bli_dtrsm_small_AlXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) { - - dim_t D_MR = 4; //size of block along 'M' dimpension - dim_t D_NR = 8; //size of block along 'N' dimension + dim_t D_MR = 8; //size of block along 'M' dimpension + dim_t D_NR = 6; //size of block along 'N' dimension dim_t m = bli_obj_length(b); // number of rows of matrix B dim_t n = bli_obj_width(b); // number of columns of matrix B + dim_t cs_a = bli_obj_col_stride(a); // column stride of A + dim_t cs_b = bli_obj_col_stride(b); // column stride of B -#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME) - || (m> D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_N) - || (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_M && n D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_NAPLES) + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed + + double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha + double *L = a->buffer; //pointer to matrix A + double *B = b->buffer; //pointer to matrix B + + //pointers that point to blocks for GEMM and TRSM + double *a10, *a11, *b01, *b11; + double *ptr_b01_dup; + + double ones = 1.0; + bool is_unitdiag = bli_obj_has_unit_diag(a); + + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19; + __m256d ymm20; + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s; + double *D_A_pack = NULL; + double d11_pack[D_MR] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + } + + if ( (D_MR * m * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + /* + Performs solving TRSM for 8 colmns at a time from 0 to m/8 in steps of D_MR + a. Load and pack A (a10 block), the size of packing 8x6 to 8x (m-8) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of D_NR + */ + for(i = 0;(i+D_MR-1) < m; i += D_MR) //loop along 'M' dimension + { + a10 = L + (i); //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = D_MR; // packed leading dimension + + /* + Pack current A block (a10) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by D_MR for every next itteration + untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + for(dim_t x =0;x < i;x++) { - return BLIS_NOT_YET_IMPLEMENTED; + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * x)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * x), ymm16); + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * x + 4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * x + 4), ymm16); } -#endif + + /* + Pack 8 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(is_unitdiag) + { + _mm256_storeu_pd((double *)(d11_pack), ymm4); + _mm256_storeu_pd((double *)(d11_pack + 4), ymm4); + }else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); - dim_t m_remainder = m & 3; //number of remainder rows - dim_t n_remainder = n & 7; //number of remainder columns + //Pick one element each column and create a 4 element vector and store + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm0 = _mm256_div_pd(ymm4, ymm1); + _mm256_storeu_pd((double *)(d11_pack), ymm0); + + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5 + 5)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6 + 6)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7 + 7)); + + //Pick one element each column and create a 4 element vector and store + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm0 = _mm256_div_pd(ymm4, ymm1); + _mm256_storeu_pd((double *)(d11_pack + 4), ymm0); + } + + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along n dimension for every D_NR columns of B01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i >>3; //number of times GEMM to be performed + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + double *b01_temp = b01+ 4; + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + b01 += 1; //move to next row of B + a10 += p_lda; //pointer math to calculate next block of A for GEMM + if(!((k+1)&3)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup + D_MR*p_lda; + ptr_a10_dup = a10; + } + } + ///GEMM code end/// + + /* + Load b11 of size 6x8 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_loadu_pd((double const *)(b11 + 4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm12); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm13); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm14); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm15); + + ymm13 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm15 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm15 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *4 + 4)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *5 + 4)); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); + ymm18 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm17 = _mm256_permute2f128_pd(ymm18,ymm16,0x20); + ymm19 = _mm256_permute2f128_pd(ymm18,ymm16,0x31); + ymm20 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm18 = _mm256_permute2f128_pd(ymm20,ymm16,0x20); + ymm20 = _mm256_permute2f128_pd(ymm20,ymm16,0x31); + //b11 transpose end + + /* + Compute 8x6 TRSM block by using GEMM block output in register + a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers + 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 + 5. ymm12, ymm17 6. ymm13,ymm18, 7. ymm14,ymm19 8. ymm15, ymm20 + where ymm8-ymm15 holds 8x4 data and reaming 8x2 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in b11 + */ + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + ymm4 = _mm256_mul_pd(ymm4, ymm1); + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm8, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm4, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm8, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm4, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm8, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm4, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm8, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm4, ymm20); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + ymm5 = _mm256_mul_pd(ymm5, ymm1); + + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm9, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm5, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm9, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm5, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm9, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm5, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm9, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm5, ymm20); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + ymm6 = _mm256_mul_pd(ymm6, ymm1); + + a11 += cs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm10, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm6, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm10, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm6, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm10, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm6, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm10, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm6, ymm20); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + ymm7 = _mm256_mul_pd(ymm7, ymm1); + + a11 += cs_a; + + //extract a44 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(ROw4): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm7, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm11, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm7, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm11, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm7, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm11, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm7, ymm20); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + ymm17 = _mm256_mul_pd(ymm17, ymm1); + + a11 += cs_a; + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm17, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm12, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm17, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm17, ymm20); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + ymm18 = _mm256_mul_pd(ymm18, ymm1); + + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw6): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm18, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm13, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm18, ymm20); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm19 = _mm256_mul_pd(ymm19, ymm1); + + a11 += cs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //(ROw7): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm19, ymm20); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + ymm20 = _mm256_mul_pd(ymm20, ymm1); + + a11 += cs_a; + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm3 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ///unpack high/// + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm3 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + _mm256_storeu_pd((double *)(b11 + 4), ymm0); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm1); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm2); //store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm3); //store B11[7][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm17, ymm18); + ymm3 = _mm256_unpacklo_pd(ymm19, ymm20); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + + ///unpack high/// + ymm17 = _mm256_unpackhi_pd(ymm17, ymm18); + ymm18 = _mm256_unpackhi_pd(ymm19, ymm20); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm17, ymm18, 0x20); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4 + 4), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 5 + 4), ymm1); + } + + dim_t n_rem = n-j; + if(n_rem >= 4) + { + a10 = D_A_pack; + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / D_MR; //number of times GEMM to be performed + + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3 + 4)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup + D_MR*p_lda; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); + ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); + + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm8, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm8, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm8, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm8, ymm15); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm9, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm9, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm9, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm9, ymm15); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + a11 += cs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm10, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm10, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm10, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm10, ymm15); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + a11 += cs_a; + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(ROw4): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm11, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm11, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm11, ymm15); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + + a11 += cs_a; + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm12, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw6): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 +7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm13, ymm15); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + + a11 += cs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //(ROw7): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + + a11 += cs_a; + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); + + n_rem -=4; + j +=4; + + } + + if(n_rem) + { + a10 = D_A_pack; + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / D_MR; //number of times GEMM to be performed + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_rem) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +D_MR*p_lda; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +D_MR*p_lda; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +D_MR*p_lda; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); + ymm5 = _mm256_broadcast_sd((double const *)(&ones)); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm8, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm8, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm8, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm8, ymm15); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm9, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm9, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm9, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm9, ymm15); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + a11 += cs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm10, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm10, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm10, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm10, ymm15); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + a11 += cs_a; + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(ROw4): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm11, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm11, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm11, ymm15); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + + a11 += cs_a; + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm12, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw6): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 +7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm13, ymm15); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + + a11 += cs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //(ROw7): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + + a11 += cs_a; + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); + + if(3 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); + } + else if(2 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); + } + else if(1 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); + } + } + } + + /* + Reminder cases starts here: + a. Similar logic and code flow used in computing full block (8x6) + above holds for reminder cases too. + */ + + dim_t m_rem = m-i; + //implementation for reamainder rows(when 'M' is not a multiple of D_MR) + if(m_rem>=4) + { + a10 = L + (i); //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + double *ptr_a10_dup = D_A_pack; + double *ptr_a11_dup = a11; + + dim_t p_lda = 4; // packed leading dimension + for(dim_t x =0;x < i;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + cs_a * x)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * x), ymm0); + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); + + //Pick one element each column and create a 4 element vector and store + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm4 = _mm256_div_pd(ymm4, ymm1); + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = ptr_a11_dup; //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM operation to be done + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + + ////unpacklow//// + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + + //rearrange high elements + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //b11 transpose end + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + ymm4 = _mm256_mul_pd(ymm4, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + ymm5 = _mm256_mul_pd(ymm5, ymm1); + + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + ymm6 = _mm256_mul_pd(ymm6, ymm1); + + a11 += cs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + ymm7 = _mm256_mul_pd(ymm7, ymm1); + + a11 += cs_a; + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] + } + dim_t n_rem = n-j; + if(n_rem >= 4) + { + a10 = D_A_pack; + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + a11 += cs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + a11 += cs_a; + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + n_rem -= 4; + j += 4; + } + if(n_rem) + { + a10 = D_A_pack; + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + a11 += cs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + a11 += cs_a; + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + if(3 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + } + else if(2 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + } + else if(1 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + } + } + m_rem -=4; + i +=4; + } + + if(m_rem) + { + a10 = L + (i); //pointer to block of A to be used for GEMM + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + if(3 == m_rem) // Repetative A blocks will be 3*3 + { + dim_t p_lda = 4; // packed leading dimension + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + n_rem -= 4; + j +=4; + } + if(n_rem) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + } + } + } + else if(2 == m_rem) // Repetative A blocks will be 2*2 + { + dim_t p_lda = 4; // packed leading dimension + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + n_rem -= 4; + j +=4; + } + if(n_rem) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed + + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + } + + } + m_rem -=2; + i+=2; + } + else if(1 == m_rem) // Repetative A blocks will be 1*1 + { + dim_t p_lda = 4; // packed leading dimension + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + n_rem -= 4; + j+=4; + } + + if(n_rem) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed + + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + } + } + m_rem -=1; + i+=1; + } + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm, &local_mem_buf_A_s); + } + + return BLIS_SUCCESS; +} + +/* TRSM for the Left Upper case AX = alpha * B, Double precision + * A is Left side, upper-triangular, transpose, non-unit diagonal + * dimensions A: mxm X: mxn B: mxn + a10 ----> b11---> +*********** ***************** +* * * * *b01*b11* * * + **a10 * * a11 b11 * * * * * + ********* | | ***************** + *a11* * | | * * * * * + * * * | | * * * * * + ****** v v ***************** + * * * * * * * + * * * * * * * + * * ***************** + * + a11---> +*/ +static err_t bli_dtrsm_small_AutXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + dim_t D_MR = 8; //size of block along 'M' dimpension + dim_t D_NR = 6; //size of block along 'N' dimension + + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B dim_t cs_a = bli_obj_col_stride(a); // column stride of A dim_t cs_b = bli_obj_col_stride(b); // column stride of B @@ -982,31 +3676,2605 @@ static err_t bli_dtrsm_small_AlXB( double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM double *ptr_b01_dup; - double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0 - double* f_temp; - double ones = 1.0; + bool is_unitdiag = bli_obj_has_unit_diag(a); //scratch registers __m256d ymm0, ymm1, ymm2, ymm3; __m256d ymm4, ymm5, ymm6, ymm7; __m256d ymm8, ymm9, ymm10, ymm11; __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16; + __m256d ymm16, ymm17, ymm18, ymm19; + __m256d ymm20; + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s; + double *D_A_pack = NULL; + double d11_pack[D_MR] __attribute__((aligned(64))); + rntm_t rntm; + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); - for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' dimension + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if ( (D_MR * m * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) { - for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' dimension + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + } + + /* + Performs solving TRSM for 8 colmns at a time from 0 to m/8 in steps of D_MR + a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-8) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of D_NR + */ + for(i = 0;(i+D_MR-1) < m; i += D_MR) //loop along 'M' dimension + { + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); + double *ptr_a10_dup = D_A_pack; + dim_t p_lda = i; // packed leading dimension + + /* + Load, tranpose and pack current A block (a10) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by D_MR for every next itteration + untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + + for(dim_t x =0;x < i;x+=D_MR) { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm10 = _mm256_loadu_pd((double const *)(a10 + 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm11 = _mm256_loadu_pd((double const *)(a10 + 4 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm12 = _mm256_loadu_pd((double const *)(a10 + 4 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); + ymm13 = _mm256_loadu_pd((double const *)(a10 + 4 + cs_a * 3)); - k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4) + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + + ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); + ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 4), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 5), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 6), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 7), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + cs_a * 4)); + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 4 + 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a * 5)); + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 5 + 4)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 6)); + ymm12 = _mm256_loadu_pd((double const *)(a10 + cs_a * 6 + 4)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 7)); + ymm13 = _mm256_loadu_pd((double const *)(a10 + cs_a * 7 + 4)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup + 4), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), ymm9); + + + ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); + ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 4 + 4), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 5 + 4), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 6 + 4), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 7 + 4), ymm9); + + a10 += D_MR; + ptr_a10_dup += D_MR; + } + + /* + Pack 8 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(is_unitdiag) + { + _mm256_storeu_pd((double *)(d11_pack), ymm4); + _mm256_storeu_pd((double *)(d11_pack + 4), ymm4); + }else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); + + //Pick one element each column and create a 4 element vector and store + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm0 = _mm256_div_pd(ymm4, ymm1); + _mm256_storeu_pd((double *)(d11_pack), ymm0); + + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11 + 4 + cs_a*4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5 + cs_a*5)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 + cs_a*6)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 7 + cs_a*7)); + + //Pick one element each column and create a 4 element vector and store + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm0 = _mm256_div_pd(ymm4, ymm1); + _mm256_storeu_pd((double *)(d11_pack + 4), ymm0); + } + + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along n dimension for every D_NR rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + dim_t temp = n - D_NR + 1; + for(j = 0; j < temp; j += D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i >>3; //number of times GEMM to be performed(in blocks of 4x4) + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + double *b01_temp = b01+ 4; + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + b01 += 1; //move to next row of B + a10 += p_lda; //pointer math to calculate next block of A for GEMM + if(!((k+1)&3)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +D_MR; + ptr_a10_dup = a10; + } + } + ///GEMM code end/// + + /* + Load b11 of size 6x8 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] + + ymm13 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *4 + 4)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *5 + 4)); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + ymm18 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + + ymm17 = _mm256_permute2f128_pd(ymm18,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm19 = _mm256_permute2f128_pd(ymm18,ymm16,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm20 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm18 = _mm256_permute2f128_pd(ymm20,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm20 = _mm256_permute2f128_pd(ymm20,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + /* + Compute 8x6 TRSM block by using GEMM block output in register + a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers + 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 + 5. ymm12, ymm17 6. ymm13,ymm18, 7. ymm14,ymm19 8. ymm15, ymm20 + where ymm8-ymm15 holds 8x4 data and reaming 8x2 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in b11 + */ + ////extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + ymm4 = _mm256_mul_pd(ymm4, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm8, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm4, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm8, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm4, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm8, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm4, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm8, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm4, ymm20); + + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + ymm5 = _mm256_mul_pd(ymm5, ymm1); + + a11 += 1; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm9, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm5, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm9, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm5, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm9, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm5, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm9, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm5, ymm20); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + ymm6 = _mm256_mul_pd(ymm6, ymm1); + + a11 += 1; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm10, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm6, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm10, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm6, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm10, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm6, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm10, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm6, ymm20); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + ymm7 = _mm256_mul_pd(ymm7, ymm1); + + a11 += 1; + + //extract a44 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + //(ROw4): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm7, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm11, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm7, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm11, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm7, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm11, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm7, ymm20); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + ymm17 = _mm256_mul_pd(ymm17, ymm1); + + a11 += 1; + + //extract a55 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm17, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm12, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm17, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm17, ymm20); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + ymm18 = _mm256_mul_pd(ymm18, ymm1); + + a11 += 1; + + //extract a66 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw6): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm18, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm13, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm18, ymm20); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm19 = _mm256_mul_pd(ymm19, ymm1); + + a11 += 1; + + //extract a77 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //(ROw7): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm19, ymm20); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + ymm20 = _mm256_mul_pd(ymm20, ymm1); + + a11 += 1; + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm3 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ///unpack high/// + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm3 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + _mm256_storeu_pd((double *)(b11 + 4), ymm0); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm1); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm2); //store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm3); //store B11[7][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm17, ymm18); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm3 = _mm256_unpacklo_pd(ymm19, ymm20); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ///unpack high/// + ymm17 = _mm256_unpackhi_pd(ymm17, ymm18); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm18 = _mm256_unpackhi_pd(ymm19, ymm20); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm17, ymm18, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + + _mm256_storeu_pd((double *)(b11 + cs_b * 4 + 4), ymm0); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5 + 4), ymm1); //store B11[5][0-3] + } + + dim_t n_rem = n-j; + if(n_rem >= 4) + { + a10 = D_A_pack; + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4) + + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3 + 4)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 5)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 5 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3 + 4)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 6)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 6 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3 + 4)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 7)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 7 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3 + 4)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + b01 += 1; //move to next row of B + + a10 += D_MR; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] + + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm8, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm8, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm8, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm8, ymm15); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm9, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm9, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm9, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm9, ymm15); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm10, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm10, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm10, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm10, ymm15); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //extract a44 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + + //(ROw4): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm5, ymm11, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm11, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm11, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm11, ymm15); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + + //extract a55 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + + //(ROw5): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm6, ymm12, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm12, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm12, ymm15); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 +cs_a*7)); + + a11 += 1; + + //extract a66 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw6): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm7, ymm13, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm13, ymm15); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + + //extract a77 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + //(ROw7): FMA operations + ymm15 = _mm256_fnmadd_pd(ymm16, ymm14, ymm15); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); //store B11[7][0-3] + + n_rem -=4; + j +=4; + + } + if(n_rem) + { + a10 = D_A_pack; + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4) + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 5)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 5 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 6)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 6 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 7)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 7 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + b01 += 1; //move to next row of B + + a10 += D_MR; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 5)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 5 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 6)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 6 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 7)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 7 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + b01 += 1; //move to next row of B + + a10 += D_MR; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 5)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 5 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 6)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 6 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 7)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 7 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + b01 += 1; //move to next row of B + + a10 += D_MR; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_broadcast_sd((double const *)(&ones)); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm8, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm8, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm8, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm8, ymm15); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm9, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm9, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm9, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm9, ymm15); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm10, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm10, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm10, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm10, ymm15); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //extract a44 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + + //(ROw4): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm5, ymm11, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm11, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm11, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm11, ymm15); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + + //extract a55 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + + //(ROw5): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm6, ymm12, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm12, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm12, ymm15); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 +cs_a*7)); + + a11 += 1; + + //extract a66 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + + //(ROw6): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm7, ymm13, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm13, ymm15); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + + //extract a77 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + //(ROw7): FMA operations + ymm15 = _mm256_fnmadd_pd(ymm16, ymm14, ymm15); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + if(3 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + } + else if(2 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + } + else if(1 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + } + } + + } + //======================M remainder cases================================ + dim_t m_rem = m-i; + if(m_rem>=4) //implementation for reamainder rows(when 'M' is not a multiple of D_MR) + { + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); + double *ptr_a10_dup = D_A_pack; + dim_t p_lda = i; // packed leading dimension + for(dim_t x =0;x < i;x+=4) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += 4; + ptr_a10_dup += 4; + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm4 = _mm256_div_pd(ymm4, ymm1); + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i / 4; //number of times GEMM operation to be done(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + + ////unpacklow//// + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + //ymm16; + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + //ymm16; + + //rearrange high elements + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + //b11 transpose end + + ////extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + ymm4 = _mm256_mul_pd(ymm4, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); + + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + ymm5 = _mm256_mul_pd(ymm5, ymm1); + + a11 += 1; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + ymm6 = _mm256_mul_pd(ymm6, ymm1); + + a11 += 1; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + ymm7 = _mm256_mul_pd(ymm7, ymm1); + + a11 += 1; + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] + } + dim_t n_rem = n-j; + if(n_rem >= 4) + { + a10 = D_A_pack; + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / 4; //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += 1; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += 1; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += 1; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + n_rem -= 4; + j += 4; + } + if(n_rem) + { + a10 = D_A_pack; + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / 4; //number of times GEMM to be performed(in blocks of 4x4) + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + ////extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += 1; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += 1; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += 1; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + if(3 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + } + else if(2 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + } + else if(1 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + } + + } + m_rem -=4; + i +=4; + } + + if(m_rem) + { + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + if(3 == m_rem) // Repetative A blocks will be 3*3 + { + dim_t p_lda = i; // packed leading dimension + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - k_iter = i / D_MR; //number of times GEMM operation to be done(in blocks of 4x4) + k_iter = i / 4; //number of times GEMM to be performed(in blocks of 4x4) - int iter; - - if((j+D_NR) == n) - { - for(iter = 0; iter < m_remainder; iter++) - f_t[iter] = (b11 + cs_b * 7)[iter]; - f_temp = f_t; - } - else - f_temp = (b11 + cs_b * 7); + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ymm8 = _mm256_setzero_pd(); ymm9 = _mm256_setzero_pd(); ymm10 = _mm256_setzero_pd(); @@ -1321,1748 +6449,2174 @@ static err_t bli_dtrsm_small_AlXB( ymm14 = _mm256_setzero_pd(); ymm15 = _mm256_setzero_pd(); - ///GEMM code Begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b,is_unitdiag); + n_rem -= 4; + j +=4; + } + if(n_rem) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i / 4; //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { ptr_b01_dup = b01; - ymm16 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - b01 += 1; //move to next row of B + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0] ) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm16 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) + b01 += 1; //move to next row of B - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 1)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - b01 += 1; //move to next row of B01 + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) + b01 += 1; //move to next row of B - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] //A10[1][2] A10[2][2] A10[3][2] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - b01 += 1; //move to next row of B + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][0]*A10[3][2]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][1]*A10[3][2]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][2]*A10[3][2]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][3]*A10[3][2]) + b01 += 1; //move to next row of B - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - b01 += 1; //move to next row of B + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm8 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm8 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm8 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm8 += (B01[3][0]*A10[0][3] B01[3][4]*A10[1][3] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm8 += (B01[3][1]*A10[0][3] B01[3][5]*A10[1][3] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm8 += (B01[3][2]*A10[0][3] B01[3][6]*A10[1][3] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm8 += (B01[3][3]*A10[0][3] B01[3][7]*A10[1][3] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) + b01 += 1; //move to next row of B - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b,is_unitdiag); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b,is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + + dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + } + } + } + else if(2 == m_rem) // Repetative A blocks will be 2*2 + { + dim_t p_lda = i; // packed leading dimension + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - k_iter = i / D_MR; //number of times GEMM to be performed(in block of 4) - ///GEMM for previously calculated values /// - - //load 4x4 block from b11 - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + k_iter = i / 4; //number of times GEMM to be performed(in blocks of 4x4) + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a*2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ptr_b01_dup = b01; - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + ymm0 = _mm256_loadu_pd((double const *)(a10)); - b01 += 1; //move to next row of B + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - b01 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) + b01 += 1; //move to next row of B - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); - b01 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[1][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[2][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[3][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - b01 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) + b01 += 1; //move to next row of B + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - - ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B01[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B01[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B01[0-3][3] *alpha -= ymm7 + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha ///implement TRSM/// - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - //3rd col - a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] - ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3] - //compute reciprocals of L(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[2][2] A11[2][2] - ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[1][1] A11[1][1] A11[3][3] A11[3][3] + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - ymm14 = _mm256_broadcast_sd((double const *)&ones); + dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + n_rem -= 4; + j +=4; + } + if(n_rem) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //extract a00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] - - //extract diag a11 from a - ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]*B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]*B11[0][0-3] - ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]*B11[0][0-3] - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] - - - //extract diag a22 from a - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]*B11[1][0-3] - ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]*B11[1][0-3] - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2] - - //extract diag a33 from a - ymm15 = _mm256_permute_pd(ymm14, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3] - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]*B11[2][0-3] - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - ymm13 = _mm256_mul_pd(ymm13, ymm15); //B11[3][0-3] /= A11[3][3] - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[0-3][3]) - - } - if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) - - dim_t iter; - - if((j+4) == n) - { - f_temp = f_t; - for(iter = 0; iter < m_remainder; iter++) - f_temp[iter] = (b11 + cs_b * 3)[iter]; - } - else - f_temp = (b11 + cs_b * 3); - ///GEMM for previously calculated values /// - - //load 4x4 block from b11 - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + k_iter = i / 4; //number of times GEMM to be performed(in blocks of 4x4) + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); - for(k = 0; k < k_iter; k++) //looop for number of GEMM operations + if(3 == n_rem) { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - b01 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + b01 += 1; //move to next row of B - b01 += 1; + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - b01 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) + b01 += 1; //move to next row of B - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); - b01 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - - ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[0-3][3] *alpha -= ymm7 - - - if(3 == m_remainder) + else if(2 == n_rem) { - ///implement TRSM/// - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] + ///GEMM code begins/// - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; - //3rd col - a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] + ymm0 = _mm256_loadu_pd((double const *)(a10)); - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(&ones)); //A11[3][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - //compute reciprocals of L(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] - ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[2][2] A11[3][3] A11[3][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm14 = _mm256_broadcast_sd((double const *)&ones); + b01 += 1; //move to next row of B - ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + b01 += 1; //move to next row of B - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); - //extract a00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - //extract diag a11 from a - ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[][] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] + b01 += 1; //move to next row of B - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]* B11[0][0-3] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - //extract diag a22 from a - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]* B11[1][0-3] - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2] + b01 += 1; //move to next row of B - ymm13 = _mm256_broadcast_sd((double const *)(&ones)); + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ///implement TRSM/// - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - //load 4x4 block from b11 - ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - //determine correct values to store - ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08); - ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08); - ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08); - ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08); + + dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); } - else if( 2 == m_remainder ) + else if(1 == n_rem) { - ///implement TRSM/// - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] + ///GEMM code begins/// - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; - //compute reciprocals of L(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] + ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm14 = _mm256_broadcast_sd((double const *)&ones); + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm4 = _mm256_blend_pd(ymm4, ymm14, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + b01 += 1; //move to next row of B - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + b01 += 1; //move to next row of B - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); - //extract a00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] + b01 += 1; //move to next row of B - //extract diag a11 from a - ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[][] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] + b01 += 1; //move to next row of B - ymm11 = _mm256_broadcast_sd((double const *)(&ones)); - ymm13 = _mm256_broadcast_sd((double const *)(&ones)); + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ///implement TRSM/// - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - //load 4x4 block from b11 - ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] - - //determine correct values to store - ymm0 = _mm256_permute2f128_pd(ymm0, ymm4,0x30); - ymm1 = _mm256_permute2f128_pd(ymm1, ymm5,0x30); - ymm2 = _mm256_permute2f128_pd(ymm2, ymm6,0x30); - ymm3 = _mm256_permute2f128_pd(ymm3, ymm7,0x30); - - } - else if(1 == m_remainder) - { - ///implement TRSM/// - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - ymm14 = _mm256_broadcast_sd((double const *)&ones); - ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //extract a00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] - - ymm8 = _mm256_broadcast_sd((double const *)(&ones)); - ymm11 = _mm256_broadcast_sd((double const *)(&ones)); - ymm13 = _mm256_broadcast_sd((double const *)(&ones)); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - //load 4x4 block from b11 - ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] - - //determine correct values to store - ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x0E); - ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x0E); - ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x0E); - ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x0E); + dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); } - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[0-3][3]) - - if((j+4) == n) - { - for(iter = 0; iter < m_remainder; iter++) - (b11 + cs_b * 3)[iter] = f_temp[iter]; - } + } + m_rem -=2; + i+=2; } + else if(1 == m_rem) // Repetative A blocks will be 1*1 + { + dim_t p_lda = i; // packed leading dimension + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i / 4; //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + n_rem -= 4; + j+=4; + } + + if(n_rem) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i / 4; //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + + dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + } + } + m_rem -=1; + i+=1; + } } - if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR) + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) { - for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM for previously calculated values /// - - //load 4x4 block from b11 - if(3 == n_remainder) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - for(k = 0; k < k_iter; k++) - { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - } - - ///GEMM code ends/// - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value - - ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 - } - else if(2 == n_remainder) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - for(k = 0; k < k_iter; k++) - { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - } - ///GEMM code ends/// - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value - - ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 - } - else if(1 == n_remainder) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - for(k = 0; k < k_iter; k++) - { - ptr_b01_dup = b01; - - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - } - ///GEMM code ends/// - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value - - ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 - - } - - ///implement TRSM/// - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] - - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] - - //3rd col - a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] - ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] - - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3] - //compute reciprocals of L(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] - ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[2][2] A11[3][3] A11[3][3] - - ymm14 = _mm256_broadcast_sd((double const *)&ones); - - ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //extract a00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] - - //extract diag a11 from a - ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0] * B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0] * B11[0][0-3] - ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0] * B11[0][0-3] - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] - - - //extract diag a22 from a - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1] * B11[1][0-3] - ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1] * B11[1][0-3] - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2] - - //extract diag a33 from a - ymm15 = _mm256_permute_pd(ymm14, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3] - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2] * B11[2][0-3] - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - ymm13 = _mm256_mul_pd(ymm13, ymm15); //B11[3][0-3] /= A11[3][3] - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - if(3 == n_remainder) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - - } - else if(2 == n_remainder) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - - } - else if(1 == n_remainder) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - } - - } - if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - - k_iter = i / D_MR; //number of times GEMM operations to be performed - - dim_t iter; - if((j+n_remainder) == n) - { - f_temp = f_t; - for(iter = 0; iter < m_remainder; iter++) - f_temp[iter] = (b11 + cs_b * (n_remainder -1))[iter]; - } - else - f_temp = (b11 + cs_b * (n_remainder -1)); - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM for previously calculated values /// - - - //load 4x4 block from b11 - if(3 == n_remainder) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[0][2] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[1][2] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[2][2] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[3][2] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value - - ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 - ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5 - ymm10 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] * alpha -= ymm6 - - ///implement TRSM/// - //determine correct values to store - if(3 == m_remainder) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - } - else if(2 == m_remainder) - { - ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); - ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); - ymm2 = _mm256_permute2f128_pd(ymm10, ymm2, 0x30); - } - else if(1 == m_remainder) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - } - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(f_temp), ymm2); //store(B11[0-3][2]) - } - if(2 == n_remainder) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_b01_dup = b01; - - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value - - ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 - ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5 - - ///implement TRSM/// - //determine correct values to store - if(3 == m_remainder) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - } - else if(2 == m_remainder) - { - ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); - ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); - } - else if(1 == m_remainder) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - } - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[0-3][1]) - } - if(n_remainder == 1) - { - ymm0 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_b01_dup = b01; - - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - - } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value - - ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 - - ///implement TRSM/// - //determine correct values to store - if(3 == m_remainder) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - } - else if(2 == m_remainder) - { - ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); - } - else if(1 == m_remainder) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - } - _mm256_storeu_pd((double *)(f_temp), ymm0); //store(B11[0-3][0]) - } - - if((j+n_remainder) == n) - { - for(iter = 0; iter < m_remainder; iter++) - (b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter]; - } - - ///scalar code for trsm without alpha/// - dtrsm_small_AlXB(a11, b11, m_remainder, n_remainder, cs_a, cs_b); - } + bli_membrk_release(&rntm, &local_mem_buf_A_s); } return BLIS_SUCCESS; } /* TRSM for the case AX = alpha * B, Double precision - * A is lower-triangular, no-transpose, unit diagonal + * A is lower-triangular, transpose, non-unit diagonal * dimensions A: mxm X: mxn B: mxn - - b01---> - * ***************** - ** * * * * * - * * * * * * * - * * *b01* * * * - * * * * * * * -a10 ****** b11 ***************** - | * * * | * * * * * - | * * * | * * * * * - | *a10*a11* | *b11* * * * - v * * * v * * * * * - *********** ***************** - * * * * * * * * * - * * * * * * * * * - * * * * * * * * * - * * * * * * * * * - **************** ***************** - a11---> */ - -static err_t bli_dtrsm_small_AlXB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +static err_t bli_dtrsm_small_AltXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) { + dim_t D_MR = 8; //size of block along 'M' dimpension + dim_t D_NR = 6; //size of block along 'N' dimension - dim_t D_MR = 4; //size of block along 'M' dimpension - dim_t D_NR = 8; //size of block along 'N' dimension + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B - dim_t m = bli_obj_length(b); // number of rows of matrix B - dim_t n = bli_obj_width(b); // number of columns of matrix B + dim_t cs_a = bli_obj_col_stride(a); // column stride of A + dim_t cs_b = bli_obj_col_stride(b); // column stride of B -#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME) - || (m> D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_N) - || (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_M && n D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_NAPLES) - { - return BLIS_NOT_YET_IMPLEMENTED; - } -#endif - - dim_t m_remainder = m & (3); //number of remainder rows - dim_t n_remainder = n & (7); //number of remainder columns - - dim_t cs_a = bli_obj_col_stride(a); // column stride of A - dim_t cs_b = bli_obj_col_stride(b); // column stride of B - - dim_t i, j, k; //loop variables - dim_t k_iter; //number of times GEMM to be performed + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B + double *L = a->buffer; //pointer to matrix A + double *B = b->buffer; //pointer to matrix B - double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM + //pointers that point to blocks for GEMM and TRSM + double *a10, *a11, *b01, *b11; double *ptr_b01_dup; - - double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0 - double* f_temp; + double *ptr_a10_dup; double ones = 1.0; - + bool is_unitdiag = bli_obj_has_unit_diag(a); //scratch registers __m256d ymm0, ymm1, ymm2, ymm3; __m256d ymm4, ymm5, ymm6, ymm7; __m256d ymm8, ymm9, ymm10, ymm11; __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16; + __m256d ymm16, ymm17, ymm18, ymm19; + __m256d ymm20; + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s; + double *D_A_pack = NULL; + double d11_pack[D_MR] __attribute__((aligned(64))); + rntm_t rntm; + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); - for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' dimension + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if((D_MR * m * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) { - for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' dimension + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + } + + /* + Performs solving TRSM for 8 colmns at a time from 0 to m/8 in steps of D_MR + a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-8) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of D_NR + */ + for(i = (m - D_MR); (i + 1) > 0; i -= D_MR) + { + a10 = L + (i*cs_a) + i + D_MR; //pointer to block of A to be used for GEMM + a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM + + // Do transpose for a10 & store in D_A_pack + ptr_a10_dup = D_A_pack; + dim_t p_lda = (m-i-D_MR); // packed leading dimension + /* + Load, transpose and pack current A block (a10) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by D_MR for every next itteration + untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + for(dim_t x =0;x < (m-i-D_MR);x+=D_MR) { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm10 = _mm256_loadu_pd((double const *)(a10 + 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm11 = _mm256_loadu_pd((double const *)(a10 + 4 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm12 = _mm256_loadu_pd((double const *)(a10 + 4 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); + ymm13 = _mm256_loadu_pd((double const *)(a10 + 4 + cs_a * 3)); - k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4) + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); + ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 4), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 5), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 6), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 7), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + cs_a * 4)); + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 4 + 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a * 5)); + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 5 + 4)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 6)); + ymm12 = _mm256_loadu_pd((double const *)(a10 + cs_a * 6 + 4)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 7)); + ymm13 = _mm256_loadu_pd((double const *)(a10 + cs_a * 7 + 4)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup + 4), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), ymm9); + + ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); + ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 4 + 4), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 5 + 4), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 6 + 4), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 7 + 4), ymm9); + + a10 += D_MR; + ptr_a10_dup += D_MR; + } + + /* + Pack 8 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(is_unitdiag) + { + _mm256_storeu_pd((double *)(d11_pack), ymm4); + _mm256_storeu_pd((double *)(d11_pack + 4), ymm4); + }else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); + + //Pick one element each column and create a 4 element vector and store + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm0 = _mm256_div_pd(ymm4, ymm1); + _mm256_storeu_pd((double *)(d11_pack), ymm0); + + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11 + 4 + cs_a*4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5 + cs_a*5)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 + cs_a*6)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 7 + cs_a*7)); + + //Pick one element each column and create a 4 element vector and store + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm0 = _mm256_div_pd(ymm4, ymm1); + _mm256_storeu_pd((double *)(d11_pack + 4), ymm0); + } + + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along n dimension for every D_NR rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) + { + a10 = D_A_pack; + b01 = B + (j*cs_b) + i + D_MR; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - D_MR) / D_MR; //number of times GEMM to be performed + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3 + 4)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4 + 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5 + 4)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + b01 += 1; + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +D_MR; + ptr_a10_dup = a10; + } + } + //GEMM block end here + + /* + Load b11 of size 6x8 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm12); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm13); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm14); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm15); + + ymm13 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm15 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm15 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *4 + 4)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *5 + 4)); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); + ymm18 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm17 = _mm256_permute2f128_pd(ymm18,ymm16,0x20); + ymm19 = _mm256_permute2f128_pd(ymm18,ymm16,0x31); + + ////unpackhigh//// + ymm20 = _mm256_unpackhi_pd(ymm2, ymm3); + + //rearrange high elements + ymm18 = _mm256_permute2f128_pd(ymm20,ymm16,0x20); + ymm20 = _mm256_permute2f128_pd(ymm20,ymm16,0x31); + + /* + Compute 8x6 TRSM block by using GEMM block output in register + a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers + 1. ymm15, ymm20 2. ymm14, ymm19 3. ymm13, ymm18 , 4. ymm12, ymm17 + 5. ymm11, ymm7 6. ymm10, ymm6, 7.ymm9, ymm5 8. ymm8, ymm4 + where ymm15-ymm8 holds 8x4 data and reaming 8x2 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in b11 + */ + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + ymm20 = _mm256_mul_pd(ymm20, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw7): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm20, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm15, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm20, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm15, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm20, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm15, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm20, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm15, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm20, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm15, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm20, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm15, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm20, ymm4); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm19 = _mm256_mul_pd(ymm19, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw6): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm14, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm19, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm14, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm19, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm14, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm19, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm14, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm19, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm14, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm19, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm14, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm19, ymm4); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + ymm18 = _mm256_mul_pd(ymm18, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm13, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm18, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm13, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm18, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm13, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm18, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm13, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm18, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm13, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm18, ymm4); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + ymm17 = _mm256_mul_pd(ymm17, ymm1); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw4): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm12, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm17, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm12, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm17, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm12, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm17, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm12, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm17, ymm4); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + ymm7 = _mm256_mul_pd(ymm7, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + ymm6 = _mm256_mul_pd(ymm6, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + ymm5 = _mm256_mul_pd(ymm5, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + ymm4 = _mm256_mul_pd(ymm4, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm12, ymm13); + ymm3 = _mm256_unpacklo_pd(ymm14, ymm15); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); + + ///unpack high/// + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); + ymm3 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); + + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm1); + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm2); + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm3); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm17, ymm18); + ymm3 = _mm256_unpacklo_pd(ymm19, ymm20); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + + ///unpack high/// + ymm17 = _mm256_unpackhi_pd(ymm17, ymm18); + ymm18 = _mm256_unpackhi_pd(ymm19, ymm20); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm17, ymm18, 0x20); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4 + 4), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 5 + 4), ymm1); + + } + + dim_t n_remainder = j + D_NR; + if(n_remainder >= 4) + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + i; + b01 = B + ((n_remainder - 4)* cs_b) + i + D_MR; + b11 = B + ((n_remainder - 4)* cs_b) + i; + + k_iter = (m - i - D_MR) / D_MR; + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ymm8 = _mm256_setzero_pd(); ymm9 = _mm256_setzero_pd(); ymm10 = _mm256_setzero_pd(); @@ -3073,109 +8627,58 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm15 = _mm256_setzero_pd(); ///GEMM code begins/// - - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations { - ptr_b01_dup = b01; + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - ymm16 = _mm256_loadu_pd((double const *)(a10));//A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - b01 += 1; //mobe to next row of B + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a));//A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3 + 4)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); - b01 += 1; //mobe to next row of B + b01 += 1; //move to next row of B + a10 += p_lda; - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) - - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2));//A10[0][2] A10[1][2] A10[2][2] A10[3][2] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] - - b01 += 1; //mobe to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][4]*A10[3][2]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][5]*A10[3][2]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][6]*A10[3][2]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][7]*A10[3][2]) - - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3));//A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] - - b01 += 1; //mobe to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[3][0] B01[3][0]*A10[2][3] B01[3][0]*A10[3][0]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[3][1]*A10[0][3] B01[3][1]*A10[3][0] B01[3][1]*A10[2][3] B01[3][1]*A10[3][0]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[3][2]*A10[0][3] B01[3][2]*A10[3][0] B01[3][2]*A10[2][3] B01[3][2]*A10[3][0]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[3][3]*A10[0][3] B01[3][3]*A10[3][0] B01[3][3]*A10[2][3] B01[3][3]*A10[3][0]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[3][4]*A10[0][3] B01[3][4]*A10[3][0] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[3][5]*A10[0][3] B01[3][5]*A10[3][0] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[3][6]*A10[0][3] B01[3][6]*A10[3][0] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[3][7]*A10[0][3] B01[3][7]*A10[3][0] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to calculate next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +D_MR; + ptr_a10_dup = a10; + } } ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -3184,10 +8687,11 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *7)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] @@ -3229,48 +8733,137 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] - ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0] - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][0] + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - a11 += cs_a; + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= A11[1][0] * B11[0-3][0] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= A11[2][0] * B11[0-3][0] - ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); //B11[3][0-3] -= A11[3][0] * B11[0-3][0] + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= A11[1][0] * B11[0-3][4] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= A11[2][0] * B11[0-3][4] - ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); //B11[7][0-3] -= A11[3][0] * B11[0-3][4] + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1] - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][1] + //(ROw7): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); - a11 += cs_a; + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6)); + + //(ROw6): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5)); + + //(ROw5): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4)); + + //(ROw4): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3)); + + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2)); //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1] - ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); //B11[3][0-3] -= A11[3][1] * B11[0-3][1] + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); - ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5] - ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); //B11[7][0-3] -= A11[3][1] * B11[0-3][5] + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][2] + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - a11 += cs_a; + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1)); - //(ROw1): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); //B11[3][0-3] -= A11[3][2] * B11[0-3][2] + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); - ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); //B11[7][0-3] -= A11[3][2] * B11[0-3][6] + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] //rearrange low elements ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] @@ -3280,49 +8873,43 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] //rearrange high elements ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store B11[6][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store B11[7][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); //store B11[7][0-3] + n_remainder -=4; } - if(m_remainder) //implementation for reamainder rows(when 'M' is not a multiple of D_MR) + if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)() n = 3 { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + a10 = D_A_pack; + a11 = L + (i*cs_a) + i; + b01 = B + i + D_MR; + b11 = B + i; - k_iter = i / D_MR; //number of times GEMM operation to be done(in blocks of 4x4) - - dim_t iter; - - if((j+D_NR) == n) - { - f_temp = f_t; - for(iter = 0; iter < m_remainder; iter++) - f_temp[iter] = (b11 + cs_b * 7)[iter]; - } - else - f_temp = (b11 + cs_b * 7); + k_iter = (m - i - D_MR) / D_MR; + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ymm8 = _mm256_setzero_pd(); ymm9 = _mm256_setzero_pd(); ymm10 = _mm256_setzero_pd(); @@ -3332,1325 +8919,5030 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm14 = _mm256_setzero_pd(); ymm15 = _mm256_setzero_pd(); - ///GEMM code Begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + if(3 == n_remainder) { - ptr_b01_dup = b01; + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - ymm16 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - b01 += 1; //move to next row of B + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0] ) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm16 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 1)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + b01 += 1; //move to next row of B + a10 += p_lda; - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] - - b01 += 1; //move to next row of B01 - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) - - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] //A10[1][2] A10[2][2] A10[3][2] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] - - b01 += 1; //move to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][0]*A10[3][2]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][1]*A10[3][2]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][2]*A10[3][2]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] - - b01 += 1; //move to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm8 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm8 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm8 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm8 += (B01[3][0]*A10[0][3] B01[3][4]*A10[1][3] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm8 += (B01[3][1]*A10[0][3] B01[3][5]*A10[1][3] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm8 += (B01[3][2]*A10[0][3] B01[3][6]*A10[1][3] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm8 += (B01[3][3]*A10[0][3] B01[3][7]*A10[1][3] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +D_MR; + ptr_a10_dup = a10; + } } - ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +D_MR; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +D_MR; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_broadcast_sd((double const *)(&ones)); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7)); + + //(ROw7): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6)); + + //(ROw6): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5)); + + //(ROw5): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4)); + + //(ROw4): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3)); + + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2)); + + //(ROw2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1)); + + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + if(3 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + } + else if(2 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + } + else if(1 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + } + } + }// End of multiples of D_MR blocks in m-dimension + + // Repetative A blocks will be 4*4 + dim_t m_remainder = i + D_MR; + if(m_remainder >= 4) + { + i = m_remainder - 4; + a10 = L + (i*cs_a) + i + 4; //pointer to block of A to be used for GEMM + a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM + + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + dim_t p_lda = m-i+4; // packed leading dimension + for(dim_t x =0;x < m-i+4;x+=4) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += 4; + ptr_a10_dup += 4; + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm4 = _mm256_div_pd(ymm4, ymm1); + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + //cols + for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b) + i + 4; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + + ////unpacklow//// + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + + //rearrange high elements + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + ymm7 = _mm256_mul_pd(ymm7, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + ymm6 = _mm256_mul_pd(ymm6, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + ymm5 = _mm256_mul_pd(ymm5, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + ymm4 = _mm256_mul_pd(ymm4, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] + } + dim_t n_remainder = j + D_NR; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + i + 4; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup + 4; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + n_remainder = n_remainder - 4; + } + + if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)() n = 3 + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + i; + b01 = B + i + 4; + b11 = B + i; + + k_iter = (m - i - 4); + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_remainder) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] *alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] *alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] *alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] *alpha -= B01[0-3][3] - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] *alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] *alpha -= B01[0-3][5] - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] *alpha -= B01[0-3][6] - ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] *alpha -= B01[0-3][7] - - if(3 == m_remainder) - { - ///implement TRSM/// - - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - - ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] - ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0] - - a11 += cs_a; - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= B11[0-3][0]*A11[2][0] - - ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= B11[0-3][4]*A11[2][4] - - ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1] - - a11 += cs_a; - - //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1] - - ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5] - - ymm11 = _mm256_broadcast_sd((double const *)(&ones)); - ymm15 = _mm256_broadcast_sd((double const *)(&ones)); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[2][5] B11[3][5] B11[2][7] B11[3][7] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - - ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] - ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] - ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] - - //determine correct values to store - ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x08); - ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x08); - ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x08); - ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x08); - ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x08); - ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x08); - ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x08); - ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x08); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); } - else if(2 == m_remainder) + else if(2 == n_remainder) { - ///implement TRSM/// - - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - - ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] - - a11 += cs_a; - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0] - - ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4] - - ymm10 = _mm256_broadcast_sd((double const *)&ones); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm10, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm10, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_permute2f128_pd(ymm5, ymm10, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm10, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm10, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm10, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm5 = _mm256_permute2f128_pd(ymm12, ymm10, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm10, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - - ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] - ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] - ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] - - //determine correct values to store - ymm0 = _mm256_permute2f128_pd(ymm0, ymm8, 0x30); - ymm1 = _mm256_permute2f128_pd(ymm1, ymm9, 0x30); - ymm2 = _mm256_permute2f128_pd(ymm2, ymm10, 0x30); - ymm3 = _mm256_permute2f128_pd(ymm3, ymm11, 0x30); - ymm4 = _mm256_permute2f128_pd(ymm4, ymm12, 0x30); - ymm5 = _mm256_permute2f128_pd(ymm5, ymm13, 0x30); - ymm6 = _mm256_permute2f128_pd(ymm6, ymm14, 0x30); - ymm7 = _mm256_permute2f128_pd(ymm7, ymm15, 0x30); - - } - else if(1 == m_remainder) + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { - ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] - ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] - ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] + ymm0 = _mm256_loadu_pd((double const *)(a10)); - //determine correct values to store - ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x0E); - ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x0E); - ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x0E); - ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x0E); - ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x0E); - ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x0E); - ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x0E); - ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x0E); + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } } - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store(B11[0-3][4]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store(B11[0-3][5]) - _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store(B11[0-3][6]) - _mm256_storeu_pd((double *)(f_temp), ymm7); //store(B11[0-3][7]) + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - if((j+D_NR) == n) - { - for(iter = 0; iter < m_remainder; iter++) - (b11 + cs_b * 7)[iter] = f_temp[iter]; + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); } - } - } - - if((n & 4)) //implementation for remainder columns(when 'n_remainder' is greater than 4) - { - for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i / D_MR; //number of times GEMM to be performed(in block of 4) - ///GEMM for previously calculated values /// - - //load 4x4 block from b11 - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + else if(1 == n_remainder) { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a*2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[1][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[2][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[3][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + b01 += 1; //move to next row of B + a10 += p_lda; + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B01[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B01[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B01[0-3][3] *alpha -= ymm7 + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } ///implement TRSM/// - //1st col - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] - //2nd col - a11 += cs_a; - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] - - //3rd col - a11 += cs_a; - ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]*B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]*B11[0][0-3] - ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]*B11[0][0-3] + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]*B11[1][0-3] - ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]*B11[1][0-3] + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]*B11[2][0-3] + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3)); + + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2)); + + //(ROw2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1)); + + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[0-3][3]) - - } - if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) - - dim_t iter; - - if((j+4) == n) - { - f_temp = f_t; - for(iter = 0; iter < m_remainder; iter++) - f_temp[iter] = (b11 + cs_b * 3)[iter]; - } - else - f_temp = (b11 + cs_b * 3); - ///GEMM for previously calculated values /// - - //load 4x4 block from b11 - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - for(k = 0; k < k_iter; k++) //looop for number of GEMM operations - { - ptr_b01_dup = b01; - - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - - } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - - ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[0-3][3] *alpha -= ymm7 - - - if(3 == m_remainder) - { - ///implement TRSM/// - //1st col - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] - - //2nd col - a11 += cs_a; - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]* B11[0][0-3] - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]* B11[1][0-3] - - ymm13 = _mm256_broadcast_sd((double const *)(&ones)); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - //load 4x4 block from b11 - ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] - - //determine correct values to store - ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08); - ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08); - ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08); - ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08); - } - else if(2 == m_remainder) - { - ///implement TRSM/// - //1st col - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] - - ymm11 = _mm256_broadcast_sd((double const *)(&ones)); - ymm13 = _mm256_broadcast_sd((double const *)(&ones)); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - //load 4x4 block from b11 - ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] - - //determine correct values to store - ymm0 = _mm256_permute2f128_pd(ymm0, ymm4,0x30); - ymm1 = _mm256_permute2f128_pd(ymm1, ymm5,0x30); - ymm2 = _mm256_permute2f128_pd(ymm2, ymm6,0x30); - ymm3 = _mm256_permute2f128_pd(ymm3, ymm7,0x30); - - } - else if(1 == m_remainder) - { - //load 4x4 block from b11 - ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] - - //determine correct values to store - ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x0E); - ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x0E); - ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x0E); - ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x0E); - } - - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[0-3][3]) - - if((j+4) == n) - { - for(iter = 0; iter < m_remainder; iter++) - (b11 + cs_b * 3)[iter] = f_temp[iter]; - } - } - - n_remainder -= 4; - j += 4; - - } - - if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR) - { - for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM for previously calculated values /// - - //load 4x4 block from b11 if(3 == n_remainder) { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + } + else if(2 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + } + else if(1 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + } + } + m_remainder -= 4; + } - for(k = 0; k < k_iter; k++) + if(m_remainder) + { + a10 = L + m_remainder; + + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + if(3 == m_remainder) // Repetative A blocks will be 3*3 + { + dim_t p_lda = m-m_remainder; // packed leading dimension + for(dim_t x =0;x < m-m_remainder;x+=4) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += 4; + ptr_a10_dup += 4; + } + + //cols + for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - b01 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - b01 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] + b01 += 1; //move to next row of B + a10 += p_lda; - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } } ///GEMM code ends/// - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); } - else if(2 == n_remainder) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - for(k = 0; k < k_iter; k++) + dim_t n_remainder = j + D_NR; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - b01 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - b01 += 1; + b01 += 1; //move to next row of B + a10 += p_lda; - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } } - ///GEMM code ends/// - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value - ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + n_remainder -= 4; } - else if(1 == n_remainder) + if(n_remainder) { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM - for(k = 0; k < k_iter; k++) + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_remainder) { - ptr_b01_dup = b01; + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - b01 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + b01 += 1; //move to next row of B + a10 += p_lda; - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } } - ///GEMM code ends/// - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value - ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) + { + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + } } - - ///implement TRSM/// - //1st col - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] - - //2nd col - a11 += cs_a; - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] - - //3rd col - a11 += cs_a; - ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0] * B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0] * B11[0][0-3] - ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0] * B11[0][0-3] - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1] * B11[1][0-3] - ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1] * B11[1][0-3] - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2] * B11[2][0-3] - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - if(3 == n_remainder) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - - } - else if(2 == n_remainder) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - - } - else if(1 == n_remainder) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - } - } - if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) + else if(2 == m_remainder) // Repetative A blocks will be 2*2 { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - - k_iter = i / D_MR; //number of times GEMM operations to be performed - - dim_t iter; - - if((j+n_remainder) == n) + dim_t p_lda = m-m_remainder; // packed leading dimension + for(dim_t x =0;x < m-m_remainder;x+=4) { - f_temp = f_t; - for(iter = 0; iter < m_remainder; iter++) - f_temp[iter] = (b11 + cs_b * (n_remainder -1))[iter]; + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += 4; + ptr_a10_dup += 4; } - else - f_temp = (b11 + cs_b * (n_remainder -1)); - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM for previously calculated values /// - - - //load 4x4 block from b11 - if(3 == n_remainder) + //cols + for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[0][2] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - b01 += 1; //move to next row of B + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[1][2] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - b01 += 1; //move to next row of B + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[2][2] + b01 += 1; //move to next row of B + a10 += p_lda; - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[3][2] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value - ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 - ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5 - ymm10 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] * alpha -= ymm6 + ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + } + dim_t n_remainder = j + D_NR; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha ///implement TRSM/// - //determine correct values to store - if(3 == m_remainder) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - } - else if(2 == m_remainder) - { - ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); - ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); - ymm2 = _mm256_permute2f128_pd(ymm10, ymm2, 0x30); - } - else if(1 == m_remainder) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - } - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(f_temp), ymm2); //store(B11[0-3][2]) + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + n_remainder -= 4; } - else if(2 == n_remainder) + if(n_remainder) { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_remainder) { - ptr_b01_dup = b01; + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - b01 += 1; //move to next row of B + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] + b01 += 1; //move to next row of B + a10 += p_lda; - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value - ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 - ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5 + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha ///implement TRSM/// - //determine correct values to store - if(3 == m_remainder) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); } - else if(2 == m_remainder) + else if(2 == n_remainder) { - ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); - ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); - } - else if(1 == m_remainder) + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } } - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[0-3][1]) - } - else if(1 == n_remainder) - { - ymm0 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_b01_dup = b01; - - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - - } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value - - ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha ///implement TRSM/// - //determine correct values to store - if(3 == m_remainder) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); } - else if(2 == m_remainder) + else if(1 == n_remainder) { - ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); - } - else if(1 == m_remainder) + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); } - _mm256_storeu_pd((double *)(f_temp), ymm0); //store(B11[0-3][0]) } - if((j+n_remainder) == n) - { - for(iter = 0; iter < m_remainder; iter++) - (b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter]; - } - ///scalar code for trsm without alpha/// - dtrsm_small_AlXB_unitDiag(a11, b11, m_remainder, n_remainder, cs_a, cs_b); } + else if(1 == m_remainder) // Repetative A blocks will be 1*1 + { + dim_t p_lda = m-m_remainder; // packed leading dimension + for(dim_t x =0;x < m-m_remainder;x+=4) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_broadcast_sd((double const *)&ones); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += 4; + ptr_a10_dup += 4; + } + //cols + for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + } + + ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + } + dim_t n_remainder = j + D_NR; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + + } + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + if(n_remainder) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_remainder) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + } + + //register to hold alpha + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ///implement TRSM/// + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + } + } + } + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm,&local_mem_buf_A_s); } return BLIS_SUCCESS; } +/* + * TRSM for the case AX = alpha * B, Double precision + * A is upper-triangular, non-transpose, non-unit diagonal + * dimensions A: mxm X: mxn B: mxn +*/ +static err_t bli_dtrsm_small_AuXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + dim_t D_MR = 8; //size of block along 'M' dimpension + dim_t D_NR = 6; //size of block along 'N' dimension -/*implements TRSM for the case XA = alpha * B + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B + + dim_t cs_a = bli_obj_col_stride(a); // column stride of A + dim_t cs_b = bli_obj_col_stride(b); // column stride of B + + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed + + double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha + double *L = a->buffer; //pointer to matrix A + double *B = b->buffer; //pointer to matrix B + + //pointers that point to blocks for GEMM and TRSM + double *a10, *a11, *b01, *b11; + double *ptr_b01_dup; + double *ptr_a10_dup; + + double ones = 1.0; + bool is_unitdiag = bli_obj_has_unit_diag(a); + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19; + __m256d ymm20; + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s; + double *D_A_pack = NULL; + double d11_pack[D_MR] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if ( (D_MR * m * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + } + + /* + Performs solving TRSM for 8 colmns at a time from 0 to m/8 in steps of D_MR + a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-8) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n row of B in steps of D_NR + */ + for(i = (m - D_MR); (i + 1) > 0; i -= D_MR) + { + a10 = L + (i + D_MR)*cs_a + i; //pointer to block of A to be used for GEMM + a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM + + // Do transpose for a10 & store in D_A_pack + ptr_a10_dup = D_A_pack; //ptr_a11_dup = a11; + dim_t p_lda = D_MR; // packed leading dimension + + /* + Pack current A block (a10) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by D_MR for every next itteration + untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + for(dim_t x =0;x < (m-i-D_MR);x++) + { + ymm16 = _mm256_loadu_pd((double const *)(a10 + (cs_a * x))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * x)), ymm16); + ymm16 = _mm256_loadu_pd((double const *)(a10 + (cs_a * x) + 4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * x) + 4), ymm16); + } + + /* + Pack 8 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(is_unitdiag) + { + _mm256_storeu_pd((double *)(d11_pack), ymm4); + _mm256_storeu_pd((double *)(d11_pack + 4), ymm4); + }else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm0 = _mm256_div_pd(ymm4, ymm1); + _mm256_storeu_pd((double *)(d11_pack), ymm0); + + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11 + 4 + cs_a*4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5 + cs_a*5)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 + cs_a*6)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 7 + cs_a*7)); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm0 = _mm256_div_pd(ymm4, ymm1); + _mm256_storeu_pd((double *)(d11_pack + 4), ymm0); + } + + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along n dimension for every D_NR rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; + b01 = B + (j*cs_b) + i + D_MR; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - D_MR) / D_MR; //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3 + 4)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4 + 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5 + 4)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + b01 += 1; + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup + D_MR*p_lda; + ptr_a10_dup = a10; + } + } + //GEMM block ends here + + /* + Load b11 of size 6x8 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm12); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm13); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm14); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm15); + + ymm13 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm15 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm15 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *4 + 4)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *5 + 4)); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); + ymm18 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm17 = _mm256_permute2f128_pd(ymm18,ymm16,0x20); + ymm19 = _mm256_permute2f128_pd(ymm18,ymm16,0x31); + ymm20 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm18 = _mm256_permute2f128_pd(ymm20,ymm16,0x20); + ymm20 = _mm256_permute2f128_pd(ymm20,ymm16,0x31); + + /* + Compute 8x6 TRSM block by using GEMM block output in register + a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers + 1. ymm15, ymm20 2. ymm14, ymm19 3. ymm13, ymm18 , 4. ymm12, ymm17 + 5. ymm11, ymm7 6. ymm10, ymm6, 7.ymm9, ymm5 8. ymm8, ymm4 + where ymm15-ymm8 holds 8x4 data and reaming 8x2 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in b11 + */ + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + ymm20 = _mm256_mul_pd(ymm20, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw7): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 + 7*cs_a)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm20, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5 + 7*cs_a)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm15, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm20, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 7*cs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm15, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm20, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 7*cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm15, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm20, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 7*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm15, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm20, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 7*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm15, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm20, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm15, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm20, ymm4); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm19 = _mm256_mul_pd(ymm19, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw6): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5 + 6*cs_a)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm14, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm19, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 6*cs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm14, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm19, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 6*cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm14, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm19, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 6*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm14, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm19, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 6*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm14, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm19, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm14, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm19, ymm4); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + ymm18 = _mm256_mul_pd(ymm18, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 5*cs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm13, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm18, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 5*cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm13, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm18, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 5*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm13, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm18, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 5*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm13, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm18, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm13, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm18, ymm4); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + ymm17 = _mm256_mul_pd(ymm17, ymm1); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw4): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 4*cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm12, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm17, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 4*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm12, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm17, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 4*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm12, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm17, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm12, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm17, ymm4); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + ymm7 = _mm256_mul_pd(ymm7, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + ymm6 = _mm256_mul_pd(ymm6, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + ymm5 = _mm256_mul_pd(ymm5, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + ymm4 = _mm256_mul_pd(ymm4, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm12, ymm13); + ymm3 = _mm256_unpacklo_pd(ymm14, ymm15); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); + + ///unpack high/// + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); + ymm3 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); + + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm1); + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm2); + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm3); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm17, ymm18); + ymm3 = _mm256_unpacklo_pd(ymm19, ymm20); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + + ///unpack high/// + ymm17 = _mm256_unpackhi_pd(ymm17, ymm18); + ymm18 = _mm256_unpackhi_pd(ymm19, ymm20); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm17, ymm18, 0x20); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4 + 4), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 5 + 4), ymm1); + + } + + dim_t n_remainder = j + D_NR; + if(n_remainder >= 4) + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + i; + b01 = B + ((n_remainder - 4)* cs_b) + i + D_MR; + b11 = B + ((n_remainder - 4)* cs_b) + i; + + k_iter = (m - i - D_MR) / D_MR; + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3 + 4)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup + D_MR*p_lda; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); + ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); + + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 + 7*cs_a)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5 + 7*cs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 + 7*cs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 + 7*cs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 + 7*cs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 7*cs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7*cs_a)); + + //(ROw7): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5 + 6*cs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 + 6*cs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 + 6*cs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 + 6*cs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 6*cs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a)); + + //(ROw6): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 + 5*cs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 + 5*cs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 + 5*cs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 5*cs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); + + //(ROw5): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 + 4*cs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 + 4*cs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 4*cs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); + + //(ROw4): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + + //(ROw2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1*cs_a)); + + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); + n_remainder -=4; + } + + if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)() n = 3 + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + i; + b01 = B + i + D_MR; + b11 = B + i; + + k_iter = (m - i - D_MR) / D_MR; + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_remainder) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup + D_MR*p_lda; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_remainder) + { + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup + D_MR*p_lda; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup + D_MR*p_lda; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); + ymm5 = _mm256_broadcast_sd((double const *)(&ones)); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw7): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 + 7*cs_a)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5 + 7*cs_a)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm15, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 7*cs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm15, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 7*cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm15, ymm11); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 7*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm15, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 7*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm15, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm15, ymm8); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw6): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5 + 6*cs_a)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm14, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 6*cs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm14, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 6*cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm14, ymm11); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 6*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm14, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 6*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm14, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm14, ymm8); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 5*cs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm13, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 5*cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm13, ymm11); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 5*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm13, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 5*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm13, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm13, ymm8); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw4): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 4*cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm12, ymm11); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 4*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm12, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 4*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm12, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm12, ymm8); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); + + if(3 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); + } + else if(2 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); + } + else if(1 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); + } + } + }// End of multiples of D_MR blocks in m-dimension + + // Repetative A blocks will be 4*4 + dim_t m_remainder = i + D_MR; + if(m_remainder >= 4) + { + i = m_remainder - 4; + a10 = L + (i + 4)*cs_a + i; //pointer to block of A to be used for GEMM + a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM + + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + double *ptr_a11_dup = a11; + dim_t p_lda = 4; // packed leading dimension + for(dim_t x =0;x < m-i-4;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x*cs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); + + //Pick one element each column and create a 4 element vector and store + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm4 = _mm256_div_pd(ymm4, ymm1); + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + //cols + for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = ptr_a11_dup; //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b) + i + 4; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4*p_lda; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + + ////unpacklow//// + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + + //rearrange high elements + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + ymm7 = _mm256_mul_pd(ymm7, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + ymm6 = _mm256_mul_pd(ymm6, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + ymm5 = _mm256_mul_pd(ymm5, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + ymm4 = _mm256_mul_pd(ymm4, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] + } + + dim_t n_remainder = j + D_NR; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + i + 4; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - 4); //number of times GEMM to be performed + + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + n_remainder = n_remainder - 4; + } + + if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)() n = 3 + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + i; + b01 = B + i + 4; + b11 = B + i; + + k_iter = (m - i - 4); + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + if(3 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + } + else if(2 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + } + else if(1 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + } + } + m_remainder -= 4; + } + + if(m_remainder) + { + a10 = L + m_remainder*cs_a; + + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + if(3 == m_remainder) // Repetative A blocks will be 3*3 + { + dim_t p_lda = 4; // packed leading dimension + for(dim_t x =0;x < m-m_remainder;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x*cs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); + } + //cols + for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + } + + dim_t n_remainder = j + D_NR; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed + + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + if(n_remainder) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed + + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) + { + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + } + } + } + else if(2 == m_remainder) // Repetative A blocks will be 2*2 + { + dim_t p_lda = 4; // packed leading dimension + for(dim_t x =0;x < m-m_remainder;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x*cs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); + } + //cols + for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + } + dim_t n_remainder = j + D_NR; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + if(n_remainder) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + } + } + + } + else if(1 == m_remainder) // Repetative A blocks will be 1*1 + { + dim_t p_lda = 4; // packed leading dimension + for(dim_t x =0;x < m-m_remainder;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x*cs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); + } + //cols + for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed + + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + } + dim_t n_remainder = j + D_NR; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed + + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + if(n_remainder) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + //register to hold alpha + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ///implement TRSM/// + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + dtrsm_AuXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + } + } + } + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm, &local_mem_buf_A_s); + } + return BLIS_SUCCESS; +} + +/* TRSM for the case XA = alpha * B *A is upper triangular, non-unit diagonal, no transpose *dimensions: X:mxn A:nxn B: mxn - */ - -/* b11---> a01 ----> + * + * b11---> a01 ----> ***************** *********** *b01*b11* * * * * * * b11 * * * * * **a01 * * a11 @@ -4664,14 +13956,14 @@ b11 * * * * * **a01 * * a11 * */ -static err_t bli_dtrsm_small_XAuB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +static err_t bli_dtrsm_small_XAuB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns @@ -4685,18 +13977,6 @@ static err_t bli_dtrsm_small_XAuB( dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B -#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME) - || (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_COLUMN_PANEL_N) - ) - return BLIS_NOT_YET_IMPLEMENTED; -#else - if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO) - { - return BLIS_NOT_YET_IMPLEMENTED; - } -#endif - dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides @@ -4710,8 +13990,8 @@ static err_t bli_dtrsm_small_XAuB( double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks double *ptr_a01_dup; - double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0 - double* f_temp; + double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0 + double* f_temp; cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; @@ -5877,7 +15157,7 @@ static err_t bli_dtrsm_small_XAuB( } if(m_remainder) ///omplementation for remainder rows { - for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' direction + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction { a01 = L + j*cs_a; //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM @@ -6366,18 +15646,17 @@ static err_t bli_dtrsm_small_XAuB( (b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter]; } //scalar code for TRSM - dtrsm_small_XAuB(a11, b11, m_remainder, n_remainder, cs_a, cs_b); + dtrsm_XAuB_ref(a11, b11, m_remainder, n_remainder, cs_a, cs_b); } } return BLIS_SUCCESS; } -/*implements TRSM for the case XA = alpha * B +/* TRSM for the case XA = alpha * B *A is upper triangular, unit-diagonal, no transpose *dimensions: X:mxn A:nxn B: mxn - */ - -/* b11---> a01 ----> + * + * b11---> a01 ----> ***************** *********** *b01*b11* * * * * * * b11 * * * * * **a01 * * a11 @@ -6391,14 +15670,14 @@ b11 * * * * * **a01 * * a11 * */ -static err_t bli_dtrsm_small_XAuB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +static err_t bli_dtrsm_XAuB_unitDiag_ref +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns @@ -6412,18 +15691,6 @@ static err_t bli_dtrsm_small_XAuB_unitDiag( dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B -#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME) - || (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_COLUMN_PANEL_N) - ) - return BLIS_NOT_YET_IMPLEMENTED; -#else - if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO) - { - return BLIS_NOT_YET_IMPLEMENTED; - } -#endif - dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides @@ -7367,7 +16634,7 @@ static err_t bli_dtrsm_small_XAuB_unitDiag( } if(m_remainder) ///omplementation for remainder rows { - for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' direction + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction { a01 = L + j*cs_a; //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM @@ -7819,19 +17086,17 @@ static err_t bli_dtrsm_small_XAuB_unitDiag( (b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter]; } //scalar code for TRSM - dtrsm_small_XAuB_unitDiag(a11, b11, m_remainder, n_remainder, cs_a, cs_b); + dtrsm_XAuB_unitDiag_ref(a11, b11, m_remainder, n_remainder, cs_a, cs_b); } } return BLIS_SUCCESS; } - -/*implements TRSM for the case XA = alpha * B +/* TRSM for the case XA = alpha * B *A is lower triangular, non-unit diagonal, transpose *dimensions: X:mxn A:nxn B: mxn - */ - -/* b11---> a01 ----> + * + * b11---> a01 ----> ***************** *********** *b01*b11* * * * * * * b11 * * * * * **a01 * * a11 @@ -7843,16 +17108,15 @@ b11 * * * * * **a01 * * a11 * * * * * * * ***************** * * * - */ -static err_t bli_dtrsm_small_XAltB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +static err_t bli_dtrsm_small_XAltB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns @@ -7866,21 +17130,6 @@ static err_t bli_dtrsm_small_XAltB( dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B -#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_N) - || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_M && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_N) - || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME) - || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME) - || (m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N) - ) - return BLIS_NOT_YET_IMPLEMENTED; -#else - if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) - { - return BLIS_NOT_YET_IMPLEMENTED; - } -#endif - dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides @@ -9074,7 +18323,7 @@ static err_t bli_dtrsm_small_XAltB( } if(m_remainder) ///omplementation for remainder rows { - for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' direction + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction { a01 = L + j; //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM @@ -9574,18 +18823,17 @@ static err_t bli_dtrsm_small_XAltB( (b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter]; } //scalar code for TRSM - dtrsm_small_XAltB(a11, b11, m_remainder, n_remainder, cs_a, cs_b); + dtrsm_XAltB_ref(a11, b11, m_remainder, n_remainder, cs_a, cs_b); } } return BLIS_SUCCESS; } -/*implements TRSM for the case XA = alpha * B +/* TRSM for the case XA = alpha * B *A is lower triangular, unit-diagonal, transpose *dimensions: X:mxn A:nxn B: mxn - */ - -/* b11---> a01 ----> + * + * b11---> a01 ----> ***************** *********** *b01*b11* * * * * * * b11 * * * * * **a01 * * a11 @@ -9599,14 +18847,14 @@ b11 * * * * * **a01 * * a11 * */ -static err_t bli_dtrsm_small_XAltB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +static err_t bli_dtrsm_small_XAltB_unitDiag +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns @@ -9620,21 +18868,6 @@ static err_t bli_dtrsm_small_XAltB_unitDiag( dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B -#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_N) - || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_M && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_N) - || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME) - || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME) - || (m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N) - ) - return BLIS_NOT_YET_IMPLEMENTED; -#else - if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) - { - return BLIS_NOT_YET_IMPLEMENTED; - } -#endif - dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides @@ -10583,7 +19816,7 @@ static err_t bli_dtrsm_small_XAltB_unitDiag( } if(m_remainder) ///omplementation for remainder rows { - for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' direction + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction { a01 = L + j; //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM @@ -11043,19 +20276,17 @@ static err_t bli_dtrsm_small_XAltB_unitDiag( (b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter]; } //scalar code for TRSM - dtrsm_small_XAltB_unitDiag(a11, b11, m_remainder, n_remainder, cs_a, cs_b); + dtrsm_XAltB_unitDiag_ref(a11, b11, m_remainder, n_remainder, cs_a, cs_b); } } return BLIS_SUCCESS; } - -/*implements TRSM for the case XA = alpha * B +/* TRSM for the case XA = alpha * B *A is lower triangular, non-unit diagonal, no transpose *dimensions: X:mxn A:nxn B: mxn - */ - -/* <---b11 <---a11 + * + * <---b11 <---a11 ***************** * *b01*b11* * * * * ^ * * * * * ^ * * @@ -11068,14 +20299,14 @@ b10 ***************** ************* ***************** ******************* */ -static err_t bli_dtrsm_small_XAlB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +static err_t bli_dtrsm_small_XAlB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns @@ -11089,19 +20320,6 @@ static err_t bli_dtrsm_small_XAlB( dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B - -#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME) - ||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N) - ) - return BLIS_NOT_YET_IMPLEMENTED; -#else - if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) - { - return BLIS_NOT_YET_IMPLEMENTED; - } -#endif - dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides @@ -12294,17 +21512,16 @@ static err_t bli_dtrsm_small_XAlB( // if(i < 0) i = 0; if(m_remainder) ///implementation for remainder rows { - dtrsm_small_XAlB(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); + dtrsm_XAlB_ref(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); } return BLIS_SUCCESS; } -/*implements TRSM for the case XA = alpha * B +/* TRSM for the case XA = alpha * B *A is lower triangular, unit-diagonal, no transpose *dimensions: X:mxn A:nxn B: mxn - */ - -/* <---b11 <---a11 + * + * <---b11 <---a11 ***************** * *b01*b11* * * * * ^ * * * * * ^ * * @@ -12315,16 +21532,15 @@ b10 ***************** ************* * * * * * * * * * * * * * * * * * * ***************** ******************* - */ -static err_t bli_dtrsm_small_XAlB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +static err_t bli_dtrsm_small_XAlB_unitDiag +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns @@ -12338,19 +21554,6 @@ static err_t bli_dtrsm_small_XAlB_unitDiag( dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B - -#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME) - ||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N) - ) - return BLIS_NOT_YET_IMPLEMENTED; -#else - if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) - { - return BLIS_NOT_YET_IMPLEMENTED; - } -#endif - dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides @@ -13288,18 +22491,16 @@ static err_t bli_dtrsm_small_XAlB_unitDiag( // if(i < 0) i = 0; if(m_remainder) ///implementation for remainder rows { - dtrsm_small_XAlB_unitDiag(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); + dtrsm_XAlB_unitDiag_ref(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); } return BLIS_SUCCESS; } - -/*implements TRSM for the case XA = alpha * B +/* TRSM for the case XA = alpha * B *A is lower triangular, non-unit diagonal, no transpose *dimensions: X:mxn A:nxn B: mxn - */ - -/* <---b11 <---a11 + * + * <---b11 <---a11 ***************** * *b01*b11* * * * * ^ * * * * * ^ * * @@ -13312,14 +22513,14 @@ b10 ***************** ************* ***************** ******************* */ -static err_t bli_dtrsm_small_XAutB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +static err_t bli_dtrsm_small_XAutB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns @@ -13333,18 +22534,6 @@ static err_t bli_dtrsm_small_XAutB( dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B -#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME) - ||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N) - ) - return BLIS_NOT_YET_IMPLEMENTED; -#else - if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) - { - return BLIS_NOT_YET_IMPLEMENTED; - } -#endif - dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides @@ -14558,17 +23747,16 @@ static err_t bli_dtrsm_small_XAutB( } if(m_remainder) ///implementation for remainder rows { - dtrsm_small_XAutB(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); + dtrsm_XAutB_ref(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); } return BLIS_SUCCESS; } -/*implements TRSM for the case XA = alpha * B +/* TRSM for the case XA = alpha * B *A is lower triangular, unit-diagonal, no transpose *dimensions: X:mxn A:nxn B: mxn - */ - -/* <---b11 <---a11 + * + * <---b11 <---a11 ***************** * *b01*b11* * * * * ^ * * * * * ^ * * @@ -14581,14 +23769,14 @@ b10 ***************** ************* ***************** ******************* */ -static err_t bli_dtrsm_small_XAutB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +static err_t bli_dtrsm_small_XAutB_unitDiag +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns @@ -14602,18 +23790,6 @@ static err_t bli_dtrsm_small_XAutB_unitDiag( dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B -#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME) - ||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N) - ) - return BLIS_NOT_YET_IMPLEMENTED; -#else - if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) - { - return BLIS_NOT_YET_IMPLEMENTED; - } -#endif - dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides @@ -15566,12255 +24742,9 @@ static err_t bli_dtrsm_small_XAutB_unitDiag( } if(m_remainder) ///implementation for remainder rows { - dtrsm_small_XAutB_unitDiag(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); + dtrsm_XAutB_unitDiag_ref(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); } return BLIS_SUCCESS; } - -/* - * AX = Alpha*B, Single precision, A: lower triangular - * This kernel implementation supports matrices A and B such that m is equal to BLI_AlXB_M_SP and n is mutiple of 8 - */ - -static err_t bli_strsm_small_AlXB ( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) -{ - obj_t alpha, beta; // gemm parameters - obj_t Ga, Gb, Gc; // for GEMM - int m = bli_obj_length(b); // number of rows of matrix B - int n = bli_obj_width(b); // number of columns of matrix B - - int lda = bli_obj_col_stride(a); // column stride of A - int ldb = bli_obj_col_stride(b); // column stride of B - - int rsa = bli_obj_row_stride(a); // row stride of A - int rsb = bli_obj_row_stride(b); // row stride of B - - int i = 0; - int j; - int blk_size = 8; - int isUnitDiag = bli_obj_has_unit_diag(a); - - float alphaVal; - float* restrict L = a->buffer; - float* restrict B = b->buffer; - - if (m != BLI_AlXB_M_SP || (n&7) != 0) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - if ( (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM ) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - - alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj)); - - /* Small _GEMM preparation code */ - bli_obj_create( BLIS_FLOAT, 1, 1, 0, 0, &alpha ); - bli_obj_create( BLIS_FLOAT, 1, 1, 0, 0, &beta ); - - /* B = B - A*B */ - bli_setsc( -(1.0), 0.0, &alpha ); - bli_setsc( (1.0), 0.0, &beta ); - - - bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, blk_size, a->buffer, rsa, lda, &Ga); - bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, n, b->buffer, rsb, ldb, &Gb); - bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, n, b->buffer, rsb, ldb, &Gc); - - bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Ga ); - bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Gb ); - bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Gc ); - - //first block of trsm - Gb.buffer = (void*)(B + i); - - //trsm of first 8xn block - if (alphaVal != 1) - { - if (isUnitDiag == 0) - { - blis_strsm_microkernel_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); - fp_blis_strsm_microkernel = blis_strsm_microkernel; - } - else - { - blis_strsm_microkernel_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); - fp_blis_strsm_microkernel = blis_strsm_microkernel_unitDiag; - } - bli_setsc( alphaVal, 0.0, &beta ); - } - else - { - if (isUnitDiag == 0) - { - blis_strsm_microkernel((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); - fp_blis_strsm_microkernel = blis_strsm_microkernel; - } - else - { - blis_strsm_microkernel_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); - fp_blis_strsm_microkernel = blis_strsm_microkernel_unitDiag; - } - } - - //gemm update - for (j = i + blk_size; j < m; j += blk_size) // for rows upto multiple of BLOCK_HEIGHT - { - Ga.buffer = (void*)(L + j + i*lda); - Gc.buffer = (void*)(B + j); - - bli_gemm_small(&alpha, &Ga, &Gb, &beta, &Gc, cntx, cntl ); // Gc = beta*Gc + alpha*Ga *Gb - } - - //trsm of remaining blocks - for (i = blk_size; i < m; i += blk_size) - { - Gb.buffer = (void*)(B + i); - - fp_blis_strsm_microkernel((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); - - for (j = i + blk_size; j < m; j += blk_size) // for rows upto multiple of BLOCK_HEIGHT - { - Ga.buffer = (void*)(L + j + i*lda); - Gc.buffer = (void*)(B + j); - - bli_gemm_small(&alpha, &Ga, &Gb, &beta, &Gc, cntx, cntl ); // Gc = beta*Gc + alpha*Ga *Gb - } - - } // End of for loop - i - - return BLIS_SUCCESS; -} - - - -/* - * XA' = Alpha*B, Single precision, A: lower triangular - * This kernel implementation supports matrices A and B such that - * m and n are multiples of 8 and n is less than or equal to BLI_XAltB_N_SP - */ -static err_t bli_strsm_small_XAltB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) -{ - int m = bli_obj_length(a); // number of rows of matrix B - int n = bli_obj_length(b); // number of columns of matrix B - - int lda = bli_obj_col_stride(a); // column stride of A - int ldb = bli_obj_col_stride(b); // column stride of B - - int rsa = bli_obj_row_stride(a); // row stride of A - int rsb = bli_obj_row_stride(b); // row stride of B - - int i = 0; - int isUnitDiag = bli_obj_has_unit_diag(a); - - float alphaVal; - float *L = a->buffer; - float *B = b->buffer; - - if ((m&7) != 0 || (n&7) != 0) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - if ( n > BLI_XAltB_N_SP || (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM ) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - - alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj)); - - if (alphaVal != 1) - { - if (isUnitDiag == 0) - { - trsm_XAtB_block_allSmallSizedMatrices_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); - } - else - { - trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); - } - } - else - { - if (isUnitDiag == 0) - { - trsm_XAtB_block_allSmallSizedMatrices((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); - } - else - { - trsm_XAtB_block_allSmallSizedMatrices_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); - } - } - return BLIS_SUCCESS; -} - -/* - * A'X = Alpha*B, Single precision, A: upper triangular - * This kernel implementation supports matrices A and B such that - * m and n are multiples of 8, m is less than or equal to BLI_AutXB_M_SP and n is less than or equal to BLI_AutXB_N_SP - */ -static err_t bli_strsm_small_AutXB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) -{ - int m = bli_obj_width(a); // number of rows of matrix A (since At, so width is taken) - int n = bli_obj_width(b); // number of columns of matrix B - - int lda = bli_obj_col_stride(a); // column stride of A - int ldb = bli_obj_col_stride(b); // column stride of B - - int rsa = bli_obj_row_stride(a); // row stride of A - int rsb = bli_obj_row_stride(b); // row stride of B - - int i = 0; - int isUnitDiag = bli_obj_has_unit_diag(a); - - float alphaVal; - float *L = a->buffer; - float *B = b->buffer; - - if ((m&7) != 0 || (n&7) != 0) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - if ( m > BLI_AutXB_M_SP || n > BLI_AutXB_N_SP || (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM ) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - - alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj)); - - if (alphaVal != 1) - { - if (isUnitDiag == 0) - { - trsm_AutXB_block_allSmallSizedMatrices_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); - } - else - { - trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); - } - } - else - { - if (isUnitDiag == 0) - { - trsm_AutXB_block_allSmallSizedMatrices((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); - } - else - { - trsm_AutXB_block_allSmallSizedMatrices_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); - } - } - return BLIS_SUCCESS; -} - -///////////////////////////// AX=B /////////////////////////////// -static void blis_strsm_microkernel_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alphaVal) -{ - float ones = 1.0; - int j; - int cs_b_offset[6]; - //int row2, row4, row6; - float *ptr_b_dup; - - //70 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_cols[8]; - __m256 mat_a_cols_rearr[36]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags; - __m256 alphaReg; - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - reciprocal_diags = _mm256_broadcast_ss((float const *)&ones); - alphaReg = _mm256_broadcast_ss((float const *)&alphaVal); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - //_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0); - //row2 = (cs_l << 1); - //row4 = (cs_l << 2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - //_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - //_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - //_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0); - //row6 = row2 + row4; - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - //_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - //_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - //_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - //_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0); - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - - //read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L - /*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/ - - //Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers - //tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually. - //mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]); - //1st col - mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0)); - mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1)); - mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2)); - mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3)); - mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4)); - mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5)); - mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6)); - mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7)); - //2nd col - ptr_l += cs_l; - mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //3rd col - ptr_l += cs_l; - mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //4rth col - ptr_l += cs_l; - mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //5th col - ptr_l += cs_l; - mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //6th col - ptr_l += cs_l; - mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //7th col - ptr_l += cs_l; - mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //7th col - ptr_l += cs_l; - mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - numCols_b -= 8; // blk_width = 8 - - //compute reciprocals of L(i,i) and broadcast in registers - mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]); - mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]); - mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]); - mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]); - - //mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55); - //mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55); - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC); - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20); - - //reciprocal of diagnol elements - reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]); - - //Start loop for cols of B to be processed in size of blk_width - for (j = 0; j < numCols_b; j += 8) - { - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Read next set of B columns - ptr_b += (cs_b + cs_b_offset[5]); - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - - //end loop of cols - } - - //Last block trsm processing - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - - //end loop of cols -} - -static void blis_strsm_microkernel_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alphaVal) -{ - //float ones = 1.0; - int j; - int cs_b_offset[6]; - //int row2, row4, row6; - float *ptr_b_dup; - - //70 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_cols[8]; - __m256 mat_a_cols_rearr[36]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags; - __m256 alphaReg; - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - //reciprocal_diags = _mm256_broadcast_ss((float const *)&ones); - alphaReg = _mm256_broadcast_ss((float const *)&alphaVal); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - //_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0); - //row2 = (cs_l << 1); - //row4 = (cs_l << 2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - //_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - //_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - //_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0); - //row6 = row2 + row4; - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - //_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - //_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - //_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - //_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0); - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - - //read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L - /*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/ - - //Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers - //tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually. - //mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]); - //1st col - mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0)); - mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1)); - mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2)); - mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3)); - mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4)); - mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5)); - mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6)); - mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7)); - //2nd col - ptr_l += cs_l; - mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //3rd col - ptr_l += cs_l; - mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //4rth col - ptr_l += cs_l; - mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //5th col - ptr_l += cs_l; - mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //6th col - ptr_l += cs_l; - mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //7th col - ptr_l += cs_l; - mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //8th col - //ptr_l += cs_l; - //mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - numCols_b -= 8; // blk_width = 8 - - //compute reciprocals of L(i,i) and broadcast in registers - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]); - - //mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55); - //mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55); - //mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC); - //mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC); - //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20); - - //reciprocal of diagnol elements - //reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]); - - //Start loop for cols of B to be processed in size of blk_width - for (j = 0; j < numCols_b; j += 8) - { - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - //mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - //mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); - - //extract diag a11 from a - //mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - //mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - //mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - //mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - //mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - //mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - //mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - //mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - //mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - //mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - //mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - //mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - //mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - //mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Read next set of B columns - ptr_b += (cs_b + cs_b_offset[5]); - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - - //end loop of cols - } - - //Last block trsm processing - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - //mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - //mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); - - //extract diag a11 from a - //mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - //mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - //mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - //mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - //mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - //mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - //mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - //mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - //mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - //mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - //mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - //mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - //mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - //mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - - //end loop of cols -} - -static void blis_strsm_microkernel_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - //float ones = 1.0; - int j; - int cs_b_offset[6]; - //int row2, row4, row6; - float *ptr_b_dup; - - //70 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_cols[8]; - __m256 mat_a_cols_rearr[36]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags; - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - //reciprocal_diags = _mm256_broadcast_ss((float const *)&ones); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - //_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0); - //row2 = (cs_l << 1); - //row4 = (cs_l << 2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - //_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - //_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - //_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0); - //row6 = row2 + row4; - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - //_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - //_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - //_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - //_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0); - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - - //read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L - /*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/ - - //Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers - //tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually. - //mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]); - //1st col - mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0)); - mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1)); - mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2)); - mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3)); - mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4)); - mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5)); - mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6)); - mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7)); - //2nd col - ptr_l += cs_l; - mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //3rd col - ptr_l += cs_l; - mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //4rth col - ptr_l += cs_l; - mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //5th col - ptr_l += cs_l; - mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //6th col - ptr_l += cs_l; - mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //7th col - ptr_l += cs_l; - mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //8th col - //ptr_l += cs_l; - //mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - numCols_b -= 8; // blk_width = 8 - - //compute reciprocals of L(i,i) and broadcast in registers - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]); - - //mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55); - //mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55); - //mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC); - //mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC); - //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20); - - //reciprocal of diagnol elements - //reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]); - - //Start loop for cols of B to be processed in size of blk_width - for (j = 0; j < numCols_b; j += 8) - { - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - //mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - //mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - //extract diag a11 from a - //mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - //mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - //mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - //mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - //mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - //mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - //mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - //mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - //mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - //mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - //mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - //mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - //mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - //mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Read next set of B columns - ptr_b += (cs_b + cs_b_offset[5]); - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - //end loop of cols - } - - //Last block trsm processing - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - //mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - //mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - //extract diag a11 from a - //mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - //mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - //mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - //mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - //mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - //mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - //mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - //mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - //mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - //mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - //mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - //mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - //mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - //mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - //end loop of cols -} - -static void blis_strsm_microkernel(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - float ones = 1.0; - int j; - int cs_b_offset[6]; - //int row2, row4, row6; - float *ptr_b_dup; - - //70 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_cols[8]; - __m256 mat_a_cols_rearr[36]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags; - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - reciprocal_diags = _mm256_broadcast_ss((float const *)&ones); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - //_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0); - //row2 = (cs_l << 1); - //row4 = (cs_l << 2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - //_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - //_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - //_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0); - //row6 = row2 + row4; - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - //_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - //_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - //_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - //_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0); - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - - //read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L - /*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/ - - //Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers - //tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually. - //mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]); - //1st col - mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0)); - mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1)); - mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2)); - mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3)); - mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4)); - mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5)); - mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6)); - mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7)); - //2nd col - ptr_l += cs_l; - mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //3rd col - ptr_l += cs_l; - mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //4rth col - ptr_l += cs_l; - mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //5th col - ptr_l += cs_l; - mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //6th col - ptr_l += cs_l; - mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //7th col - ptr_l += cs_l; - mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //7th col - ptr_l += cs_l; - mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - numCols_b -= 8; // blk_width = 8 - - //compute reciprocals of L(i,i) and broadcast in registers - mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]); - mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]); - mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]); - mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]); - - //mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55); - //mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55); - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC); - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20); - - //reciprocal of diagnol elements - reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]); - - //Start loop for cols of B to be processed in size of blk_width - for (j = 0; j < numCols_b; j += 8) - { - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Read next set of B columns - ptr_b += (cs_b + cs_b_offset[5]); - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - //end loop of cols - } - - //Last block trsm processing - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - //end loop of cols -} - -#if OPT_CACHE_BLOCKING_L1 //new intrinsic kernels -static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags[2]; - - reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - //read diag elems of L 16x16 block - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - reciprocal_diags[1] = reciprocal_diags[0]; - - //pack first 8 diags together - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - - //Read next 8x8 block of A to get diag elements - i3 += cs_l_offset[6]; - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); - - //pack 8 diags of A together - reciprocal_diags[0] = reciprocal_diags[1]; - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - i = i1 + r; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); -#endif - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - ptr_l_dup = ptr_l; - i4 = i2 + r; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - i4 = k >> 3; - ptr_l_dup += cs_l; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - } - i2 += cs_b_offset[6]; - i += cs_l_offset[6]; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - { - i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2)); -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - } - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} - -static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) -{ - float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags[2]; - __m256 alphaReg; - - reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); - alphaReg = _mm256_broadcast_ss((float const *)&alpha); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - //read diag elems of L 16x16 block - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - reciprocal_diags[1] = reciprocal_diags[0]; - - //pack first 8 diags together - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); -#if 0 - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); -#endif - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - - //Read next 8x8 block of A to get diag elements - i3 += cs_l_offset[6]; - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); - - //pack 8 diags of A together - reciprocal_diags[0] = reciprocal_diags[1]; - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - i = i1 + r; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); -#endif - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - ptr_l_dup = ptr_l; - i4 = i2 + r; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - i4 = k >> 3; - ptr_l_dup += cs_l; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - } - i2 += cs_b_offset[6]; - i += cs_l_offset[6]; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - { - i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2)); - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - - _mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - } - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} - -static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - //float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags[2]; - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - //(Row0) - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - i3 += cs_l_offset[6]; - - i = 0; - i2 = 0; - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - i = i1 + r; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); -#endif - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - ptr_l_dup = ptr_l; - i4 = i2 + r; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - i4 = k >> 3; - ptr_l_dup += cs_l; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - } - i2 += cs_b_offset[6]; - i += cs_l_offset[6]; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - { - i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2)); -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): already done -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - } - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} - -static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) -{ - //float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags[2]; - __m256 alphaReg; - alphaReg = _mm256_broadcast_ss((float const *)&alpha); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - -#if 0 - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); -#endif - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); - - //(Row0) - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - i3 += cs_l_offset[6]; - - i = 0; - i2 = 0; - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - i = i1 + r; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); -#endif - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - ptr_l_dup = ptr_l; - i4 = i2 + r; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - i4 = k >> 3; - ptr_l_dup += cs_l; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - } - i2 += cs_b_offset[6]; - i += cs_l_offset[6]; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - { - i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2)); - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): already done - -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - } - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} -#else //rel 1.0 intrisic kernels (NOT OPT_CACHE_BLOCKING_L1) -static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[16][8]; - __m256 mat_a_cols_rearr[8]; - __m256 mat_a_blk_elems[64]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags[2]; - - reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - //read diag elems of L 16x16 block - mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l); - mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); - mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); - mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); - mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); - mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); - mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); - mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - reciprocal_diags[1] = reciprocal_diags[0]; - - //pack first 8 diags together - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_col[0] = _mm256_mul_ps(mat_b_rearr[0][0], mat_a_diag_inv[0]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b) - mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_col[1] = _mm256_mul_ps(mat_b_rearr[1][0], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_col[2] = _mm256_mul_ps(mat_b_rearr[2][0], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_col[3] = _mm256_mul_ps(mat_b_rearr[3][0], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_col[4] = _mm256_mul_ps(mat_b_rearr[4][0], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_col[5] = _mm256_mul_ps(mat_b_rearr[5][0], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_col[6] = _mm256_mul_ps(mat_b_rearr[6][0], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_col[7] = _mm256_mul_ps(mat_b_rearr[7][0], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - - //Read next 8x8 block of A to get diag elements - i3 += cs_l_offset[6]; - mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l + i3); - mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); - mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); - mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); - mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); - mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); - mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); - mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); - - //pack 8 diags of A together - reciprocal_diags[0] = reciprocal_diags[1]; - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - i = 0; - i2 = 0; - for (k = 0; k < numCols_b; k += 8) - { - i = i1 + k; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - i2++; - } - - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4)); - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7)); - - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1)); - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5)); - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7)); - - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i)); - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2)); - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3)); - mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4)); - mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5)); - mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6)); - mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7)); - - // _mm256_permute2f128_ps() - - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i)); - mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1)); - mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2)); - mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3)); - mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4)); - mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5)); - mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6)); - mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7)); - - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i)); - mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1)); - mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2)); - mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3)); - mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4)); - mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5)); - mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6)); - mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7)); - - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i)); - mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1)); - mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2)); - mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3)); - mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4)); - mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5)); - mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6)); - mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7)); - - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i)); - mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1)); - mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2)); - mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3)); - mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4)); - mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5)); - mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6)); - mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7)); - - i += cs_l_offset[6]; - - - for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - - i4 = i2 + k; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - i4 = k >> 3; - - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b) - - //end loop of cols - } - i2 += cs_b_offset[6]; - } - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - k = 0; - for (i = 0; i < numCols_b; i+=8) - { - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[k][0] = _mm256_mul_ps(mat_b_rearr[k][0], mat_a_diag_inv[0]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[k][1] = _mm256_mul_ps(mat_b_rearr[k][1], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[k][2] = _mm256_mul_ps(mat_b_rearr[k][2], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[k][3] = _mm256_mul_ps(mat_b_rearr[k][3], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[k][4] = _mm256_mul_ps(mat_b_rearr[k][4], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[k][5] = _mm256_mul_ps(mat_b_rearr[k][5], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[k][6] = _mm256_mul_ps(mat_b_rearr[k][6], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[k][7] = _mm256_mul_ps(mat_b_rearr[k][7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - - _mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - } - - - } - ///////////////////loop ends ///////////////////// -} - -static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) -{ - float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[16][8]; - __m256 mat_a_cols_rearr[8]; - __m256 mat_a_blk_elems[64]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags[2]; - __m256 alphaReg; - - reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); - alphaReg = _mm256_broadcast_ss((float const *)&alpha); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - //read diag elems of L 16x16 block - mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l); - mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); - mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); - mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); - mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); - mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); - mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); - mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - reciprocal_diags[1] = reciprocal_diags[0]; - - //pack first 8 diags together - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_rearr[0][0] = _mm256_mul_ps(mat_b_rearr[0][0], alphaReg); - mat_b_rearr[1][0] = _mm256_mul_ps(mat_b_rearr[1][0], alphaReg); - mat_b_rearr[2][0] = _mm256_mul_ps(mat_b_rearr[2][0], alphaReg); - mat_b_rearr[3][0] = _mm256_mul_ps(mat_b_rearr[3][0], alphaReg); - mat_b_rearr[4][0] = _mm256_mul_ps(mat_b_rearr[4][0], alphaReg); - mat_b_rearr[5][0] = _mm256_mul_ps(mat_b_rearr[5][0], alphaReg); - mat_b_rearr[6][0] = _mm256_mul_ps(mat_b_rearr[6][0], alphaReg); - mat_b_rearr[7][0] = _mm256_mul_ps(mat_b_rearr[7][0], alphaReg); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_col[0] = _mm256_mul_ps(mat_b_rearr[0][0], mat_a_diag_inv[0]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b) - mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_col[1] = _mm256_mul_ps(mat_b_rearr[1][0], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_col[2] = _mm256_mul_ps(mat_b_rearr[2][0], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_col[3] = _mm256_mul_ps(mat_b_rearr[3][0], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_col[4] = _mm256_mul_ps(mat_b_rearr[4][0], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_col[5] = _mm256_mul_ps(mat_b_rearr[5][0], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_col[6] = _mm256_mul_ps(mat_b_rearr[6][0], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_col[7] = _mm256_mul_ps(mat_b_rearr[7][0], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - - //Read next 8x8 block of A to get diag elements - i3 += cs_l_offset[6]; - mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l + i3); - mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); - mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); - mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); - mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); - mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); - mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); - mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); - - //pack 8 diags of A together - reciprocal_diags[0] = reciprocal_diags[1]; - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - i = 0; - i2 = 0; - for (k = 0; k < numCols_b; k += 8) - { - i = i1 + k; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_rearr[i2][0] = _mm256_mul_ps(mat_b_rearr[i2][0], alphaReg); - mat_b_rearr[i2][1] = _mm256_mul_ps(mat_b_rearr[i2][1], alphaReg); - mat_b_rearr[i2][2] = _mm256_mul_ps(mat_b_rearr[i2][2], alphaReg); - mat_b_rearr[i2][3] = _mm256_mul_ps(mat_b_rearr[i2][3], alphaReg); - mat_b_rearr[i2][4] = _mm256_mul_ps(mat_b_rearr[i2][4], alphaReg); - mat_b_rearr[i2][5] = _mm256_mul_ps(mat_b_rearr[i2][5], alphaReg); - mat_b_rearr[i2][6] = _mm256_mul_ps(mat_b_rearr[i2][6], alphaReg); - mat_b_rearr[i2][7] = _mm256_mul_ps(mat_b_rearr[i2][7], alphaReg); - - i2++; - } - - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4)); - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7)); - - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1)); - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5)); - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7)); - - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i)); - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2)); - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3)); - mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4)); - mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5)); - mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6)); - mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7)); - - // _mm256_permute2f128_ps() - - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i)); - mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1)); - mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2)); - mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3)); - mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4)); - mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5)); - mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6)); - mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7)); - - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i)); - mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1)); - mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2)); - mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3)); - mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4)); - mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5)); - mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6)); - mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7)); - - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i)); - mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1)); - mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2)); - mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3)); - mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4)); - mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5)); - mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6)); - mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7)); - - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i)); - mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1)); - mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2)); - mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3)); - mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4)); - mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5)); - mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6)); - mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7)); - - i += cs_l_offset[6]; - - - for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - - i4 = i2 + k; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - i4 = k >> 3; - - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b) - - //end loop of cols - } - i2 += cs_b_offset[6]; - } - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - k = 0; - for (i = 0; i < numCols_b; i+=8) - { - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[k][0] = _mm256_mul_ps(mat_b_rearr[k][0], mat_a_diag_inv[0]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[k][1] = _mm256_mul_ps(mat_b_rearr[k][1], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[k][2] = _mm256_mul_ps(mat_b_rearr[k][2], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[k][3] = _mm256_mul_ps(mat_b_rearr[k][3], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[k][4] = _mm256_mul_ps(mat_b_rearr[k][4], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[k][5] = _mm256_mul_ps(mat_b_rearr[k][5], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[k][6] = _mm256_mul_ps(mat_b_rearr[k][6], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[k][7] = _mm256_mul_ps(mat_b_rearr[k][7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - - _mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]); - k++; - } - - - } - ///////////////////loop ends ///////////////////// -} - -static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - //float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[16][8]; - //__m256 mat_a_cols_rearr[8]; - __m256 mat_a_blk_elems[64]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags[2]; - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - //(Row0) - mat_b_col[0] = mat_b_rearr[0][0]; - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b) - mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b) - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - i3 += cs_l_offset[6]; - - i = 0; - i2 = 0; - for (k = 0; k < numCols_b; k += 8) - { - i = i1 + k; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - i2++; - } - - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4)); - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7)); - - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1)); - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5)); - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7)); - - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i)); - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2)); - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3)); - mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4)); - mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5)); - mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6)); - mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7)); - - // _mm256_permute2f128_ps() - - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i)); - mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1)); - mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2)); - mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3)); - mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4)); - mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5)); - mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6)); - mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7)); - - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i)); - mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1)); - mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2)); - mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3)); - mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4)); - mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5)); - mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6)); - mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7)); - - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i)); - mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1)); - mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2)); - mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3)); - mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4)); - mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5)); - mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6)); - mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7)); - - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i)); - mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1)); - mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2)); - mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3)); - mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4)); - mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5)); - mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6)); - mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7)); - - i += cs_l_offset[6]; - - for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - - i4 = i2 + k; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - i4 = k >> 3; - - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b) - - //end loop of cols - } - i2 += cs_b_offset[6]; - } - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - k = 0; - for (i = 0; i < numCols_b; i+=8) - { - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A - - //(Row0): already done - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b) - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - - _mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - } - - - } - ///////////////////loop ends ///////////////////// -} - -static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) -{ - //float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[16][8]; - //__m256 mat_a_cols_rearr[8]; - __m256 mat_a_blk_elems[64]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags[2]; - __m256 alphaReg; - alphaReg = _mm256_broadcast_ss((float const *)&alpha); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_rearr[0][0] = _mm256_mul_ps(mat_b_rearr[0][0], alphaReg); - mat_b_rearr[1][0] = _mm256_mul_ps(mat_b_rearr[1][0], alphaReg); - mat_b_rearr[2][0] = _mm256_mul_ps(mat_b_rearr[2][0], alphaReg); - mat_b_rearr[3][0] = _mm256_mul_ps(mat_b_rearr[3][0], alphaReg); - mat_b_rearr[4][0] = _mm256_mul_ps(mat_b_rearr[4][0], alphaReg); - mat_b_rearr[5][0] = _mm256_mul_ps(mat_b_rearr[5][0], alphaReg); - mat_b_rearr[6][0] = _mm256_mul_ps(mat_b_rearr[6][0], alphaReg); - mat_b_rearr[7][0] = _mm256_mul_ps(mat_b_rearr[7][0], alphaReg); - - //(Row0) - mat_b_col[0] = mat_b_rearr[0][0]; - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b) - mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b) - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - i3 += cs_l_offset[6]; - - i = 0; - i2 = 0; - for (k = 0; k < numCols_b; k += 8) - { - i = i1 + k; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_rearr[i2][0] = _mm256_mul_ps(mat_b_rearr[i2][0], alphaReg); - mat_b_rearr[i2][1] = _mm256_mul_ps(mat_b_rearr[i2][1], alphaReg); - mat_b_rearr[i2][2] = _mm256_mul_ps(mat_b_rearr[i2][2], alphaReg); - mat_b_rearr[i2][3] = _mm256_mul_ps(mat_b_rearr[i2][3], alphaReg); - mat_b_rearr[i2][4] = _mm256_mul_ps(mat_b_rearr[i2][4], alphaReg); - mat_b_rearr[i2][5] = _mm256_mul_ps(mat_b_rearr[i2][5], alphaReg); - mat_b_rearr[i2][6] = _mm256_mul_ps(mat_b_rearr[i2][6], alphaReg); - mat_b_rearr[i2][7] = _mm256_mul_ps(mat_b_rearr[i2][7], alphaReg); - - i2++; - } - - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4)); - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7)); - - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1)); - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5)); - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7)); - - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i)); - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2)); - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3)); - mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4)); - mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5)); - mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6)); - mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7)); - - // _mm256_permute2f128_ps() - - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i)); - mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1)); - mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2)); - mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3)); - mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4)); - mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5)); - mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6)); - mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7)); - - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i)); - mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1)); - mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2)); - mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3)); - mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4)); - mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5)); - mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6)); - mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7)); - - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i)); - mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1)); - mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2)); - mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3)); - mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4)); - mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5)); - mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6)); - mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7)); - - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i)); - mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1)); - mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2)); - mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3)); - mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4)); - mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5)); - mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6)); - mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7)); - - i += cs_l_offset[6]; - - for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - - i4 = i2 + k; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - i4 = k >> 3; - - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b) - - //end loop of cols - } - i2 += cs_b_offset[6]; - } - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - k = 0; - for (i = 0; i < numCols_b; i+=8) - { - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A - - //(Row0): already done - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b) - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - - _mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - } - - - } - ///////////////////loop ends ///////////////////// -} -#endif //OPT_CACHE_BLOCKING_L1 - -//////////////////////////// AutX=B /////////////////////// -static void trsm_AutXB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags[2]; - - reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - //read diag elems of L 16x16 block - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - reciprocal_diags[1] = reciprocal_diags[0]; - - //pack first 8 diags together - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); -#if 0 - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); -#endif - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5])); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5])); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5])); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5])); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5])); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5])); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5])); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]); - - i += cs_b_offset[6]; - ptr_b_dup += cs_b_offset[6]; - //i += 8; - //ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += cs_l_offset[6]; - - //Read next 8x8 block of A to get diag elements - i3 += 8; - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); - - //pack 8 diags of A together - reciprocal_diags[0] = reciprocal_diags[1]; - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += 8; - i1 += 8; - i = i1; - i2 = 0; - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ -#endif - //i = 0; - ptr_l_dup = ptr_l; - i4 = i2; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - //{ - /////////////////// Partial Lower 8x8 block trsm of B - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44); - mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE); - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE); -#else - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E); - mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC); - mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33); - mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC); - mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31); - /* transpose steps end */ - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i4 = k >> 3; - ptr_l_dup++; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - //} - //i2 += cs_b_offset[6]; - i4 += 8; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - //{ - //i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5])); - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - //} - i += cs_b_offset[6]; - i2 += cs_b_offset[6]; - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} - -static void trsm_AutXB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) -{ - float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags[2]; - __m256 alphaReg; - - reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); - alphaReg = _mm256_broadcast_ss((float const *)&alpha); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - //read diag elems of L 16x16 block - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - reciprocal_diags[1] = reciprocal_diags[0]; - - //pack first 8 diags together - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); -#if 0 - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); -#endif - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5])); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5])); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5])); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5])); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5])); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5])); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5])); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]); - - i += cs_b_offset[6]; - ptr_b_dup += cs_b_offset[6]; - //i += 8; - //ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += cs_l_offset[6]; - - //Read next 8x8 block of A to get diag elements - i3 += 8; - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); - - //pack 8 diags of A together - reciprocal_diags[0] = reciprocal_diags[1]; - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += 8; - i1 += 8; - i = i1; - i2 = 0; - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); -#endif - - //i = 0; - ptr_l_dup = ptr_l; - i4 = i2; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - //{ - /////////////////// Partial Lower 8x8 block trsm of B - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44); - mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE); - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE); -#else - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E); - mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC); - mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33); - mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC); - mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31); - /* transpose steps end */ - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i4 = k >> 3; - ptr_l_dup++; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - //} - //i2 += cs_b_offset[6]; - i4 += 8; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - //{ - //i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5])); - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - //} - i += cs_b_offset[6]; - i2 += cs_b_offset[6]; - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} - -static void trsm_AutXB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - //float ones = 1.0; - int i, i1, i2, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags[2]; - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - -#if 0 - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); -#endif - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - - //(Row0) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5])); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5])); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5])); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5])); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5])); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5])); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5])); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]); - - i += cs_b_offset[6]; - ptr_b_dup += cs_b_offset[6]; - //i += 8; - //ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += cs_l_offset[6]; - - - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += 8; - i1 += 8; - i = i1; - i2 = 0; - - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ -#endif - - //i = 0; - ptr_l_dup = ptr_l; - i4 = i2; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - //{ - /////////////////// Partial Lower 8x8 block trsm of B - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44); - mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE); - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE); -#else - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E); - mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC); - mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33); - mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC); - mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31); - /* transpose steps end */ - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i4 = k >> 3; - ptr_l_dup++; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - //} - //i2 += cs_b_offset[6]; - i4 += 8; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - //{ - //i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): already done - -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5])); - - - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - //} - i += cs_b_offset[6]; - i2 += cs_b_offset[6]; - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} - -static void trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) -{ - //float ones = 1.0; - int i, i1, i2, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags[2]; - __m256 alphaReg; - alphaReg = _mm256_broadcast_ss((float const *)&alpha); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - -#if 0 - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); -#endif - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); - - //(Row0) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5])); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5])); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5])); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5])); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5])); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5])); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5])); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]); - - i += cs_b_offset[6]; - ptr_b_dup += cs_b_offset[6]; - //i += 8; - //ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += cs_l_offset[6]; - - - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += 8; - i1 += 8; - i = i1; - i2 = 0; - - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); -#endif - - //i = 0; - ptr_l_dup = ptr_l; - i4 = i2; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - //{ - /////////////////// Partial Lower 8x8 block trsm of B - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44); - mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE); - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE); -#else - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E); - mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC); - mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33); - mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC); - mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31); - /* transpose steps end */ - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i4 = k >> 3; - ptr_l_dup++; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - //} - //i2 += cs_b_offset[6]; - i4 += 8; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - //{ - //i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): already done - -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5])); - - - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - //} - i += cs_b_offset[6]; - i2 += cs_b_offset[6]; - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} -#endif +#endif //BLIS_ENABLE_SMALL_MATRIX_TRSM diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 9f3888340..e15b81e9c 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -260,3 +260,13 @@ void bli_dgemm_ref_k1_nn double* c, const inc_t ldc ); + err_t bli_trsm_small + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); +