From 3a2e4c3db89ea59226e509543ba7f0881fc5cd2b Mon Sep 17 00:00:00 2001 From: Nallani Bhaskar Date: Fri, 7 May 2021 18:41:25 +0530 Subject: [PATCH] Added optimized single threaded dtrsm small for left cases Details: 1. Added optimized dtrsm kernels for all 8 left side cases Below are few notable optimizations which improved performance a. Loading, transposing (for transa cases), packing and reusing of a10 block required for GEMM operation. The block size increases from 0 to 8X(m-8) in steps of 8x8 while solving TRSM from one end of A to other end of triangular A b. Performing inregister transpose whenever required c. Packing of 8 diagonal elements in one location helped to utilize cache line efficiently 2. Enabled calling dtrsm small for smaller sizes at cblas level itself to avoid frame work overhead, which is significant for very small sizes 3. Thanks to SatishKumar.Nuggu@amd.com for implementing lln, llt, lun and manideep.kurumella@amd.com for implementing lut kernels using intrinsics. 4. Removed all older implementations of strsm which are not developed as per the guide lines, can be refered from older releases if required. Change-Id: I66ad6ef364cbcf5c99a3c4a4dcac12929865ade6 --- frame/3/trsm/bli_trsm_front.c | 44 - frame/compat/bla_trsm.c | 267 +- kernels/zen/3/bli_trsm_small.c | 29610 ++++++++++++++----------------- kernels/zen/bli_kernels_zen.h | 10 + 4 files changed, 13546 insertions(+), 16385 deletions(-) diff --git a/frame/3/trsm/bli_trsm_front.c b/frame/3/trsm/bli_trsm_front.c index ff6264fa7..0c3bb11d2 100644 --- a/frame/3/trsm/bli_trsm_front.c +++ b/frame/3/trsm/bli_trsm_front.c @@ -55,50 +55,6 @@ void bli_trsm_front obj_t b_local; obj_t c_local; -#ifdef PRINT_SMALL_TRSM_INFO - printf("Side:: %c\n", side ? 'R' : 'L'); - if (bli_obj_datatype(*a) == BLIS_FLOAT) - printf("Alpha:: %9.2e\n", *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, *alpha))); - else if (bli_obj_datatype(*a) == BLIS_DOUBLE) - printf("Alpha is double:: %9.2e\n", *((double *)bli_obj_buffer_for_const(BLIS_DOUBLE, *alpha))); - else - printf("Unsupported datatype for Alpha\n"); - - printf("A:: M = %d, N = %d, elem_size = %d, row_off = %ld, col_off = %ld, rs = %d, cs = %d, trans = %c, TRIANG = %c, unit diag = %c\n", a->dim[0], a->dim[1], bli_obj_elem_size(*a ), bli_obj_row_off(*a), bli_obj_col_off(*a), a->rs, a->cs, bli_obj_has_trans(*a) ? 'Y' : 'N', bli_obj_is_upper(*a) ? 'U' : bli_obj_is_lower(*a) ? 'L' : 'N', bli_obj_has_unit_diag(*a) ? 'Y' : 'N'); -#ifdef PRINT_SMALL_TRSM - //bli_printm("a", a, "%4.1f", ""); -#endif - printf("B:: M = %d, N = %d, elem_size = %d, row_off = %ld, col_off = %ld, rs = %d, cs = %d, trans = %c\n", b->dim[0], b->dim[1], bli_obj_elem_size(*a ), bli_obj_row_off(*a), bli_obj_col_off(*a), b->rs, b->cs, bli_obj_has_trans(*b) ? 'Y' : 'N'); -#ifdef PRINT_SMALL_TRSM - //bli_printm("b", b, "%4.1f", ""); -#endif - fflush(stdout); -#endif -#if 0 -for (i = 0; i < m; i++) //no. of cols of B -{ - for (j = 0; j < n; j++) //no. of rows of B - { - B[i*n + j] = 1001 + j + (i*n); - } -} -for (i = 0; i < m; i++) //no. of cols of B -{ - for (j = i; j < m; j++) //no. of rows of B - { - L[i*m + j] = 2001 + j + (i*m); - } -} -#endif -#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - gint_t status = bli_trsm_small( side, alpha, a, b, cntx, cntl ); - if ( status == BLIS_SUCCESS ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } -#endif - // Check parameters. if ( bli_error_checking_is_enabled() ) bli_trsm_check( side, alpha, a, b, &BLIS_ZERO, b, cntx ); diff --git a/frame/compat/bla_trsm.c b/frame/compat/bla_trsm.c index fff59a351..f3edca62a 100644 --- a/frame/compat/bla_trsm.c +++ b/frame/compat/bla_trsm.c @@ -229,6 +229,7 @@ void PASTEF77(ch,blasname) \ (ftype*)b, rs_b, \ NULL \ ); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ return; \ } \ else if(bli_is_trans(blis_transa)) \ @@ -244,6 +245,7 @@ void PASTEF77(ch,blasname) \ (ftype*)b, rs_b, \ NULL \ ); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ return; \ } \ } \ @@ -268,6 +270,7 @@ void PASTEF77(ch,blasname) \ PASTEMAC(ch,invscals)( a_conj, b[indx] ); \ } \ }\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ return; \ } \ } \ @@ -290,6 +293,7 @@ void PASTEF77(ch,blasname) \ (ftype*)a, cs_a, rs_a, \ (ftype*)b, cs_b, \ NULL); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ return; \ } \ else if(bli_is_trans(blis_transa)) \ @@ -307,6 +311,7 @@ void PASTEF77(ch,blasname) \ (ftype*)a, cs_a, rs_a, \ (ftype*)b, cs_b, \ NULL); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ return; \ } \ } \ @@ -331,6 +336,7 @@ void PASTEF77(ch,blasname) \ PASTEMAC(ch,invscals)( a_conj, b[indx*cs_b] ); \ }\ } \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ return; \ } \ } \ @@ -374,6 +380,265 @@ void PASTEF77(ch,blasname) \ #endif #ifdef BLIS_ENABLE_BLAS -INSERT_GENTFUNC_BLAS( trsm, trsm ) +#ifdef BLIS_CONFIG_EPYC +void dtrsm_ +( + const f77_char* side, + const f77_char* uploa, + const f77_char* transa, + const f77_char* diaga, + const f77_int* m, + const f77_int* n, + const double* alpha, + const double* a, const f77_int* lda, + double* b, const f77_int* ldb +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(d), + *side, *uploa,*transa, *diaga, *m, *n, + (void*)alpha,*lda, *ldb); + + side_t blis_side; + uplo_t blis_uploa; + trans_t blis_transa; + diag_t blis_diaga; + dim_t m0, n0; + conj_t conja = BLIS_NO_CONJUGATE ; + + /* Initialize BLIS. */ + bli_init_auto(); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(trsm) + ( + MKSTR(d), + MKSTR(trsm), + side, + uploa, + transa, + diaga, + m, + n, + lda, + ldb + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const num_t dt = BLIS_DOUBLE; + + if( n0 == 1 ) + { + if( blis_side == BLIS_LEFT ) + { + if(bli_is_notrans(blis_transa)) + { + bli_dtrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + bli_dtrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) + { + /* b = alpha * b; */ + bli_dscalv_ex + ( + conja, + m0, + (double*)alpha, + b, rs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + double inva = 1.0/ *a; + for(int indx = 0; indx < m0; indx ++) + { + b[indx] = ( inva * b[indx] ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( m0 == 1 ) + { + if(blis_side == BLIS_RIGHT) + { + if(bli_is_notrans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_dtrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (double*)alpha, + (double*)a, cs_a, rs_a, + (double*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_dtrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (double*)alpha, + (double*)a, cs_a, rs_a, + (double*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) + { + /* b = alpha * b; */ + bli_dscalv_ex + ( + conja, + n0, + (double*)alpha, + b, cs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + double inva = 1.0/ *a; + for(int indx = 0; indx < n0; indx ++) + { + b[indx*cs_b] = (inva * b[indx*cs_b] ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + + const struc_t struca = BLIS_TRIANGULAR; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + + dim_t mn0_a; + + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); + + bli_obj_init_finish_1x1( dt, (double*)alpha, &alphao ); + + bli_obj_init_finish( dt, mn0_a, mn0_a, (double*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0, n0, (double*)b, rs_b, cs_b, &bo ); + + bli_obj_set_uplo( blis_uploa, &ao ); + bli_obj_set_diag( blis_diaga, &ao ); + bli_obj_set_conjtrans( blis_transa, &ao ); + + bli_obj_set_struc( struca, &ao ); + +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + /* Irrespective of num threads single thread bli_dtrsm_small + * is performing better than other implementations for [m,n]<=128 */ + /* ToDo: This condition will be tunned for single thread */ + if(m0 <=128 && n0<=128) + { + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } #endif + bli_trsmnat + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + /* Finalize BLIS. */ + bli_finalize_auto(); +} + +GENTFUNC( float, s, trsm, trsm ) +INSERT_GENTFUNC_BLAS_CZ( trsm, trsm ) +#else +INSERT_GENTFUNC_BLAS( trsm, trsm ) +#endif +#endif diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index c6ea0d12b..195819647 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -1,703 +1,387 @@ /* -BLIS -An object-based framework for developing high-performance BLAS-like -libraries. + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. -Copyright (C) 2018-2019, Advanced Micro Devices, Inc. + Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: -- Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. -- Redistributions in binary form must reproduce the above copyright -notice, this list of conditions and the following disclaimer in the -documentation and/or other materials provided with the distribution. -- Neither the name of The University of Texas at Austin nor the names -of its contributors may be used to endorse or promote products -derived from this software without specific prior written permission. + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ #include "blis.h" #ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM #include "immintrin.h" -#define GEMM_BLK_V1 8 //Block size to perform gemm and apply trsm -#define GEMM_ACCUM_A 1 //Peform B1=B1-(B0*A0) operation instead of B1'=(B0*A0) and then B1=B1-B1' -#define OPT_CACHE_BLOCKING_L1 1 //Perform trsm block-wise in blocks of GEMM_BLK_V1 instead of all columns of B together. -#define REARRANGE_SHFL 0 //Rearrange operations using blend or shuffle -#define BLI_AlXB_M_SP 16 -#define BLI_XAltB_N_SP 128 -#define BLI_AutXB_M_SP 64 -#define BLI_AutXB_N_SP 128 -// XA = B; A is lower-traingular; No transpose; double precision; non-unit diagonal -static err_t bli_dtrsm_small_XAlB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); +#define BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL -//XA = B; A is lower triabgular; No transpose; double precision; unit-diagonal -static err_t bli_dtrsm_small_XAlB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); +/* + trsm kernels function pointer +*/ +typedef err_t (*trsmsmall_ker_ft) +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); -//XA = B; A is lower-triangular; A is transposed; double precision; non-unit-diagonal -static err_t bli_dtrsm_small_XAltB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); +//AX = B; A is lower triangular; No transpose; +//double precision; non-unit diagonal +static err_t bli_dtrsm_small_AlXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); -//XA = B; A is lower-triangular; A is transposed; double precision; unit-diagonal -static err_t bli_dtrsm_small_XAltB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); +/* TRSM for the case AX = alpha * B, Double precision + * A is upper-triangular, non-transpose, non-unit diagonal + * dimensions A: mxm X: mxn B: mxn +*/ +static err_t bli_dtrsm_small_AuXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); -// XA = B; A is upper triangular; No transpose; double presicion; non-unit diagonal -static err_t bli_dtrsm_small_XAuB - ( - side_t side, - obj_t* alpha, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); +//AX = B; A is lower triangular; transpose; +//double precision; non-unit diagonal +static err_t dtrsm_AltXB_ref +( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag +); -//XA = B; A is upper triangular; No transpose; double precision; unit-diagonal -static err_t bli_dtrsm_small_XAuB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); - -//XA = B; A is upper-triangular; A is transposed; double precision; non-unit diagonal -static err_t bli_dtrsm_small_XAutB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); - -//XA = B; A is upper-triangular; A is transposed; double precision; unit diagonal -static err_t bli_dtrsm_small_XAutB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); - -//AX = B; A is lower triangular; No transpose; double precision; non-unit diagonal -static err_t bli_dtrsm_small_AlXB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); - -//AX = B; A is lower triangular; No transpose; double precision; unit diagonal -static err_t bli_dtrsm_small_AlXB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); - - - -static void (*fp_blis_strsm_microkernel)( float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b - ); -static void blis_strsm_microkernel( float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b - ); -static void blis_strsm_microkernel_alpha( float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b, - float alphaVal - ); -static void blis_strsm_microkernel_unitDiag( float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b - ); -static void blis_strsm_microkernel_alpha_unitDiag( float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b, - float alphaVal - ); -static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b); -static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b, - float alphaVal); -static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b); -static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b, - float alphaVal); - - -static void blis_dtrsm_microkernel( double *ptr_l, - double *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b - ); - -static void blis_dtrsm_microkernel_alpha( double *ptr_l, - double *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b, - double alphaVal - ); - -static void blis_dtrsm_microkernel_unitDiag( double *ptr_l, - double *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b - ); - -static void blis_dtrsm_microkernel_alpha_unitDiag( double *ptr_l, - double *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b, - double alphaVal - ); - -static void dtrsm_XAtB_block_allSmallSizedMatrices(double *ptr_l, - double *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b); -static void dtrsm_XAtB_block_allSmallSizedMatrices_alpha(double *ptr_l, - double *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b, - double alphaVal); -static void dtrsm_XAtB_block_allSmallSizedMatrices_unitDiag(double *ptr_l, - double *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b); -static void dtrsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(double *ptr_l, - double *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b, - double alphaVal); -static void trsm_AutXB_block_allSmallSizedMatrices(float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b); -static void trsm_AutXB_block_allSmallSizedMatrices_alpha(float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b, - float alpha); -static void trsm_AutXB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b); -static void trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, - float *ptr_b, - int numRows_lb, - int numCols_b, - int rs_l, - int rs_b, - int cs_l, - int cs_b, - float alpha); - -//AX = B; A is lower triangular; No transpose; single precision -static err_t bli_strsm_small_AlXB - ( - side_t side, - obj_t* alpha, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); -//A.'X = B; A is upper triangular; A has to be transposed; single precision -static err_t bli_strsm_small_AutXB - ( - side_t side, - obj_t* alpha, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); - -//XA.' = B; A is lower triangular; A has to be transposed; single precision -static err_t bli_strsm_small_XAltB - ( - side_t side, - obj_t* alpha, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); - -//A.'X = B; A is upper triangular; A has to be transposed; double precision +//A.'X = B; A is upper triangular; +//A has to be transposed; double precision static err_t bli_dtrsm_small_AutXB - ( - side_t side, - obj_t* alpha, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); +( + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +//AX = B; A is lower triangular; transpose; double precision +static err_t bli_dtrsm_small_AltXB +( + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); /* -* The bli_trsm_small implements unpacked version of TRSM -* Currently only column-major is supported, A & B are column-major -* Input: A: MxM (triangular matrix) -* B: MxN matrix -* Output: X: MxN matrix such that AX = alpha*B or XA = alpha*B or A'X = alpha*B or XA' = alpha*B -* Here the output X is stored in B -* The custom-kernel will be called only when M*(M+N)* sizeof(Matrix Elements) < L3 cache + * Reference implementations + * ToDo: We can combine all these reference implementation + into a macro */ -err_t bli_trsm_small - ( - side_t side, - obj_t* alpha, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +//A'X = B; A is upper triangular; transpose; +//non-unitDiagonal double precision +static err_t dtrsm_AutXB_ref +( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool unitDiagonal +) { -#ifdef BLIS_ENABLE_MULTITHREADING - return BLIS_NOT_YET_IMPLEMENTED; -#endif - - dim_t m = bli_obj_length(b); - dim_t n = bli_obj_width(b); - - if(!(m && n)) - return BLIS_SUCCESS; - - - // If alpha is zero, B matrix will become zero after scaling & hence solution is also zero matrix - if (bli_obj_equals(alpha, &BLIS_ZERO)) + dim_t i, j, k; + for (k = 0; k < M; k++) { - return BLIS_NOT_YET_IMPLEMENTED; // scale B by alpha - } - // We have to call matrix scaling if alpha != 1.0 - - // if row major format return. Check this again. - if ((bli_obj_row_stride(a) != 1) || - (bli_obj_row_stride(b) != 1)) + double lkk_inv = 1.0; + if(!unitDiagonal) lkk_inv = 1.0/A[k+k*lda]; + for (j = 0; j < N; j++) + { + B[k + j*ldb] *= lkk_inv; + for (i = k+1; i < M; i++) + { + B[i + j*ldb] -= A[i*lda + k] * B[k + j*ldb]; + } + } + }// k -loop + return BLIS_SUCCESS; +}// end of function + +/* TRSM scalar code for the case AX = alpha * B + * A is upper-triangular, non-unit-diagonal + * Dimensions: A: mxm X: mxn B:mxn + */ +static err_t dtrsm_AuXB_ref +( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag +) +{ + dim_t i, j, k; + for (k = M-1; k >= 0; k--) { - return BLIS_INVALID_ROW_STRIDE; - } - - num_t dt = ((*b).info & (0x7 << 0)); - - // only float and double datatypes are supported as of now. - if (dt != BLIS_DOUBLE && dt != BLIS_FLOAT) - { - return BLIS_EXPECTED_REAL_DATATYPE; - } - - // A is expected to be triangular in trsm - if (!bli_obj_is_upper_or_lower (a)) - { - return BLIS_EXPECTED_TRIANGULAR_OBJECT; - } - - // can use other control structs - even can use array of function pointers, - // indexed by a number with bits formed by f('side', 'uplo', 'transa', dt). - // In the below implementation, based on the number of finally implemented - // cases, can move the checks with more cases higher up. - - if(side == BLIS_LEFT) - { - if(bli_obj_has_trans(a)) - { - if(dt == BLIS_DOUBLE) - { - if(bli_obj_is_upper(a)) - { - //return bli_dtrsm_small_AutXB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - else - { - //return bli_dtrsm_small_AltXB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - } - else - { - if(bli_obj_is_upper(a)) - { - return bli_strsm_small_AutXB(side, alpha, a, b, cntx, cntl); - } - else - { - //return bli_strsm_small_AltXB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - - } - } - else - { - if(dt == BLIS_DOUBLE) - { - if(bli_obj_is_upper(a)) - { - //return bli_dtrsm_small_AuXB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - else - { - if(bli_obj_has_unit_diag(a)) - return bli_dtrsm_small_AlXB_unitDiag(side, alpha, a, b, cntx, cntl); - else - return bli_dtrsm_small_AlXB(side, alpha, a, b, cntx, cntl); - } - } - else - { - if(bli_obj_is_upper(a)) - { - //return bli_strsm_small_AuXB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - else - { - return bli_strsm_small_AlXB(side, alpha, a, b, cntx, cntl); - } - - } - - } - } - else - { - if(bli_obj_has_trans(a)) - { - if(dt == BLIS_DOUBLE) - { - if(bli_obj_is_upper(a)) - { - if(bli_obj_has_unit_diag(a)) - return bli_dtrsm_small_XAutB_unitDiag(side, alpha, a, b, cntx, cntl); - else - return bli_dtrsm_small_XAutB(side, alpha, a, b, cntx, cntl); - } - else - { - if(bli_obj_has_unit_diag(a)) - return bli_dtrsm_small_XAltB_unitDiag(side, alpha, a, b, cntx, cntl); - else - return bli_dtrsm_small_XAltB(side, alpha, a, b, cntx, cntl); - } - } - else - { - if(bli_obj_is_upper(a)) - { - //return bli_strsm_small_XAutB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - else - { - return bli_strsm_small_XAltB(side, alpha, a, b, cntx, cntl); - } - - } - } - else - { - if(dt == BLIS_DOUBLE) - { - if(bli_obj_is_upper(a)) - { - if(bli_obj_has_unit_diag(a)) - return bli_dtrsm_small_XAuB_unitDiag(side, alpha, a, b, cntx, cntl); - else - return bli_dtrsm_small_XAuB(side, alpha, a, b, cntx, cntl); - } - else - { - if(bli_obj_has_unit_diag(a)) - return bli_dtrsm_small_XAlB_unitDiag(side, alpha, a, b, cntx, cntl); - else - return bli_dtrsm_small_XAlB(side, alpha, a, b, cntx, cntl); - } - } - else - { - if(bli_obj_is_upper(a)) - { - //return bli_strsm_small_XAuB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - else - { - //return bli_strsm_small_XAlB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - - } - - } - } - return BLIS_NOT_YET_IMPLEMENTED; -}; + double lkk_inv = 1.0; + if(!is_unitdiag) lkk_inv = 1.0/A[k+k*lda]; + for (j = N -1; j >= 0; j--) + { + B[k + j*ldb] *= lkk_inv; + for (i = k-1; i >=0; i--) + { + B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; + } + } + }// k -loop + return BLIS_SUCCESS; +}// end of function /* TRSM scalar code for the case AX = alpha * B * A is lower-triangular, non-unit-diagonal, no transpose * Dimensions: A: mxm X: mxn B:mxn */ - -static err_t dtrsm_small_AlXB ( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb - ) +static err_t dtrsm_AlXB_ref +( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag +) { - - dim_t i, j, k; - - for (k = 0; k < M; k++) - { - double lkk_inv = 1.0/A[k+k*lda]; - for (j = 0; j < N; j++) + dim_t i, j, k; + for (k = 0; k < M; k++) { - B[k + j*ldb] *= lkk_inv; - for (i = k+1; i < M; i++) + double lkk_inv = 1.0; + if(!is_unitdiag) lkk_inv = 1.0/A[k+k*lda]; + for (j = 0; j < N; j++) { - B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; + B[k + j*ldb] *= lkk_inv; + for (i = k+1; i < M; i++) + { + B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; + } } - } - }// k -loop - return BLIS_SUCCESS; + }// k -loop + return BLIS_SUCCESS; }// end of function /* TRSM scalar code for the case AX = alpha * B - * A is lower-triangular, unit-diagonal, no transpose + * A is lower-triangular, non-unit-diagonal, transpose * Dimensions: A: mxm X: mxn B:mxn */ - -static err_t dtrsm_small_AlXB_unitDiag ( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb - ) +static err_t dtrsm_AltXB_ref +( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag +) { - - dim_t i, j, k; - - for (k = 0; k < M; k++) - { - for (j = 0; j < N; j++) - { - for (i = k+1; i < M; i++) + dim_t i, j, k; + for (k = M-1; k >= 0; k--) + { + double lkk_inv = 1.0; + if(!is_unitdiag) lkk_inv = 1.0/A[k+k*lda]; + for (j = N -1; j >= 0; j--) { - B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; + B[k + j*ldb] *= lkk_inv; + for (i = k-1; i >=0; i--) + { + B[i + j*ldb] -= A[i*lda + k] * B[k + j*ldb]; + } } - } - } - return BLIS_SUCCESS; + }// k -loop + return BLIS_SUCCESS; }// end of function +// XA = B; A is lower-traingular; No transpose; +//double precision; non-unit diagonal +static err_t bli_dtrsm_small_XAlB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +//XA = B; A is lower triabgular; No transpose; +//double precision; unit-diagonal +static err_t bli_dtrsm_small_XAlB_unitDiag +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +//XA' = B; A is lower-triangular; A is transposed; +// double precision; non-unit-diagonal +static err_t bli_dtrsm_small_XAltB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +//XA' = B; A is lower-triangular; A is transposed; +//double precision; unit-diagonal +static err_t bli_dtrsm_small_XAltB_unitDiag +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +// XA = B; A is upper triangular; No transpose; +//double presicion; non-unit diagonal +static err_t bli_dtrsm_small_XAuB +( + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +//XA = B; A is upper triangular; No transpose; +//double precision; unit-diagonal +static err_t bli_dtrsm_XAuB_unitDiag_ref +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +//XA' = B; A is upper-triangular; A is transposed; +//double precision; non-unit diagonal +static err_t bli_dtrsm_small_XAutB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +//XA' = B; A is upper-triangular; A is transposed; +//double precision; unit diagonal +static err_t bli_dtrsm_small_XAutB_unitDiag +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + + /* TRSM scalar code for the case XA = alpha * B * A is upper-triangular, non-unit-diagonal no transpose * Dimensions: X:mxn A:nxn B:mxn */ -static err_t dtrsm_small_XAuB ( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb +static err_t dtrsm_XAuB_ref +( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb ) { - - dim_t i, j, k; - for(k = 0; k < N; k++) - { - double lkk_inv = 1.0/A[k+k*lda]; - for(i = 0; i < M; i++) - { - B[i+k*ldb] *= lkk_inv; - for(j = k+1; j < N; j++) - { - B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda]; - } - } - + dim_t i, j, k; + for(k = 0; k < N; k++) + { + double lkk_inv = 1.0/A[k+k*lda]; + for(i = 0; i < M; i++) + { + B[i+k*ldb] *= lkk_inv; + for(j = k+1; j < N; j++) + { + B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda]; + } + } + } -return BLIS_SUCCESS; + return BLIS_SUCCESS; } /* TRSM scalar code for the case XA = alpha * B * A is lower-triangular, non-unit triangular, no transpose * Dimensions: X:mxn A:nxn B:mxn */ - -static err_t dtrsm_small_XAlB ( - double *A, - double *B, - double alpha, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb +static err_t dtrsm_XAlB_ref +( + double *A, + double *B, + double alpha, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb ) { - dim_t i, j, k; for(j = 0; j < N; j++) + { for(i = 0; i < M; i++) + { B[i+j*ldb] *= alpha; + } + } for(k = N;k--;) { @@ -711,32 +395,36 @@ static err_t dtrsm_small_XAlB ( } } } -return BLIS_SUCCESS; + return BLIS_SUCCESS; } /* TRSM scalar code for the case XA = alpha * B * A is lower-triangular, unit-diagonal, no transpose *Dimensions: X:mxn A:nxn B:mxn */ -static err_t dtrsm_small_XAlB_unitDiag( - double *A, - double *B, - double alpha, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb +static err_t dtrsm_XAlB_unitDiag_ref +( + double *A, + double *B, + double alpha, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb ) { - dim_t i, j, k; - for(j = 0 ; j < N; j++) + { for(i = 0; i < M; i++) + { B[i+j*ldb] *= alpha; + } + } + double A_k_j; - for(k = N; k--;) - { + for(k = N; k--;) + { for(j = k; j--;) { A_k_j = A[(k)+(j)*lda]; @@ -746,31 +434,32 @@ static err_t dtrsm_small_XAlB_unitDiag( } } } - - -return BLIS_SUCCESS; + return BLIS_SUCCESS; } /* TRSM scalar code for the case XA = alpha * B *A is upper-triangular, non-unit-diagonal, A is transposed * Dimensions: X:mxn A:nxn B:mxn */ -static err_t dtrsm_small_XAutB ( - double *A, - double *B, - double alpha, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb +static err_t dtrsm_XAutB_ref +( + double *A, + double *B, + double alpha, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb ) { - dim_t i, j, k; - for(j = 0; j < N; j++) + { for(i = 0; i < M; i++) + { B[i+j*ldb] *=alpha; + } + } for(k = N; k--;) { @@ -784,33 +473,37 @@ static err_t dtrsm_small_XAutB ( } } } -return BLIS_SUCCESS; + return BLIS_SUCCESS; } /* TRSM scalar code for the case XA = alpha * B * A is upper-triangular, unit-diagonal, A has to be transposed * Dimensions: X:mxn A:nxn B:mxn */ -static err_t dtrsm_small_XAutB_unitDiag( - double *A, - double *B, - double alpha, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb +static err_t dtrsm_XAutB_unitDiag_ref +( + double *A, + double *B, + double alpha, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb ) { - dim_t i, j, k; double A_k_j; for(j = 0; j< N; j++) + { for(i = 0; i< M; i++) + { B[i+j*ldb] *= alpha; + } + } - for(k = N; k--;) - { + for(k = N; k--;) + { for(j = k; j--;) { A_k_j = A[(j)+(k)*lda]; @@ -821,25 +514,24 @@ static err_t dtrsm_small_XAutB_unitDiag( } } } -return BLIS_SUCCESS; + return BLIS_SUCCESS; } /* TRSM scalar code for the case XA = alpha * B * A is lower-triangular, non-unit-diagonal, A has to be transposed * Dimensions: X:mxn A:nxn B:mxn */ -static err_t dtrsm_small_XAltB ( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb +static err_t dtrsm_XAltB_ref +( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb ) { - dim_t i, j, k; - for(k = 0; k < N; k++) { double lkk_inv = 1.0/A[k+k*lda]; @@ -852,14 +544,14 @@ static err_t dtrsm_small_XAltB ( } } } -return BLIS_SUCCESS; + return BLIS_SUCCESS; } /* TRSM scalar code for XA = alpha * B * A is lower-triangular, unit-diagonal, A has to be transposed * Dimensions: X:mxn A:nxn B:mxn */ -static err_t dtrsm_small_XAltB_unitDiag( +static err_t dtrsm_XAltB_unitDiag_ref( double *A, double *B, dim_t M, @@ -868,9 +560,7 @@ static err_t dtrsm_small_XAltB_unitDiag( dim_t ldb ) { - dim_t i, j, k; - for(k = 0; k < N; k++) { for(i = 0; i < M; i++) @@ -881,25 +571,24 @@ static err_t dtrsm_small_XAltB_unitDiag( } } } -return BLIS_SUCCESS; + return BLIS_SUCCESS; } /* TRSM scalar code for the case XA = alpha * B * A is upper-triangular, unit-diagonal, no transpose * Dimensions: X:mxn A:nxn B:mxn */ -static err_t dtrsm_small_XAuB_unitDiag ( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb +static err_t dtrsm_XAuB_unitDiag_ref +( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb ) { - dim_t i, j, k; - for(k = 0; k < N; k++) { for(i = 0; i < M; i++) @@ -910,13 +599,125 @@ static err_t dtrsm_small_XAuB_unitDiag ( } } } -return BLIS_SUCCESS; + return BLIS_SUCCESS; } +/* + * Kernels Table +*/ +trsmsmall_ker_ft ker_fps[16] = +{ + bli_dtrsm_small_AlXB, + bli_dtrsm_small_AltXB, + bli_dtrsm_small_AuXB, + bli_dtrsm_small_AutXB, + bli_dtrsm_small_AlXB, + bli_dtrsm_small_AltXB, + bli_dtrsm_small_AuXB, + bli_dtrsm_small_AutXB, + bli_dtrsm_small_XAlB, + bli_dtrsm_small_XAltB, + bli_dtrsm_small_XAuB, + bli_dtrsm_small_XAutB, + bli_dtrsm_small_XAlB_unitDiag, + bli_dtrsm_small_XAltB_unitDiag, + bli_dtrsm_XAuB_unitDiag_ref, + bli_dtrsm_small_XAutB_unitDiag +}; + +/* +* The bli_trsm_small implements a version of TRSM where A is packed and reused +* +* Input: A: MxM (triangular matrix) +* B: MxN matrix +* Output: X: MxN matrix such that + AX = alpha*B or XA = alpha*B or A'X = alpha*B or XA' = alpha*B +* Here the output X is stored in B +* +* Note: Currently only dtrsm is supported when A & B are column-major +*/ +err_t bli_trsm_small +( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + err_t err; + dim_t m = bli_obj_length(b); + dim_t n = bli_obj_width(b); + + if(!(m && n)) { + return BLIS_SUCCESS; + } + + bool unitdiag = bli_obj_has_unit_diag(a); + bool uplo = bli_obj_is_upper(a); + bool transa = bli_obj_has_trans(a); + + /* ToDo: Temporary threshold condition for trsm single thread. + * It will be updated with arch based threshold function which reads + * tunned thresholds for all 64 (datatype,side,uplo,transa,unit,) trsm + combinations + */ + if(m > 128 || n > 128) { + return BLIS_NOT_YET_IMPLEMENTED; + } + + /* If alpha is zero, B matrix will become zero after scaling + hence solution is also zero matrix */ + if (bli_obj_equals(alpha, &BLIS_ZERO)) { + return BLIS_NOT_YET_IMPLEMENTED; // scale B by alpha + } + + // Return if inputs are row major as currently + // we are supporing col major only + if ((bli_obj_row_stride(a) != 1) || + (bli_obj_row_stride(b) != 1)) { + return BLIS_INVALID_ROW_STRIDE; + } + + //Curretnly optimized for double data type only + num_t dt = bli_obj_dt(a); + if (dt != BLIS_DOUBLE) { + return BLIS_NOT_YET_IMPLEMENTED; + } + + // A is expected to be triangular in trsm + if (!bli_obj_is_upper_or_lower (a)) { + return BLIS_EXPECTED_TRIANGULAR_OBJECT; + } + + /* + * Compose kernel index based on inputs + */ + + dim_t keridx = ( (( side & 0x1) << 3) | ((unitdiag & 0x1) << 2) | + (( uplo & 0x1) << 1) | ( transa & 0x1) ); + + + trsmsmall_ker_ft ker_fp = ker_fps[ keridx ]; + + /*Call the kernel*/ + err = ker_fp + ( + alpha, + a, + b, + cntx, + cntl + ); + + return err; +}; + /* TRSM for the case AX = alpha * B, Double precision * A is lower-triangular, no-transpose, non-unit diagonal * dimensions A: mxm X: mxn B: mxn - + b01---> * ***************** ** * * * * * @@ -936,38 +737,2931 @@ a10 ****** b11 ***************** **************** ***************** a11---> */ -static err_t bli_dtrsm_small_AlXB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +static err_t bli_dtrsm_small_AlXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) { - - dim_t D_MR = 4; //size of block along 'M' dimpension - dim_t D_NR = 8; //size of block along 'N' dimension + dim_t D_MR = 8; //size of block along 'M' dimpension + dim_t D_NR = 6; //size of block along 'N' dimension dim_t m = bli_obj_length(b); // number of rows of matrix B dim_t n = bli_obj_width(b); // number of columns of matrix B + dim_t cs_a = bli_obj_col_stride(a); // column stride of A + dim_t cs_b = bli_obj_col_stride(b); // column stride of B -#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME) - || (m> D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_N) - || (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_M && n D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_NAPLES) + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed + + double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha + double *L = a->buffer; //pointer to matrix A + double *B = b->buffer; //pointer to matrix B + + //pointers that point to blocks for GEMM and TRSM + double *a10, *a11, *b01, *b11; + double *ptr_b01_dup; + + double ones = 1.0; + bool is_unitdiag = bli_obj_has_unit_diag(a); + + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19; + __m256d ymm20; + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s; + double *D_A_pack = NULL; + double d11_pack[D_MR] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + } + + if ( (D_MR * m * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + /* + Performs solving TRSM for 8 colmns at a time from 0 to m/8 in steps of D_MR + a. Load and pack A (a10 block), the size of packing 8x6 to 8x (m-8) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of D_NR + */ + for(i = 0;(i+D_MR-1) < m; i += D_MR) //loop along 'M' dimension + { + a10 = L + (i); //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = D_MR; // packed leading dimension + + /* + Pack current A block (a10) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by D_MR for every next itteration + untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + for(dim_t x =0;x < i;x++) { - return BLIS_NOT_YET_IMPLEMENTED; + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * x)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * x), ymm16); + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * x + 4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * x + 4), ymm16); } -#endif + + /* + Pack 8 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(is_unitdiag) + { + _mm256_storeu_pd((double *)(d11_pack), ymm4); + _mm256_storeu_pd((double *)(d11_pack + 4), ymm4); + }else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); - dim_t m_remainder = m & 3; //number of remainder rows - dim_t n_remainder = n & 7; //number of remainder columns + //Pick one element each column and create a 4 element vector and store + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm0 = _mm256_div_pd(ymm4, ymm1); + _mm256_storeu_pd((double *)(d11_pack), ymm0); + + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5 + 5)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6 + 6)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7 + 7)); + + //Pick one element each column and create a 4 element vector and store + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm0 = _mm256_div_pd(ymm4, ymm1); + _mm256_storeu_pd((double *)(d11_pack + 4), ymm0); + } + + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along n dimension for every D_NR columns of B01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i >>3; //number of times GEMM to be performed + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + double *b01_temp = b01+ 4; + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + b01 += 1; //move to next row of B + a10 += p_lda; //pointer math to calculate next block of A for GEMM + if(!((k+1)&3)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup + D_MR*p_lda; + ptr_a10_dup = a10; + } + } + ///GEMM code end/// + + /* + Load b11 of size 6x8 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_loadu_pd((double const *)(b11 + 4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm12); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm13); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm14); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm15); + + ymm13 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm15 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm15 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *4 + 4)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *5 + 4)); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); + ymm18 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm17 = _mm256_permute2f128_pd(ymm18,ymm16,0x20); + ymm19 = _mm256_permute2f128_pd(ymm18,ymm16,0x31); + ymm20 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm18 = _mm256_permute2f128_pd(ymm20,ymm16,0x20); + ymm20 = _mm256_permute2f128_pd(ymm20,ymm16,0x31); + //b11 transpose end + + /* + Compute 8x6 TRSM block by using GEMM block output in register + a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers + 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 + 5. ymm12, ymm17 6. ymm13,ymm18, 7. ymm14,ymm19 8. ymm15, ymm20 + where ymm8-ymm15 holds 8x4 data and reaming 8x2 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in b11 + */ + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + ymm4 = _mm256_mul_pd(ymm4, ymm1); + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm8, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm4, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm8, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm4, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm8, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm4, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm8, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm4, ymm20); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + ymm5 = _mm256_mul_pd(ymm5, ymm1); + + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm9, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm5, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm9, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm5, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm9, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm5, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm9, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm5, ymm20); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + ymm6 = _mm256_mul_pd(ymm6, ymm1); + + a11 += cs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm10, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm6, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm10, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm6, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm10, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm6, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm10, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm6, ymm20); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + ymm7 = _mm256_mul_pd(ymm7, ymm1); + + a11 += cs_a; + + //extract a44 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(ROw4): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm7, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm11, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm7, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm11, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm7, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm11, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm7, ymm20); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + ymm17 = _mm256_mul_pd(ymm17, ymm1); + + a11 += cs_a; + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm17, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm12, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm17, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm17, ymm20); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + ymm18 = _mm256_mul_pd(ymm18, ymm1); + + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw6): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm18, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm13, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm18, ymm20); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm19 = _mm256_mul_pd(ymm19, ymm1); + + a11 += cs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //(ROw7): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm19, ymm20); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + ymm20 = _mm256_mul_pd(ymm20, ymm1); + + a11 += cs_a; + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm3 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ///unpack high/// + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm3 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + _mm256_storeu_pd((double *)(b11 + 4), ymm0); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm1); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm2); //store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm3); //store B11[7][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm17, ymm18); + ymm3 = _mm256_unpacklo_pd(ymm19, ymm20); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + + ///unpack high/// + ymm17 = _mm256_unpackhi_pd(ymm17, ymm18); + ymm18 = _mm256_unpackhi_pd(ymm19, ymm20); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm17, ymm18, 0x20); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4 + 4), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 5 + 4), ymm1); + } + + dim_t n_rem = n-j; + if(n_rem >= 4) + { + a10 = D_A_pack; + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / D_MR; //number of times GEMM to be performed + + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3 + 4)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup + D_MR*p_lda; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); + ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); + + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm8, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm8, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm8, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm8, ymm15); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm9, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm9, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm9, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm9, ymm15); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + a11 += cs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm10, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm10, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm10, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm10, ymm15); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + a11 += cs_a; + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(ROw4): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm11, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm11, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm11, ymm15); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + + a11 += cs_a; + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm12, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw6): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 +7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm13, ymm15); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + + a11 += cs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //(ROw7): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + + a11 += cs_a; + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); + + n_rem -=4; + j +=4; + + } + + if(n_rem) + { + a10 = D_A_pack; + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / D_MR; //number of times GEMM to be performed + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_rem) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +D_MR*p_lda; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +D_MR*p_lda; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +D_MR*p_lda; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); + ymm5 = _mm256_broadcast_sd((double const *)(&ones)); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm8, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm8, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm8, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm8, ymm15); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm9, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm9, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm9, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm9, ymm15); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + a11 += cs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm10, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm10, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm10, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm10, ymm15); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + a11 += cs_a; + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(ROw4): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm11, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm11, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm11, ymm15); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + + a11 += cs_a; + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm12, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw6): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 +7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm13, ymm15); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + + a11 += cs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //(ROw7): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + + a11 += cs_a; + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); + + if(3 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); + } + else if(2 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); + } + else if(1 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); + } + } + } + + /* + Reminder cases starts here: + a. Similar logic and code flow used in computing full block (8x6) + above holds for reminder cases too. + */ + + dim_t m_rem = m-i; + //implementation for reamainder rows(when 'M' is not a multiple of D_MR) + if(m_rem>=4) + { + a10 = L + (i); //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + double *ptr_a10_dup = D_A_pack; + double *ptr_a11_dup = a11; + + dim_t p_lda = 4; // packed leading dimension + for(dim_t x =0;x < i;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + cs_a * x)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * x), ymm0); + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); + + //Pick one element each column and create a 4 element vector and store + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm4 = _mm256_div_pd(ymm4, ymm1); + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = ptr_a11_dup; //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM operation to be done + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + + ////unpacklow//// + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + + //rearrange high elements + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //b11 transpose end + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + ymm4 = _mm256_mul_pd(ymm4, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + ymm5 = _mm256_mul_pd(ymm5, ymm1); + + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + ymm6 = _mm256_mul_pd(ymm6, ymm1); + + a11 += cs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + ymm7 = _mm256_mul_pd(ymm7, ymm1); + + a11 += cs_a; + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] + } + dim_t n_rem = n-j; + if(n_rem >= 4) + { + a10 = D_A_pack; + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + a11 += cs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + a11 += cs_a; + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + n_rem -= 4; + j += 4; + } + if(n_rem) + { + a10 = D_A_pack; + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + a11 += cs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + a11 += cs_a; + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + if(3 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + } + else if(2 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + } + else if(1 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + } + } + m_rem -=4; + i +=4; + } + + if(m_rem) + { + a10 = L + (i); //pointer to block of A to be used for GEMM + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + if(3 == m_rem) // Repetative A blocks will be 3*3 + { + dim_t p_lda = 4; // packed leading dimension + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + n_rem -= 4; + j +=4; + } + if(n_rem) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + } + } + } + else if(2 == m_rem) // Repetative A blocks will be 2*2 + { + dim_t p_lda = 4; // packed leading dimension + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + n_rem -= 4; + j +=4; + } + if(n_rem) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed + + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + } + + } + m_rem -=2; + i+=2; + } + else if(1 == m_rem) // Repetative A blocks will be 1*1 + { + dim_t p_lda = 4; // packed leading dimension + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + n_rem -= 4; + j+=4; + } + + if(n_rem) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed + + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + + dtrsm_AlXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + } + } + m_rem -=1; + i+=1; + } + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm, &local_mem_buf_A_s); + } + + return BLIS_SUCCESS; +} + +/* TRSM for the Left Upper case AX = alpha * B, Double precision + * A is Left side, upper-triangular, transpose, non-unit diagonal + * dimensions A: mxm X: mxn B: mxn + a10 ----> b11---> +*********** ***************** +* * * * *b01*b11* * * + **a10 * * a11 b11 * * * * * + ********* | | ***************** + *a11* * | | * * * * * + * * * | | * * * * * + ****** v v ***************** + * * * * * * * + * * * * * * * + * * ***************** + * + a11---> +*/ +static err_t bli_dtrsm_small_AutXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + dim_t D_MR = 8; //size of block along 'M' dimpension + dim_t D_NR = 6; //size of block along 'N' dimension + + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B dim_t cs_a = bli_obj_col_stride(a); // column stride of A dim_t cs_b = bli_obj_col_stride(b); // column stride of B @@ -982,31 +3676,2605 @@ static err_t bli_dtrsm_small_AlXB( double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM double *ptr_b01_dup; - double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0 - double* f_temp; - double ones = 1.0; + bool is_unitdiag = bli_obj_has_unit_diag(a); //scratch registers __m256d ymm0, ymm1, ymm2, ymm3; __m256d ymm4, ymm5, ymm6, ymm7; __m256d ymm8, ymm9, ymm10, ymm11; __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16; + __m256d ymm16, ymm17, ymm18, ymm19; + __m256d ymm20; + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s; + double *D_A_pack = NULL; + double d11_pack[D_MR] __attribute__((aligned(64))); + rntm_t rntm; + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); - for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' dimension + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if ( (D_MR * m * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) { - for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' dimension + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + } + + /* + Performs solving TRSM for 8 colmns at a time from 0 to m/8 in steps of D_MR + a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-8) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of D_NR + */ + for(i = 0;(i+D_MR-1) < m; i += D_MR) //loop along 'M' dimension + { + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); + double *ptr_a10_dup = D_A_pack; + dim_t p_lda = i; // packed leading dimension + + /* + Load, tranpose and pack current A block (a10) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by D_MR for every next itteration + untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + + for(dim_t x =0;x < i;x+=D_MR) { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm10 = _mm256_loadu_pd((double const *)(a10 + 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm11 = _mm256_loadu_pd((double const *)(a10 + 4 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm12 = _mm256_loadu_pd((double const *)(a10 + 4 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); + ymm13 = _mm256_loadu_pd((double const *)(a10 + 4 + cs_a * 3)); - k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4) + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + + ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); + ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 4), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 5), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 6), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 7), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + cs_a * 4)); + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 4 + 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a * 5)); + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 5 + 4)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 6)); + ymm12 = _mm256_loadu_pd((double const *)(a10 + cs_a * 6 + 4)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 7)); + ymm13 = _mm256_loadu_pd((double const *)(a10 + cs_a * 7 + 4)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup + 4), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), ymm9); + + + ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); + ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 4 + 4), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 5 + 4), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 6 + 4), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 7 + 4), ymm9); + + a10 += D_MR; + ptr_a10_dup += D_MR; + } + + /* + Pack 8 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(is_unitdiag) + { + _mm256_storeu_pd((double *)(d11_pack), ymm4); + _mm256_storeu_pd((double *)(d11_pack + 4), ymm4); + }else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); + + //Pick one element each column and create a 4 element vector and store + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm0 = _mm256_div_pd(ymm4, ymm1); + _mm256_storeu_pd((double *)(d11_pack), ymm0); + + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11 + 4 + cs_a*4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5 + cs_a*5)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 + cs_a*6)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 7 + cs_a*7)); + + //Pick one element each column and create a 4 element vector and store + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm0 = _mm256_div_pd(ymm4, ymm1); + _mm256_storeu_pd((double *)(d11_pack + 4), ymm0); + } + + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along n dimension for every D_NR rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + dim_t temp = n - D_NR + 1; + for(j = 0; j < temp; j += D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i >>3; //number of times GEMM to be performed(in blocks of 4x4) + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + double *b01_temp = b01+ 4; + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + + ymm2 = _mm256_broadcast_sd((double const *)(b01_temp + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + b01 += 1; //move to next row of B + a10 += p_lda; //pointer math to calculate next block of A for GEMM + if(!((k+1)&3)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +D_MR; + ptr_a10_dup = a10; + } + } + ///GEMM code end/// + + /* + Load b11 of size 6x8 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] + + ymm13 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *4 + 4)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *5 + 4)); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + ymm18 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + + ymm17 = _mm256_permute2f128_pd(ymm18,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm19 = _mm256_permute2f128_pd(ymm18,ymm16,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm20 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm18 = _mm256_permute2f128_pd(ymm20,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm20 = _mm256_permute2f128_pd(ymm20,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + /* + Compute 8x6 TRSM block by using GEMM block output in register + a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers + 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 + 5. ymm12, ymm17 6. ymm13,ymm18, 7. ymm14,ymm19 8. ymm15, ymm20 + where ymm8-ymm15 holds 8x4 data and reaming 8x2 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in b11 + */ + ////extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + ymm4 = _mm256_mul_pd(ymm4, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm8, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm4, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm8, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm4, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm8, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm4, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm8, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm4, ymm20); + + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + ymm5 = _mm256_mul_pd(ymm5, ymm1); + + a11 += 1; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm9, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm5, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm9, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm5, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm9, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm5, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm9, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm5, ymm20); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + ymm6 = _mm256_mul_pd(ymm6, ymm1); + + a11 += 1; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm10, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm6, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm10, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm6, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm10, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm6, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm10, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm6, ymm20); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + ymm7 = _mm256_mul_pd(ymm7, ymm1); + + a11 += 1; + + //extract a44 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + //(ROw4): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm7, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm11, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm7, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm11, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm7, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm11, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm7, ymm20); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + ymm17 = _mm256_mul_pd(ymm17, ymm1); + + a11 += 1; + + //extract a55 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm17, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm12, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm17, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm17, ymm20); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + ymm18 = _mm256_mul_pd(ymm18, ymm1); + + a11 += 1; + + //extract a66 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw6): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm18, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm13, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm18, ymm20); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm19 = _mm256_mul_pd(ymm19, ymm1); + + a11 += 1; + + //extract a77 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //(ROw7): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm19, ymm20); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + ymm20 = _mm256_mul_pd(ymm20, ymm1); + + a11 += 1; + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm3 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ///unpack high/// + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm3 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + _mm256_storeu_pd((double *)(b11 + 4), ymm0); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm1); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm2); //store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm3); //store B11[7][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm17, ymm18); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm3 = _mm256_unpacklo_pd(ymm19, ymm20); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ///unpack high/// + ymm17 = _mm256_unpackhi_pd(ymm17, ymm18); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm18 = _mm256_unpackhi_pd(ymm19, ymm20); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm17, ymm18, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + + _mm256_storeu_pd((double *)(b11 + cs_b * 4 + 4), ymm0); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5 + 4), ymm1); //store B11[5][0-3] + } + + dim_t n_rem = n-j; + if(n_rem >= 4) + { + a10 = D_A_pack; + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4) + + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3 + 4)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 5)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 5 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3 + 4)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 6)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 6 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3 + 4)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 7)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 7 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3 + 4)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + b01 += 1; //move to next row of B + + a10 += D_MR; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] + + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm8, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm8, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm8, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm8, ymm15); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm9, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm9, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm9, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm9, ymm15); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm10, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm10, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm10, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm10, ymm15); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //extract a44 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + + //(ROw4): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm5, ymm11, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm11, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm11, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm11, ymm15); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + + //extract a55 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + + //(ROw5): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm6, ymm12, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm12, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm12, ymm15); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 +cs_a*7)); + + a11 += 1; + + //extract a66 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw6): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm7, ymm13, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm13, ymm15); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + + //extract a77 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + //(ROw7): FMA operations + ymm15 = _mm256_fnmadd_pd(ymm16, ymm14, ymm15); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); //store B11[7][0-3] + + n_rem -=4; + j +=4; + + } + if(n_rem) + { + a10 = D_A_pack; + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4) + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 5)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 5 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 6)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 6 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 7)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 7 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + b01 += 1; //move to next row of B + + a10 += D_MR; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 5)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 5 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 6)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 6 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 7)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 7 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + b01 += 1; //move to next row of B + + a10 += D_MR; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 5)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 5 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 6)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 6 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 7)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 7 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + b01 += 1; //move to next row of B + + a10 += D_MR; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_broadcast_sd((double const *)(&ones)); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm8, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm8, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm8, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm8, ymm15); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm9, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm9, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm9, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm9, ymm15); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm10, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm10, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm10, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm10, ymm15); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //extract a44 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + + //(ROw4): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm5, ymm11, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm11, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm11, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm11, ymm15); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + + //extract a55 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + + //(ROw5): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm6, ymm12, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm12, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm12, ymm15); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 +cs_a*7)); + + a11 += 1; + + //extract a66 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + + //(ROw6): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm7, ymm13, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm13, ymm15); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + + //extract a77 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += 1; + //(ROw7): FMA operations + ymm15 = _mm256_fnmadd_pd(ymm16, ymm14, ymm15); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + if(3 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + } + else if(2 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + } + else if(1 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + } + } + + } + //======================M remainder cases================================ + dim_t m_rem = m-i; + if(m_rem>=4) //implementation for reamainder rows(when 'M' is not a multiple of D_MR) + { + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); + double *ptr_a10_dup = D_A_pack; + dim_t p_lda = i; // packed leading dimension + for(dim_t x =0;x < i;x+=4) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += 4; + ptr_a10_dup += 4; + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm4 = _mm256_div_pd(ymm4, ymm1); + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i / 4; //number of times GEMM operation to be done(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + + ////unpacklow//// + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + //ymm16; + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + //ymm16; + + //rearrange high elements + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + //b11 transpose end + + ////extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + ymm4 = _mm256_mul_pd(ymm4, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); + + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + ymm5 = _mm256_mul_pd(ymm5, ymm1); + + a11 += 1; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + ymm6 = _mm256_mul_pd(ymm6, ymm1); + + a11 += 1; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + ymm7 = _mm256_mul_pd(ymm7, ymm1); + + a11 += 1; + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] + } + dim_t n_rem = n-j; + if(n_rem >= 4) + { + a10 = D_A_pack; + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / 4; //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += 1; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += 1; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += 1; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + n_rem -= 4; + j += 4; + } + if(n_rem) + { + a10 = D_A_pack; + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / 4; //number of times GEMM to be performed(in blocks of 4x4) + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + ////extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += 1; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += 1; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += 1; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + if(3 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + } + else if(2 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + } + else if(1 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + } + + } + m_rem -=4; + i +=4; + } + + if(m_rem) + { + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + if(3 == m_rem) // Repetative A blocks will be 3*3 + { + dim_t p_lda = i; // packed leading dimension + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - k_iter = i / D_MR; //number of times GEMM operation to be done(in blocks of 4x4) + k_iter = i / 4; //number of times GEMM to be performed(in blocks of 4x4) - int iter; - - if((j+D_NR) == n) - { - for(iter = 0; iter < m_remainder; iter++) - f_t[iter] = (b11 + cs_b * 7)[iter]; - f_temp = f_t; - } - else - f_temp = (b11 + cs_b * 7); + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ymm8 = _mm256_setzero_pd(); ymm9 = _mm256_setzero_pd(); ymm10 = _mm256_setzero_pd(); @@ -1321,1748 +6449,2174 @@ static err_t bli_dtrsm_small_AlXB( ymm14 = _mm256_setzero_pd(); ymm15 = _mm256_setzero_pd(); - ///GEMM code Begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b,is_unitdiag); + n_rem -= 4; + j +=4; + } + if(n_rem) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i / 4; //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { ptr_b01_dup = b01; - ymm16 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - b01 += 1; //move to next row of B + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0] ) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm16 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) + b01 += 1; //move to next row of B - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 1)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - b01 += 1; //move to next row of B01 + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) + b01 += 1; //move to next row of B - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] //A10[1][2] A10[2][2] A10[3][2] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - b01 += 1; //move to next row of B + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][0]*A10[3][2]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][1]*A10[3][2]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][2]*A10[3][2]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][3]*A10[3][2]) + b01 += 1; //move to next row of B - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - b01 += 1; //move to next row of B + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm8 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm8 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm8 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm8 += (B01[3][0]*A10[0][3] B01[3][4]*A10[1][3] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm8 += (B01[3][1]*A10[0][3] B01[3][5]*A10[1][3] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm8 += (B01[3][2]*A10[0][3] B01[3][6]*A10[1][3] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm8 += (B01[3][3]*A10[0][3] B01[3][7]*A10[1][3] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) + b01 += 1; //move to next row of B - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b,is_unitdiag); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b,is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + + dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + } + } + } + else if(2 == m_rem) // Repetative A blocks will be 2*2 + { + dim_t p_lda = i; // packed leading dimension + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - k_iter = i / D_MR; //number of times GEMM to be performed(in block of 4) - ///GEMM for previously calculated values /// - - //load 4x4 block from b11 - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + k_iter = i / 4; //number of times GEMM to be performed(in blocks of 4x4) + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a*2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ptr_b01_dup = b01; - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + ymm0 = _mm256_loadu_pd((double const *)(a10)); - b01 += 1; //move to next row of B + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - b01 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) + b01 += 1; //move to next row of B - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); - b01 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[1][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[2][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[3][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - b01 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) + b01 += 1; //move to next row of B + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - - ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B01[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B01[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B01[0-3][3] *alpha -= ymm7 + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha ///implement TRSM/// - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - //3rd col - a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] - ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3] - //compute reciprocals of L(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[2][2] A11[2][2] - ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[1][1] A11[1][1] A11[3][3] A11[3][3] + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - ymm14 = _mm256_broadcast_sd((double const *)&ones); + dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + n_rem -= 4; + j +=4; + } + if(n_rem) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //extract a00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] - - //extract diag a11 from a - ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]*B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]*B11[0][0-3] - ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]*B11[0][0-3] - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] - - - //extract diag a22 from a - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]*B11[1][0-3] - ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]*B11[1][0-3] - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2] - - //extract diag a33 from a - ymm15 = _mm256_permute_pd(ymm14, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3] - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]*B11[2][0-3] - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - ymm13 = _mm256_mul_pd(ymm13, ymm15); //B11[3][0-3] /= A11[3][3] - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[0-3][3]) - - } - if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) - - dim_t iter; - - if((j+4) == n) - { - f_temp = f_t; - for(iter = 0; iter < m_remainder; iter++) - f_temp[iter] = (b11 + cs_b * 3)[iter]; - } - else - f_temp = (b11 + cs_b * 3); - ///GEMM for previously calculated values /// - - //load 4x4 block from b11 - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + k_iter = i / 4; //number of times GEMM to be performed(in blocks of 4x4) + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); - for(k = 0; k < k_iter; k++) //looop for number of GEMM operations + if(3 == n_rem) { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - b01 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + b01 += 1; //move to next row of B - b01 += 1; + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - b01 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) + b01 += 1; //move to next row of B - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); - b01 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - - ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[0-3][3] *alpha -= ymm7 - - - if(3 == m_remainder) + else if(2 == n_rem) { - ///implement TRSM/// - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] + ///GEMM code begins/// - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; - //3rd col - a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] + ymm0 = _mm256_loadu_pd((double const *)(a10)); - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(&ones)); //A11[3][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - //compute reciprocals of L(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] - ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[2][2] A11[3][3] A11[3][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm14 = _mm256_broadcast_sd((double const *)&ones); + b01 += 1; //move to next row of B - ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + b01 += 1; //move to next row of B - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); - //extract a00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - //extract diag a11 from a - ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[][] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] + b01 += 1; //move to next row of B - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]* B11[0][0-3] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - //extract diag a22 from a - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]* B11[1][0-3] - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2] + b01 += 1; //move to next row of B - ymm13 = _mm256_broadcast_sd((double const *)(&ones)); + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ///implement TRSM/// - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - //load 4x4 block from b11 - ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - //determine correct values to store - ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08); - ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08); - ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08); - ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08); + + dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); } - else if( 2 == m_remainder ) + else if(1 == n_rem) { - ///implement TRSM/// - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] + ///GEMM code begins/// - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; - //compute reciprocals of L(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] + ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm14 = _mm256_broadcast_sd((double const *)&ones); + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm4 = _mm256_blend_pd(ymm4, ymm14, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + b01 += 1; //move to next row of B - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + b01 += 1; //move to next row of B - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); - //extract a00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] + b01 += 1; //move to next row of B - //extract diag a11 from a - ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[][] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] + b01 += 1; //move to next row of B - ymm11 = _mm256_broadcast_sd((double const *)(&ones)); - ymm13 = _mm256_broadcast_sd((double const *)(&ones)); + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ///implement TRSM/// - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - //load 4x4 block from b11 - ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] - - //determine correct values to store - ymm0 = _mm256_permute2f128_pd(ymm0, ymm4,0x30); - ymm1 = _mm256_permute2f128_pd(ymm1, ymm5,0x30); - ymm2 = _mm256_permute2f128_pd(ymm2, ymm6,0x30); - ymm3 = _mm256_permute2f128_pd(ymm3, ymm7,0x30); - - } - else if(1 == m_remainder) - { - ///implement TRSM/// - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - ymm14 = _mm256_broadcast_sd((double const *)&ones); - ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //extract a00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] - - ymm8 = _mm256_broadcast_sd((double const *)(&ones)); - ymm11 = _mm256_broadcast_sd((double const *)(&ones)); - ymm13 = _mm256_broadcast_sd((double const *)(&ones)); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - //load 4x4 block from b11 - ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] - - //determine correct values to store - ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x0E); - ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x0E); - ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x0E); - ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x0E); + dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); } - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[0-3][3]) - - if((j+4) == n) - { - for(iter = 0; iter < m_remainder; iter++) - (b11 + cs_b * 3)[iter] = f_temp[iter]; - } + } + m_rem -=2; + i+=2; } + else if(1 == m_rem) // Repetative A blocks will be 1*1 + { + dim_t p_lda = i; // packed leading dimension + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i / 4; //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + n_rem -= 4; + j+=4; + } + + if(n_rem) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i / 4; //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 2)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 3)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + + a10 += 4; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + 4; //pointer math to calculate next block of B for GEMM + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + + dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + } + } + m_rem -=1; + i+=1; + } } - if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR) + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) { - for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM for previously calculated values /// - - //load 4x4 block from b11 - if(3 == n_remainder) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - for(k = 0; k < k_iter; k++) - { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - } - - ///GEMM code ends/// - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value - - ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 - } - else if(2 == n_remainder) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - for(k = 0; k < k_iter; k++) - { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - } - ///GEMM code ends/// - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value - - ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 - } - else if(1 == n_remainder) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - for(k = 0; k < k_iter; k++) - { - ptr_b01_dup = b01; - - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - } - ///GEMM code ends/// - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value - - ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 - - } - - ///implement TRSM/// - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] - - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] - - //3rd col - a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] - ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] - - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3] - //compute reciprocals of L(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] - ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[2][2] A11[3][3] A11[3][3] - - ymm14 = _mm256_broadcast_sd((double const *)&ones); - - ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //extract a00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] - - //extract diag a11 from a - ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0] * B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0] * B11[0][0-3] - ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0] * B11[0][0-3] - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] - - - //extract diag a22 from a - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1] * B11[1][0-3] - ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1] * B11[1][0-3] - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2] - - //extract diag a33 from a - ymm15 = _mm256_permute_pd(ymm14, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3] - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2] * B11[2][0-3] - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - ymm13 = _mm256_mul_pd(ymm13, ymm15); //B11[3][0-3] /= A11[3][3] - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - if(3 == n_remainder) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - - } - else if(2 == n_remainder) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - - } - else if(1 == n_remainder) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - } - - } - if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - - k_iter = i / D_MR; //number of times GEMM operations to be performed - - dim_t iter; - if((j+n_remainder) == n) - { - f_temp = f_t; - for(iter = 0; iter < m_remainder; iter++) - f_temp[iter] = (b11 + cs_b * (n_remainder -1))[iter]; - } - else - f_temp = (b11 + cs_b * (n_remainder -1)); - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM for previously calculated values /// - - - //load 4x4 block from b11 - if(3 == n_remainder) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[0][2] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[1][2] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[2][2] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[3][2] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value - - ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 - ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5 - ymm10 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] * alpha -= ymm6 - - ///implement TRSM/// - //determine correct values to store - if(3 == m_remainder) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - } - else if(2 == m_remainder) - { - ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); - ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); - ymm2 = _mm256_permute2f128_pd(ymm10, ymm2, 0x30); - } - else if(1 == m_remainder) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - } - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(f_temp), ymm2); //store(B11[0-3][2]) - } - if(2 == n_remainder) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_b01_dup = b01; - - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value - - ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 - ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5 - - ///implement TRSM/// - //determine correct values to store - if(3 == m_remainder) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - } - else if(2 == m_remainder) - { - ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); - ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); - } - else if(1 == m_remainder) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - } - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[0-3][1]) - } - if(n_remainder == 1) - { - ymm0 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_b01_dup = b01; - - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - - } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value - - ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 - - ///implement TRSM/// - //determine correct values to store - if(3 == m_remainder) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - } - else if(2 == m_remainder) - { - ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); - } - else if(1 == m_remainder) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - } - _mm256_storeu_pd((double *)(f_temp), ymm0); //store(B11[0-3][0]) - } - - if((j+n_remainder) == n) - { - for(iter = 0; iter < m_remainder; iter++) - (b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter]; - } - - ///scalar code for trsm without alpha/// - dtrsm_small_AlXB(a11, b11, m_remainder, n_remainder, cs_a, cs_b); - } + bli_membrk_release(&rntm, &local_mem_buf_A_s); } return BLIS_SUCCESS; } /* TRSM for the case AX = alpha * B, Double precision - * A is lower-triangular, no-transpose, unit diagonal + * A is lower-triangular, transpose, non-unit diagonal * dimensions A: mxm X: mxn B: mxn - - b01---> - * ***************** - ** * * * * * - * * * * * * * - * * *b01* * * * - * * * * * * * -a10 ****** b11 ***************** - | * * * | * * * * * - | * * * | * * * * * - | *a10*a11* | *b11* * * * - v * * * v * * * * * - *********** ***************** - * * * * * * * * * - * * * * * * * * * - * * * * * * * * * - * * * * * * * * * - **************** ***************** - a11---> */ - -static err_t bli_dtrsm_small_AlXB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +static err_t bli_dtrsm_small_AltXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) { + dim_t D_MR = 8; //size of block along 'M' dimpension + dim_t D_NR = 6; //size of block along 'N' dimension - dim_t D_MR = 4; //size of block along 'M' dimpension - dim_t D_NR = 8; //size of block along 'N' dimension + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B - dim_t m = bli_obj_length(b); // number of rows of matrix B - dim_t n = bli_obj_width(b); // number of columns of matrix B + dim_t cs_a = bli_obj_col_stride(a); // column stride of A + dim_t cs_b = bli_obj_col_stride(b); // column stride of B -#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME) - || (m> D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_N) - || (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_M && n D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_NAPLES) - { - return BLIS_NOT_YET_IMPLEMENTED; - } -#endif - - dim_t m_remainder = m & (3); //number of remainder rows - dim_t n_remainder = n & (7); //number of remainder columns - - dim_t cs_a = bli_obj_col_stride(a); // column stride of A - dim_t cs_b = bli_obj_col_stride(b); // column stride of B - - dim_t i, j, k; //loop variables - dim_t k_iter; //number of times GEMM to be performed + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B + double *L = a->buffer; //pointer to matrix A + double *B = b->buffer; //pointer to matrix B - double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM + //pointers that point to blocks for GEMM and TRSM + double *a10, *a11, *b01, *b11; double *ptr_b01_dup; - - double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0 - double* f_temp; + double *ptr_a10_dup; double ones = 1.0; - + bool is_unitdiag = bli_obj_has_unit_diag(a); //scratch registers __m256d ymm0, ymm1, ymm2, ymm3; __m256d ymm4, ymm5, ymm6, ymm7; __m256d ymm8, ymm9, ymm10, ymm11; __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16; + __m256d ymm16, ymm17, ymm18, ymm19; + __m256d ymm20; + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s; + double *D_A_pack = NULL; + double d11_pack[D_MR] __attribute__((aligned(64))); + rntm_t rntm; + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); - for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' dimension + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if((D_MR * m * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) { - for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' dimension + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + } + + /* + Performs solving TRSM for 8 colmns at a time from 0 to m/8 in steps of D_MR + a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-8) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of D_NR + */ + for(i = (m - D_MR); (i + 1) > 0; i -= D_MR) + { + a10 = L + (i*cs_a) + i + D_MR; //pointer to block of A to be used for GEMM + a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM + + // Do transpose for a10 & store in D_A_pack + ptr_a10_dup = D_A_pack; + dim_t p_lda = (m-i-D_MR); // packed leading dimension + /* + Load, transpose and pack current A block (a10) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by D_MR for every next itteration + untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + for(dim_t x =0;x < (m-i-D_MR);x+=D_MR) { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm10 = _mm256_loadu_pd((double const *)(a10 + 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm11 = _mm256_loadu_pd((double const *)(a10 + 4 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm12 = _mm256_loadu_pd((double const *)(a10 + 4 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); + ymm13 = _mm256_loadu_pd((double const *)(a10 + 4 + cs_a * 3)); - k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4) + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); + ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 4), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 5), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 6), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 7), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + cs_a * 4)); + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 4 + 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a * 5)); + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 5 + 4)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 6)); + ymm12 = _mm256_loadu_pd((double const *)(a10 + cs_a * 6 + 4)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 7)); + ymm13 = _mm256_loadu_pd((double const *)(a10 + cs_a * 7 + 4)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup + 4), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), ymm9); + + ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); + ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 4 + 4), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 5 + 4), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 6 + 4), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 7 + 4), ymm9); + + a10 += D_MR; + ptr_a10_dup += D_MR; + } + + /* + Pack 8 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(is_unitdiag) + { + _mm256_storeu_pd((double *)(d11_pack), ymm4); + _mm256_storeu_pd((double *)(d11_pack + 4), ymm4); + }else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); + + //Pick one element each column and create a 4 element vector and store + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm0 = _mm256_div_pd(ymm4, ymm1); + _mm256_storeu_pd((double *)(d11_pack), ymm0); + + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11 + 4 + cs_a*4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5 + cs_a*5)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 + cs_a*6)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 7 + cs_a*7)); + + //Pick one element each column and create a 4 element vector and store + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm0 = _mm256_div_pd(ymm4, ymm1); + _mm256_storeu_pd((double *)(d11_pack + 4), ymm0); + } + + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along n dimension for every D_NR rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) + { + a10 = D_A_pack; + b01 = B + (j*cs_b) + i + D_MR; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - D_MR) / D_MR; //number of times GEMM to be performed + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3 + 4)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4 + 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5 + 4)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + b01 += 1; + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +D_MR; + ptr_a10_dup = a10; + } + } + //GEMM block end here + + /* + Load b11 of size 6x8 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm12); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm13); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm14); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm15); + + ymm13 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm15 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm15 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *4 + 4)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *5 + 4)); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); + ymm18 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm17 = _mm256_permute2f128_pd(ymm18,ymm16,0x20); + ymm19 = _mm256_permute2f128_pd(ymm18,ymm16,0x31); + + ////unpackhigh//// + ymm20 = _mm256_unpackhi_pd(ymm2, ymm3); + + //rearrange high elements + ymm18 = _mm256_permute2f128_pd(ymm20,ymm16,0x20); + ymm20 = _mm256_permute2f128_pd(ymm20,ymm16,0x31); + + /* + Compute 8x6 TRSM block by using GEMM block output in register + a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers + 1. ymm15, ymm20 2. ymm14, ymm19 3. ymm13, ymm18 , 4. ymm12, ymm17 + 5. ymm11, ymm7 6. ymm10, ymm6, 7.ymm9, ymm5 8. ymm8, ymm4 + where ymm15-ymm8 holds 8x4 data and reaming 8x2 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in b11 + */ + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + ymm20 = _mm256_mul_pd(ymm20, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw7): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm20, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm15, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm20, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm15, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm20, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm15, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm20, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm15, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm20, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm15, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm20, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm15, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm20, ymm4); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm19 = _mm256_mul_pd(ymm19, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw6): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm14, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm19, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm14, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm19, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm14, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm19, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm14, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm19, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm14, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm19, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm14, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm19, ymm4); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + ymm18 = _mm256_mul_pd(ymm18, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm13, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm18, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm13, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm18, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm13, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm18, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm13, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm18, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm13, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm18, ymm4); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + ymm17 = _mm256_mul_pd(ymm17, ymm1); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw4): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm12, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm17, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm12, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm17, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm12, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm17, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm12, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm17, ymm4); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + ymm7 = _mm256_mul_pd(ymm7, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + ymm6 = _mm256_mul_pd(ymm6, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + ymm5 = _mm256_mul_pd(ymm5, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + ymm4 = _mm256_mul_pd(ymm4, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm12, ymm13); + ymm3 = _mm256_unpacklo_pd(ymm14, ymm15); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); + + ///unpack high/// + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); + ymm3 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); + + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm1); + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm2); + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm3); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm17, ymm18); + ymm3 = _mm256_unpacklo_pd(ymm19, ymm20); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + + ///unpack high/// + ymm17 = _mm256_unpackhi_pd(ymm17, ymm18); + ymm18 = _mm256_unpackhi_pd(ymm19, ymm20); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm17, ymm18, 0x20); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4 + 4), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 5 + 4), ymm1); + + } + + dim_t n_remainder = j + D_NR; + if(n_remainder >= 4) + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + i; + b01 = B + ((n_remainder - 4)* cs_b) + i + D_MR; + b11 = B + ((n_remainder - 4)* cs_b) + i; + + k_iter = (m - i - D_MR) / D_MR; + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ymm8 = _mm256_setzero_pd(); ymm9 = _mm256_setzero_pd(); ymm10 = _mm256_setzero_pd(); @@ -3073,109 +8627,58 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm15 = _mm256_setzero_pd(); ///GEMM code begins/// - - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations { - ptr_b01_dup = b01; + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - ymm16 = _mm256_loadu_pd((double const *)(a10));//A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - b01 += 1; //mobe to next row of B + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a));//A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3 + 4)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); - b01 += 1; //mobe to next row of B + b01 += 1; //move to next row of B + a10 += p_lda; - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) - - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2));//A10[0][2] A10[1][2] A10[2][2] A10[3][2] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] - - b01 += 1; //mobe to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][4]*A10[3][2]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][5]*A10[3][2]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][6]*A10[3][2]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][7]*A10[3][2]) - - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3));//A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] - - b01 += 1; //mobe to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[3][0] B01[3][0]*A10[2][3] B01[3][0]*A10[3][0]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[3][1]*A10[0][3] B01[3][1]*A10[3][0] B01[3][1]*A10[2][3] B01[3][1]*A10[3][0]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[3][2]*A10[0][3] B01[3][2]*A10[3][0] B01[3][2]*A10[2][3] B01[3][2]*A10[3][0]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[3][3]*A10[0][3] B01[3][3]*A10[3][0] B01[3][3]*A10[2][3] B01[3][3]*A10[3][0]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[3][4]*A10[0][3] B01[3][4]*A10[3][0] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[3][5]*A10[0][3] B01[3][5]*A10[3][0] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[3][6]*A10[0][3] B01[3][6]*A10[3][0] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[3][7]*A10[0][3] B01[3][7]*A10[3][0] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to calculate next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +D_MR; + ptr_a10_dup = a10; + } } ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -3184,10 +8687,11 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *7)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] @@ -3229,48 +8733,137 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] - ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0] - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][0] + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - a11 += cs_a; + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= A11[1][0] * B11[0-3][0] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= A11[2][0] * B11[0-3][0] - ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); //B11[3][0-3] -= A11[3][0] * B11[0-3][0] + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= A11[1][0] * B11[0-3][4] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= A11[2][0] * B11[0-3][4] - ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); //B11[7][0-3] -= A11[3][0] * B11[0-3][4] + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1] - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][1] + //(ROw7): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); - a11 += cs_a; + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6)); + + //(ROw6): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5)); + + //(ROw5): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4)); + + //(ROw4): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3)); + + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2)); //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1] - ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); //B11[3][0-3] -= A11[3][1] * B11[0-3][1] + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); - ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5] - ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); //B11[7][0-3] -= A11[3][1] * B11[0-3][5] + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][2] + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - a11 += cs_a; + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1)); - //(ROw1): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); //B11[3][0-3] -= A11[3][2] * B11[0-3][2] + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); - ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); //B11[7][0-3] -= A11[3][2] * B11[0-3][6] + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] //rearrange low elements ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] @@ -3280,49 +8873,43 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] //rearrange high elements ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store B11[6][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store B11[7][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); //store B11[7][0-3] + n_remainder -=4; } - if(m_remainder) //implementation for reamainder rows(when 'M' is not a multiple of D_MR) + if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)() n = 3 { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + a10 = D_A_pack; + a11 = L + (i*cs_a) + i; + b01 = B + i + D_MR; + b11 = B + i; - k_iter = i / D_MR; //number of times GEMM operation to be done(in blocks of 4x4) - - dim_t iter; - - if((j+D_NR) == n) - { - f_temp = f_t; - for(iter = 0; iter < m_remainder; iter++) - f_temp[iter] = (b11 + cs_b * 7)[iter]; - } - else - f_temp = (b11 + cs_b * 7); + k_iter = (m - i - D_MR) / D_MR; + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ymm8 = _mm256_setzero_pd(); ymm9 = _mm256_setzero_pd(); ymm10 = _mm256_setzero_pd(); @@ -3332,1325 +8919,5030 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm14 = _mm256_setzero_pd(); ymm15 = _mm256_setzero_pd(); - ///GEMM code Begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + if(3 == n_remainder) { - ptr_b01_dup = b01; + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - ymm16 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - b01 += 1; //move to next row of B + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0] ) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm16 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 1)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + b01 += 1; //move to next row of B + a10 += p_lda; - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] - - b01 += 1; //move to next row of B01 - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) - - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] //A10[1][2] A10[2][2] A10[3][2] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] - - b01 += 1; //move to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][0]*A10[3][2]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][1]*A10[3][2]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][2]*A10[3][2]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] - - b01 += 1; //move to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm8 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm8 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm8 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm8 += (B01[3][0]*A10[0][3] B01[3][4]*A10[1][3] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm8 += (B01[3][1]*A10[0][3] B01[3][5]*A10[1][3] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm8 += (B01[3][2]*A10[0][3] B01[3][6]*A10[1][3] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm8 += (B01[3][3]*A10[0][3] B01[3][7]*A10[1][3] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +D_MR; + ptr_a10_dup = a10; + } } - ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +D_MR; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +D_MR; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_broadcast_sd((double const *)(&ones)); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7)); + + //(ROw7): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6)); + + //(ROw6): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5)); + + //(ROw5): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4)); + + //(ROw4): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3)); + + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2)); + + //(ROw2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1)); + + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + if(3 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + } + else if(2 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + } + else if(1 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + } + } + }// End of multiples of D_MR blocks in m-dimension + + // Repetative A blocks will be 4*4 + dim_t m_remainder = i + D_MR; + if(m_remainder >= 4) + { + i = m_remainder - 4; + a10 = L + (i*cs_a) + i + 4; //pointer to block of A to be used for GEMM + a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM + + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + dim_t p_lda = m-i+4; // packed leading dimension + for(dim_t x =0;x < m-i+4;x+=4) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += 4; + ptr_a10_dup += 4; + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm4 = _mm256_div_pd(ymm4, ymm1); + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + //cols + for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b) + i + 4; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + + ////unpacklow//// + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + + //rearrange high elements + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + ymm7 = _mm256_mul_pd(ymm7, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + ymm6 = _mm256_mul_pd(ymm6, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + ymm5 = _mm256_mul_pd(ymm5, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + ymm4 = _mm256_mul_pd(ymm4, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] + } + dim_t n_remainder = j + D_NR; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + i + 4; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup + 4; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + n_remainder = n_remainder - 4; + } + + if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)() n = 3 + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + i; + b01 = B + i + 4; + b11 = B + i; + + k_iter = (m - i - 4); + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_remainder) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] *alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] *alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] *alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] *alpha -= B01[0-3][3] - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] *alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] *alpha -= B01[0-3][5] - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] *alpha -= B01[0-3][6] - ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] *alpha -= B01[0-3][7] - - if(3 == m_remainder) - { - ///implement TRSM/// - - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - - ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] - ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0] - - a11 += cs_a; - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= B11[0-3][0]*A11[2][0] - - ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= B11[0-3][4]*A11[2][4] - - ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1] - - a11 += cs_a; - - //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1] - - ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5] - - ymm11 = _mm256_broadcast_sd((double const *)(&ones)); - ymm15 = _mm256_broadcast_sd((double const *)(&ones)); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[2][5] B11[3][5] B11[2][7] B11[3][7] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - - ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] - ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] - ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] - - //determine correct values to store - ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x08); - ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x08); - ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x08); - ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x08); - ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x08); - ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x08); - ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x08); - ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x08); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); } - else if(2 == m_remainder) + else if(2 == n_remainder) { - ///implement TRSM/// - - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - - ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] - - a11 += cs_a; - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0] - - ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4] - - ymm10 = _mm256_broadcast_sd((double const *)&ones); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm10, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm10, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_permute2f128_pd(ymm5, ymm10, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm10, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm10, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm10, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm5 = _mm256_permute2f128_pd(ymm12, ymm10, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm10, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - - ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] - ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] - ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] - - //determine correct values to store - ymm0 = _mm256_permute2f128_pd(ymm0, ymm8, 0x30); - ymm1 = _mm256_permute2f128_pd(ymm1, ymm9, 0x30); - ymm2 = _mm256_permute2f128_pd(ymm2, ymm10, 0x30); - ymm3 = _mm256_permute2f128_pd(ymm3, ymm11, 0x30); - ymm4 = _mm256_permute2f128_pd(ymm4, ymm12, 0x30); - ymm5 = _mm256_permute2f128_pd(ymm5, ymm13, 0x30); - ymm6 = _mm256_permute2f128_pd(ymm6, ymm14, 0x30); - ymm7 = _mm256_permute2f128_pd(ymm7, ymm15, 0x30); - - } - else if(1 == m_remainder) + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { - ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] - ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] - ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] + ymm0 = _mm256_loadu_pd((double const *)(a10)); - //determine correct values to store - ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x0E); - ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x0E); - ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x0E); - ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x0E); - ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x0E); - ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x0E); - ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x0E); - ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x0E); + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } } - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store(B11[0-3][4]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store(B11[0-3][5]) - _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store(B11[0-3][6]) - _mm256_storeu_pd((double *)(f_temp), ymm7); //store(B11[0-3][7]) + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - if((j+D_NR) == n) - { - for(iter = 0; iter < m_remainder; iter++) - (b11 + cs_b * 7)[iter] = f_temp[iter]; + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); } - } - } - - if((n & 4)) //implementation for remainder columns(when 'n_remainder' is greater than 4) - { - for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i / D_MR; //number of times GEMM to be performed(in block of 4) - ///GEMM for previously calculated values /// - - //load 4x4 block from b11 - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + else if(1 == n_remainder) { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a*2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[1][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[2][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[3][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + b01 += 1; //move to next row of B + a10 += p_lda; + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B01[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B01[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B01[0-3][3] *alpha -= ymm7 + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } ///implement TRSM/// - //1st col - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] - //2nd col - a11 += cs_a; - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] - - //3rd col - a11 += cs_a; - ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]*B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]*B11[0][0-3] - ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]*B11[0][0-3] + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]*B11[1][0-3] - ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]*B11[1][0-3] + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]*B11[2][0-3] + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3)); + + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2)); + + //(ROw2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1)); + + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[0-3][3]) - - } - if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) - - dim_t iter; - - if((j+4) == n) - { - f_temp = f_t; - for(iter = 0; iter < m_remainder; iter++) - f_temp[iter] = (b11 + cs_b * 3)[iter]; - } - else - f_temp = (b11 + cs_b * 3); - ///GEMM for previously calculated values /// - - //load 4x4 block from b11 - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - for(k = 0; k < k_iter; k++) //looop for number of GEMM operations - { - ptr_b01_dup = b01; - - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - - } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - - ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[0-3][3] *alpha -= ymm7 - - - if(3 == m_remainder) - { - ///implement TRSM/// - //1st col - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] - - //2nd col - a11 += cs_a; - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]* B11[0][0-3] - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]* B11[1][0-3] - - ymm13 = _mm256_broadcast_sd((double const *)(&ones)); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - //load 4x4 block from b11 - ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] - - //determine correct values to store - ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08); - ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08); - ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08); - ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08); - } - else if(2 == m_remainder) - { - ///implement TRSM/// - //1st col - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] - - ymm11 = _mm256_broadcast_sd((double const *)(&ones)); - ymm13 = _mm256_broadcast_sd((double const *)(&ones)); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - //load 4x4 block from b11 - ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] - - //determine correct values to store - ymm0 = _mm256_permute2f128_pd(ymm0, ymm4,0x30); - ymm1 = _mm256_permute2f128_pd(ymm1, ymm5,0x30); - ymm2 = _mm256_permute2f128_pd(ymm2, ymm6,0x30); - ymm3 = _mm256_permute2f128_pd(ymm3, ymm7,0x30); - - } - else if(1 == m_remainder) - { - //load 4x4 block from b11 - ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] - - //determine correct values to store - ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x0E); - ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x0E); - ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x0E); - ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x0E); - } - - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[0-3][3]) - - if((j+4) == n) - { - for(iter = 0; iter < m_remainder; iter++) - (b11 + cs_b * 3)[iter] = f_temp[iter]; - } - } - - n_remainder -= 4; - j += 4; - - } - - if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR) - { - for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM for previously calculated values /// - - //load 4x4 block from b11 if(3 == n_remainder) { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + } + else if(2 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + } + else if(1 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + } + } + m_remainder -= 4; + } - for(k = 0; k < k_iter; k++) + if(m_remainder) + { + a10 = L + m_remainder; + + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + if(3 == m_remainder) // Repetative A blocks will be 3*3 + { + dim_t p_lda = m-m_remainder; // packed leading dimension + for(dim_t x =0;x < m-m_remainder;x+=4) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += 4; + ptr_a10_dup += 4; + } + + //cols + for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - b01 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - b01 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] + b01 += 1; //move to next row of B + a10 += p_lda; - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } } ///GEMM code ends/// - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); } - else if(2 == n_remainder) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - for(k = 0; k < k_iter; k++) + dim_t n_remainder = j + D_NR; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - b01 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - b01 += 1; + b01 += 1; //move to next row of B + a10 += p_lda; - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } } - ///GEMM code ends/// - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value - ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + n_remainder -= 4; } - else if(1 == n_remainder) + if(n_remainder) { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM - for(k = 0; k < k_iter; k++) + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_remainder) { - ptr_b01_dup = b01; + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - b01 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + b01 += 1; //move to next row of B + a10 += p_lda; - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } } - ///GEMM code ends/// - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value - ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) + { + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + } } - - ///implement TRSM/// - //1st col - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] - - //2nd col - a11 += cs_a; - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] - - //3rd col - a11 += cs_a; - ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0] * B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0] * B11[0][0-3] - ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0] * B11[0][0-3] - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1] * B11[1][0-3] - ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1] * B11[1][0-3] - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2] * B11[2][0-3] - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - if(3 == n_remainder) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - - } - else if(2 == n_remainder) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - - } - else if(1 == n_remainder) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - } - } - if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) + else if(2 == m_remainder) // Repetative A blocks will be 2*2 { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - - k_iter = i / D_MR; //number of times GEMM operations to be performed - - dim_t iter; - - if((j+n_remainder) == n) + dim_t p_lda = m-m_remainder; // packed leading dimension + for(dim_t x =0;x < m-m_remainder;x+=4) { - f_temp = f_t; - for(iter = 0; iter < m_remainder; iter++) - f_temp[iter] = (b11 + cs_b * (n_remainder -1))[iter]; + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += 4; + ptr_a10_dup += 4; } - else - f_temp = (b11 + cs_b * (n_remainder -1)); - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM for previously calculated values /// - - - //load 4x4 block from b11 - if(3 == n_remainder) + //cols + for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[0][2] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - b01 += 1; //move to next row of B + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[1][2] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - b01 += 1; //move to next row of B + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[2][2] + b01 += 1; //move to next row of B + a10 += p_lda; - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[3][2] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value - ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 - ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5 - ymm10 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] * alpha -= ymm6 + ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + } + dim_t n_remainder = j + D_NR; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha ///implement TRSM/// - //determine correct values to store - if(3 == m_remainder) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - } - else if(2 == m_remainder) - { - ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); - ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); - ymm2 = _mm256_permute2f128_pd(ymm10, ymm2, 0x30); - } - else if(1 == m_remainder) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - } - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(f_temp), ymm2); //store(B11[0-3][2]) + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + n_remainder -= 4; } - else if(2 == n_remainder) + if(n_remainder) { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_remainder) { - ptr_b01_dup = b01; + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - b01 += 1; //move to next row of B + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] + b01 += 1; //move to next row of B + a10 += p_lda; - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value - ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 - ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5 + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha ///implement TRSM/// - //determine correct values to store - if(3 == m_remainder) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); } - else if(2 == m_remainder) + else if(2 == n_remainder) { - ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); - ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); - } - else if(1 == m_remainder) + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } } - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[0-3][1]) - } - else if(1 == n_remainder) - { - ymm0 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_b01_dup = b01; - - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - - } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value - - ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha ///implement TRSM/// - //determine correct values to store - if(3 == m_remainder) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); } - else if(2 == m_remainder) + else if(1 == n_remainder) { - ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); - } - else if(1 == m_remainder) + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); } - _mm256_storeu_pd((double *)(f_temp), ymm0); //store(B11[0-3][0]) } - if((j+n_remainder) == n) - { - for(iter = 0; iter < m_remainder; iter++) - (b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter]; - } - ///scalar code for trsm without alpha/// - dtrsm_small_AlXB_unitDiag(a11, b11, m_remainder, n_remainder, cs_a, cs_b); } + else if(1 == m_remainder) // Repetative A blocks will be 1*1 + { + dim_t p_lda = m-m_remainder; // packed leading dimension + for(dim_t x =0;x < m-m_remainder;x+=4) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_broadcast_sd((double const *)&ones); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += 4; + ptr_a10_dup += 4; + } + //cols + for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + } + + ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + } + dim_t n_remainder = j + D_NR; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + + } + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + if(n_remainder) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_remainder) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4; + ptr_a10_dup = a10; + } + } + + //register to hold alpha + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ///implement TRSM/// + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + } + } + } + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm,&local_mem_buf_A_s); } return BLIS_SUCCESS; } +/* + * TRSM for the case AX = alpha * B, Double precision + * A is upper-triangular, non-transpose, non-unit diagonal + * dimensions A: mxm X: mxn B: mxn +*/ +static err_t bli_dtrsm_small_AuXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + dim_t D_MR = 8; //size of block along 'M' dimpension + dim_t D_NR = 6; //size of block along 'N' dimension -/*implements TRSM for the case XA = alpha * B + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B + + dim_t cs_a = bli_obj_col_stride(a); // column stride of A + dim_t cs_b = bli_obj_col_stride(b); // column stride of B + + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed + + double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha + double *L = a->buffer; //pointer to matrix A + double *B = b->buffer; //pointer to matrix B + + //pointers that point to blocks for GEMM and TRSM + double *a10, *a11, *b01, *b11; + double *ptr_b01_dup; + double *ptr_a10_dup; + + double ones = 1.0; + bool is_unitdiag = bli_obj_has_unit_diag(a); + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19; + __m256d ymm20; + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s; + double *D_A_pack = NULL; + double d11_pack[D_MR] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if ( (D_MR * m * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + } + + /* + Performs solving TRSM for 8 colmns at a time from 0 to m/8 in steps of D_MR + a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-8) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n row of B in steps of D_NR + */ + for(i = (m - D_MR); (i + 1) > 0; i -= D_MR) + { + a10 = L + (i + D_MR)*cs_a + i; //pointer to block of A to be used for GEMM + a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM + + // Do transpose for a10 & store in D_A_pack + ptr_a10_dup = D_A_pack; //ptr_a11_dup = a11; + dim_t p_lda = D_MR; // packed leading dimension + + /* + Pack current A block (a10) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by D_MR for every next itteration + untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + for(dim_t x =0;x < (m-i-D_MR);x++) + { + ymm16 = _mm256_loadu_pd((double const *)(a10 + (cs_a * x))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * x)), ymm16); + ymm16 = _mm256_loadu_pd((double const *)(a10 + (cs_a * x) + 4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * x) + 4), ymm16); + } + + /* + Pack 8 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(is_unitdiag) + { + _mm256_storeu_pd((double *)(d11_pack), ymm4); + _mm256_storeu_pd((double *)(d11_pack + 4), ymm4); + }else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm0 = _mm256_div_pd(ymm4, ymm1); + _mm256_storeu_pd((double *)(d11_pack), ymm0); + + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11 + 4 + cs_a*4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5 + cs_a*5)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 + cs_a*6)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 7 + cs_a*7)); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm0 = _mm256_div_pd(ymm4, ymm1); + _mm256_storeu_pd((double *)(d11_pack + 4), ymm0); + } + + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along n dimension for every D_NR rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; + b01 = B + (j*cs_b) + i + D_MR; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - D_MR) / D_MR; //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3 + 4)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4 + 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5 + 4)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + b01 += 1; + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup + D_MR*p_lda; + ptr_a10_dup = a10; + } + } + //GEMM block ends here + + /* + Load b11 of size 6x8 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm12); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm13); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm14); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm15); + + ymm13 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm15 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm15 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *4 + 4)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *5 + 4)); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); + ymm18 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm17 = _mm256_permute2f128_pd(ymm18,ymm16,0x20); + ymm19 = _mm256_permute2f128_pd(ymm18,ymm16,0x31); + ymm20 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm18 = _mm256_permute2f128_pd(ymm20,ymm16,0x20); + ymm20 = _mm256_permute2f128_pd(ymm20,ymm16,0x31); + + /* + Compute 8x6 TRSM block by using GEMM block output in register + a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers + 1. ymm15, ymm20 2. ymm14, ymm19 3. ymm13, ymm18 , 4. ymm12, ymm17 + 5. ymm11, ymm7 6. ymm10, ymm6, 7.ymm9, ymm5 8. ymm8, ymm4 + where ymm15-ymm8 holds 8x4 data and reaming 8x2 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in b11 + */ + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + ymm20 = _mm256_mul_pd(ymm20, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw7): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 + 7*cs_a)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm20, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5 + 7*cs_a)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm15, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm20, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 7*cs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm15, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm20, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 7*cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm15, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm20, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 7*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm15, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm20, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 7*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm15, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm20, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm15, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm20, ymm4); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm19 = _mm256_mul_pd(ymm19, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw6): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5 + 6*cs_a)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm14, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm19, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 6*cs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm14, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm19, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 6*cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm14, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm19, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 6*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm14, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm19, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 6*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm14, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm19, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm14, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm19, ymm4); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + ymm18 = _mm256_mul_pd(ymm18, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 5*cs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm13, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm18, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 5*cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm13, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm18, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 5*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm13, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm18, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 5*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm13, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm18, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm13, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm18, ymm4); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + ymm17 = _mm256_mul_pd(ymm17, ymm1); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw4): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 4*cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm12, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm17, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 4*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm12, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm17, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 4*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm12, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm17, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm12, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm17, ymm4); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + ymm7 = _mm256_mul_pd(ymm7, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + ymm6 = _mm256_mul_pd(ymm6, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + ymm5 = _mm256_mul_pd(ymm5, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + ymm4 = _mm256_mul_pd(ymm4, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm12, ymm13); + ymm3 = _mm256_unpacklo_pd(ymm14, ymm15); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); + + ///unpack high/// + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); + ymm3 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); + + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm1); + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm2); + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm3); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm17, ymm18); + ymm3 = _mm256_unpacklo_pd(ymm19, ymm20); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + + ///unpack high/// + ymm17 = _mm256_unpackhi_pd(ymm17, ymm18); + ymm18 = _mm256_unpackhi_pd(ymm19, ymm20); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm17, ymm18, 0x20); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4 + 4), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 5 + 4), ymm1); + + } + + dim_t n_remainder = j + D_NR; + if(n_remainder >= 4) + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + i; + b01 = B + ((n_remainder - 4)* cs_b) + i + D_MR; + b11 = B + ((n_remainder - 4)* cs_b) + i; + + k_iter = (m - i - D_MR) / D_MR; + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3 + 4)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup + D_MR*p_lda; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); + ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); + + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 + 7*cs_a)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5 + 7*cs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 + 7*cs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 + 7*cs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 + 7*cs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 7*cs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7*cs_a)); + + //(ROw7): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5 + 6*cs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 + 6*cs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 + 6*cs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 + 6*cs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 6*cs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a)); + + //(ROw6): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 + 5*cs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 + 5*cs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 + 5*cs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 5*cs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); + + //(ROw5): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 + 4*cs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 + 4*cs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 4*cs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); + + //(ROw4): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + + //(ROw2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1*cs_a)); + + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); + n_remainder -=4; + } + + if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)() n = 3 + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + i; + b01 = B + i + D_MR; + b11 = B + i; + + k_iter = (m - i - D_MR) / D_MR; + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_remainder) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2 + 4)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup + D_MR*p_lda; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_remainder) + { + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1 + 4)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup + D_MR*p_lda; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + + for(k = 0; k< k_iter*4; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + p_lda * 4 + 4)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0 + 4)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + D_MR; + ptr_b01_dup = b01; + a10 = ptr_a10_dup + D_MR*p_lda; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); + ymm5 = _mm256_broadcast_sd((double const *)(&ones)); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //perform mul operation + ymm15 = _mm256_mul_pd(ymm15, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw7): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 + 7*cs_a)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5 + 7*cs_a)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm15, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 7*cs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm15, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 7*cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm15, ymm11); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 7*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm15, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 7*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm15, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm15, ymm8); + + //perform mul operation + ymm14 = _mm256_mul_pd(ymm14, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw6): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5 + 6*cs_a)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm14, ymm13); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 6*cs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm14, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 6*cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm14, ymm11); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 6*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm14, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 6*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm14, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm14, ymm8); + + //perform mul operation + ymm13 = _mm256_mul_pd(ymm13, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 5*cs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm13, ymm12); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 5*cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm13, ymm11); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 5*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm13, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 5*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm13, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm13, ymm8); + + //perform mul operation + ymm12 = _mm256_mul_pd(ymm12, ymm1); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw4): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 4*cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm12, ymm11); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 4*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm12, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 4*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm12, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm12, ymm8); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); + + if(3 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); + } + else if(2 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); + } + else if(1 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); + } + } + }// End of multiples of D_MR blocks in m-dimension + + // Repetative A blocks will be 4*4 + dim_t m_remainder = i + D_MR; + if(m_remainder >= 4) + { + i = m_remainder - 4; + a10 = L + (i + 4)*cs_a + i; //pointer to block of A to be used for GEMM + a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM + + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + double *ptr_a11_dup = a11; + dim_t p_lda = 4; // packed leading dimension + for(dim_t x =0;x < m-i-4;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x*cs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); + + //Pick one element each column and create a 4 element vector and store + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm4 = _mm256_div_pd(ymm4, ymm1); + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + //cols + for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = ptr_a11_dup; //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b) + i + 4; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + ptr_b01_dup = b01; + ptr_a10_dup = a10; + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + a10 += p_lda; + + if(!((k+1)&0x03)) + { + b01 = ptr_b01_dup + 4; + ptr_b01_dup = b01; + a10 = ptr_a10_dup +4*p_lda; + ptr_a10_dup = a10; + } + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + + ////unpacklow//// + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + + //rearrange high elements + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + ymm7 = _mm256_mul_pd(ymm7, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + ymm6 = _mm256_mul_pd(ymm6, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + ymm5 = _mm256_mul_pd(ymm5, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + ymm4 = _mm256_mul_pd(ymm4, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] + } + + dim_t n_remainder = j + D_NR; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + i + 4; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - 4); //number of times GEMM to be performed + + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + n_remainder = n_remainder - 4; + } + + if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)() n = 3 + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + i; + b01 = B + i + 4; + b11 = B + i; + + k_iter = (m - i - 4); + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //perform mul operation + ymm11 = _mm256_mul_pd(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + + //perform mul operation + ymm9 = _mm256_mul_pd(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + + //perform mul operation + ymm8 = _mm256_mul_pd(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + if(3 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + } + else if(2 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + } + else if(1 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + } + } + m_remainder -= 4; + } + + if(m_remainder) + { + a10 = L + m_remainder*cs_a; + + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + if(3 == m_remainder) // Repetative A blocks will be 3*3 + { + dim_t p_lda = 4; // packed leading dimension + for(dim_t x =0;x < m-m_remainder;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x*cs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); + } + //cols + for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + } + + dim_t n_remainder = j + D_NR; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed + + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + if(n_remainder) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed + + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) + { + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + } + } + } + else if(2 == m_remainder) // Repetative A blocks will be 2*2 + { + dim_t p_lda = 4; // packed leading dimension + for(dim_t x =0;x < m-m_remainder;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x*cs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); + } + //cols + for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + } + dim_t n_remainder = j + D_NR; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + if(n_remainder) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + } + } + + } + else if(1 == m_remainder) // Repetative A blocks will be 1*1 + { + dim_t p_lda = 4; // packed leading dimension + for(dim_t x =0;x < m-m_remainder;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x*cs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); + } + //cols + for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed + + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + } + dim_t n_remainder = j + D_NR; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed + + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + if(n_remainder) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed + + #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); + #endif + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(3 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + + dtrsm_AuXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + + b01 += 1; //move to next row of B + a10 += p_lda; + } + + //register to hold alpha + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ///implement TRSM/// + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + dtrsm_AuXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + } + } + } + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm, &local_mem_buf_A_s); + } + return BLIS_SUCCESS; +} + +/* TRSM for the case XA = alpha * B *A is upper triangular, non-unit diagonal, no transpose *dimensions: X:mxn A:nxn B: mxn - */ - -/* b11---> a01 ----> + * + * b11---> a01 ----> ***************** *********** *b01*b11* * * * * * * b11 * * * * * **a01 * * a11 @@ -4664,14 +13956,14 @@ b11 * * * * * **a01 * * a11 * */ -static err_t bli_dtrsm_small_XAuB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +static err_t bli_dtrsm_small_XAuB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns @@ -4685,18 +13977,6 @@ static err_t bli_dtrsm_small_XAuB( dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B -#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME) - || (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_COLUMN_PANEL_N) - ) - return BLIS_NOT_YET_IMPLEMENTED; -#else - if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO) - { - return BLIS_NOT_YET_IMPLEMENTED; - } -#endif - dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides @@ -4710,8 +13990,8 @@ static err_t bli_dtrsm_small_XAuB( double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks double *ptr_a01_dup; - double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0 - double* f_temp; + double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0 + double* f_temp; cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; @@ -5877,7 +15157,7 @@ static err_t bli_dtrsm_small_XAuB( } if(m_remainder) ///omplementation for remainder rows { - for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' direction + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction { a01 = L + j*cs_a; //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM @@ -6366,18 +15646,17 @@ static err_t bli_dtrsm_small_XAuB( (b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter]; } //scalar code for TRSM - dtrsm_small_XAuB(a11, b11, m_remainder, n_remainder, cs_a, cs_b); + dtrsm_XAuB_ref(a11, b11, m_remainder, n_remainder, cs_a, cs_b); } } return BLIS_SUCCESS; } -/*implements TRSM for the case XA = alpha * B +/* TRSM for the case XA = alpha * B *A is upper triangular, unit-diagonal, no transpose *dimensions: X:mxn A:nxn B: mxn - */ - -/* b11---> a01 ----> + * + * b11---> a01 ----> ***************** *********** *b01*b11* * * * * * * b11 * * * * * **a01 * * a11 @@ -6391,14 +15670,14 @@ b11 * * * * * **a01 * * a11 * */ -static err_t bli_dtrsm_small_XAuB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +static err_t bli_dtrsm_XAuB_unitDiag_ref +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns @@ -6412,18 +15691,6 @@ static err_t bli_dtrsm_small_XAuB_unitDiag( dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B -#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME) - || (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_COLUMN_PANEL_N) - ) - return BLIS_NOT_YET_IMPLEMENTED; -#else - if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO) - { - return BLIS_NOT_YET_IMPLEMENTED; - } -#endif - dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides @@ -7367,7 +16634,7 @@ static err_t bli_dtrsm_small_XAuB_unitDiag( } if(m_remainder) ///omplementation for remainder rows { - for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' direction + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction { a01 = L + j*cs_a; //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM @@ -7819,19 +17086,17 @@ static err_t bli_dtrsm_small_XAuB_unitDiag( (b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter]; } //scalar code for TRSM - dtrsm_small_XAuB_unitDiag(a11, b11, m_remainder, n_remainder, cs_a, cs_b); + dtrsm_XAuB_unitDiag_ref(a11, b11, m_remainder, n_remainder, cs_a, cs_b); } } return BLIS_SUCCESS; } - -/*implements TRSM for the case XA = alpha * B +/* TRSM for the case XA = alpha * B *A is lower triangular, non-unit diagonal, transpose *dimensions: X:mxn A:nxn B: mxn - */ - -/* b11---> a01 ----> + * + * b11---> a01 ----> ***************** *********** *b01*b11* * * * * * * b11 * * * * * **a01 * * a11 @@ -7843,16 +17108,15 @@ b11 * * * * * **a01 * * a11 * * * * * * * ***************** * * * - */ -static err_t bli_dtrsm_small_XAltB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +static err_t bli_dtrsm_small_XAltB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns @@ -7866,21 +17130,6 @@ static err_t bli_dtrsm_small_XAltB( dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B -#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_N) - || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_M && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_N) - || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME) - || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME) - || (m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N) - ) - return BLIS_NOT_YET_IMPLEMENTED; -#else - if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) - { - return BLIS_NOT_YET_IMPLEMENTED; - } -#endif - dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides @@ -9074,7 +18323,7 @@ static err_t bli_dtrsm_small_XAltB( } if(m_remainder) ///omplementation for remainder rows { - for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' direction + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction { a01 = L + j; //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM @@ -9574,18 +18823,17 @@ static err_t bli_dtrsm_small_XAltB( (b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter]; } //scalar code for TRSM - dtrsm_small_XAltB(a11, b11, m_remainder, n_remainder, cs_a, cs_b); + dtrsm_XAltB_ref(a11, b11, m_remainder, n_remainder, cs_a, cs_b); } } return BLIS_SUCCESS; } -/*implements TRSM for the case XA = alpha * B +/* TRSM for the case XA = alpha * B *A is lower triangular, unit-diagonal, transpose *dimensions: X:mxn A:nxn B: mxn - */ - -/* b11---> a01 ----> + * + * b11---> a01 ----> ***************** *********** *b01*b11* * * * * * * b11 * * * * * **a01 * * a11 @@ -9599,14 +18847,14 @@ b11 * * * * * **a01 * * a11 * */ -static err_t bli_dtrsm_small_XAltB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +static err_t bli_dtrsm_small_XAltB_unitDiag +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns @@ -9620,21 +18868,6 @@ static err_t bli_dtrsm_small_XAltB_unitDiag( dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B -#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_N) - || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_M && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_N) - || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME) - || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME) - || (m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N) - ) - return BLIS_NOT_YET_IMPLEMENTED; -#else - if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) - { - return BLIS_NOT_YET_IMPLEMENTED; - } -#endif - dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides @@ -10583,7 +19816,7 @@ static err_t bli_dtrsm_small_XAltB_unitDiag( } if(m_remainder) ///omplementation for remainder rows { - for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' direction + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction { a01 = L + j; //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM @@ -11043,19 +20276,17 @@ static err_t bli_dtrsm_small_XAltB_unitDiag( (b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter]; } //scalar code for TRSM - dtrsm_small_XAltB_unitDiag(a11, b11, m_remainder, n_remainder, cs_a, cs_b); + dtrsm_XAltB_unitDiag_ref(a11, b11, m_remainder, n_remainder, cs_a, cs_b); } } return BLIS_SUCCESS; } - -/*implements TRSM for the case XA = alpha * B +/* TRSM for the case XA = alpha * B *A is lower triangular, non-unit diagonal, no transpose *dimensions: X:mxn A:nxn B: mxn - */ - -/* <---b11 <---a11 + * + * <---b11 <---a11 ***************** * *b01*b11* * * * * ^ * * * * * ^ * * @@ -11068,14 +20299,14 @@ b10 ***************** ************* ***************** ******************* */ -static err_t bli_dtrsm_small_XAlB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +static err_t bli_dtrsm_small_XAlB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns @@ -11089,19 +20320,6 @@ static err_t bli_dtrsm_small_XAlB( dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B - -#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME) - ||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N) - ) - return BLIS_NOT_YET_IMPLEMENTED; -#else - if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) - { - return BLIS_NOT_YET_IMPLEMENTED; - } -#endif - dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides @@ -12294,17 +21512,16 @@ static err_t bli_dtrsm_small_XAlB( // if(i < 0) i = 0; if(m_remainder) ///implementation for remainder rows { - dtrsm_small_XAlB(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); + dtrsm_XAlB_ref(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); } return BLIS_SUCCESS; } -/*implements TRSM for the case XA = alpha * B +/* TRSM for the case XA = alpha * B *A is lower triangular, unit-diagonal, no transpose *dimensions: X:mxn A:nxn B: mxn - */ - -/* <---b11 <---a11 + * + * <---b11 <---a11 ***************** * *b01*b11* * * * * ^ * * * * * ^ * * @@ -12315,16 +21532,15 @@ b10 ***************** ************* * * * * * * * * * * * * * * * * * * ***************** ******************* - */ -static err_t bli_dtrsm_small_XAlB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +static err_t bli_dtrsm_small_XAlB_unitDiag +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns @@ -12338,19 +21554,6 @@ static err_t bli_dtrsm_small_XAlB_unitDiag( dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B - -#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME) - ||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N) - ) - return BLIS_NOT_YET_IMPLEMENTED; -#else - if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) - { - return BLIS_NOT_YET_IMPLEMENTED; - } -#endif - dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides @@ -13288,18 +22491,16 @@ static err_t bli_dtrsm_small_XAlB_unitDiag( // if(i < 0) i = 0; if(m_remainder) ///implementation for remainder rows { - dtrsm_small_XAlB_unitDiag(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); + dtrsm_XAlB_unitDiag_ref(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); } return BLIS_SUCCESS; } - -/*implements TRSM for the case XA = alpha * B +/* TRSM for the case XA = alpha * B *A is lower triangular, non-unit diagonal, no transpose *dimensions: X:mxn A:nxn B: mxn - */ - -/* <---b11 <---a11 + * + * <---b11 <---a11 ***************** * *b01*b11* * * * * ^ * * * * * ^ * * @@ -13312,14 +22513,14 @@ b10 ***************** ************* ***************** ******************* */ -static err_t bli_dtrsm_small_XAutB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +static err_t bli_dtrsm_small_XAutB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns @@ -13333,18 +22534,6 @@ static err_t bli_dtrsm_small_XAutB( dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B -#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME) - ||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N) - ) - return BLIS_NOT_YET_IMPLEMENTED; -#else - if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) - { - return BLIS_NOT_YET_IMPLEMENTED; - } -#endif - dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides @@ -14558,17 +23747,16 @@ static err_t bli_dtrsm_small_XAutB( } if(m_remainder) ///implementation for remainder rows { - dtrsm_small_XAutB(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); + dtrsm_XAutB_ref(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); } return BLIS_SUCCESS; } -/*implements TRSM for the case XA = alpha * B +/* TRSM for the case XA = alpha * B *A is lower triangular, unit-diagonal, no transpose *dimensions: X:mxn A:nxn B: mxn - */ - -/* <---b11 <---a11 + * + * <---b11 <---a11 ***************** * *b01*b11* * * * * ^ * * * * * ^ * * @@ -14581,14 +23769,14 @@ b10 ***************** ************* ***************** ******************* */ -static err_t bli_dtrsm_small_XAutB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) +static err_t bli_dtrsm_small_XAutB_unitDiag +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns @@ -14602,18 +23790,6 @@ static err_t bli_dtrsm_small_XAutB_unitDiag( dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B -#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME) - ||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N) - ) - return BLIS_NOT_YET_IMPLEMENTED; -#else - if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) - { - return BLIS_NOT_YET_IMPLEMENTED; - } -#endif - dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides @@ -15566,12255 +24742,9 @@ static err_t bli_dtrsm_small_XAutB_unitDiag( } if(m_remainder) ///implementation for remainder rows { - dtrsm_small_XAutB_unitDiag(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); + dtrsm_XAutB_unitDiag_ref(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); } return BLIS_SUCCESS; } - -/* - * AX = Alpha*B, Single precision, A: lower triangular - * This kernel implementation supports matrices A and B such that m is equal to BLI_AlXB_M_SP and n is mutiple of 8 - */ - -static err_t bli_strsm_small_AlXB ( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) -{ - obj_t alpha, beta; // gemm parameters - obj_t Ga, Gb, Gc; // for GEMM - int m = bli_obj_length(b); // number of rows of matrix B - int n = bli_obj_width(b); // number of columns of matrix B - - int lda = bli_obj_col_stride(a); // column stride of A - int ldb = bli_obj_col_stride(b); // column stride of B - - int rsa = bli_obj_row_stride(a); // row stride of A - int rsb = bli_obj_row_stride(b); // row stride of B - - int i = 0; - int j; - int blk_size = 8; - int isUnitDiag = bli_obj_has_unit_diag(a); - - float alphaVal; - float* restrict L = a->buffer; - float* restrict B = b->buffer; - - if (m != BLI_AlXB_M_SP || (n&7) != 0) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - if ( (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM ) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - - alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj)); - - /* Small _GEMM preparation code */ - bli_obj_create( BLIS_FLOAT, 1, 1, 0, 0, &alpha ); - bli_obj_create( BLIS_FLOAT, 1, 1, 0, 0, &beta ); - - /* B = B - A*B */ - bli_setsc( -(1.0), 0.0, &alpha ); - bli_setsc( (1.0), 0.0, &beta ); - - - bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, blk_size, a->buffer, rsa, lda, &Ga); - bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, n, b->buffer, rsb, ldb, &Gb); - bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, n, b->buffer, rsb, ldb, &Gc); - - bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Ga ); - bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Gb ); - bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Gc ); - - //first block of trsm - Gb.buffer = (void*)(B + i); - - //trsm of first 8xn block - if (alphaVal != 1) - { - if (isUnitDiag == 0) - { - blis_strsm_microkernel_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); - fp_blis_strsm_microkernel = blis_strsm_microkernel; - } - else - { - blis_strsm_microkernel_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); - fp_blis_strsm_microkernel = blis_strsm_microkernel_unitDiag; - } - bli_setsc( alphaVal, 0.0, &beta ); - } - else - { - if (isUnitDiag == 0) - { - blis_strsm_microkernel((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); - fp_blis_strsm_microkernel = blis_strsm_microkernel; - } - else - { - blis_strsm_microkernel_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); - fp_blis_strsm_microkernel = blis_strsm_microkernel_unitDiag; - } - } - - //gemm update - for (j = i + blk_size; j < m; j += blk_size) // for rows upto multiple of BLOCK_HEIGHT - { - Ga.buffer = (void*)(L + j + i*lda); - Gc.buffer = (void*)(B + j); - - bli_gemm_small(&alpha, &Ga, &Gb, &beta, &Gc, cntx, cntl ); // Gc = beta*Gc + alpha*Ga *Gb - } - - //trsm of remaining blocks - for (i = blk_size; i < m; i += blk_size) - { - Gb.buffer = (void*)(B + i); - - fp_blis_strsm_microkernel((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); - - for (j = i + blk_size; j < m; j += blk_size) // for rows upto multiple of BLOCK_HEIGHT - { - Ga.buffer = (void*)(L + j + i*lda); - Gc.buffer = (void*)(B + j); - - bli_gemm_small(&alpha, &Ga, &Gb, &beta, &Gc, cntx, cntl ); // Gc = beta*Gc + alpha*Ga *Gb - } - - } // End of for loop - i - - return BLIS_SUCCESS; -} - - - -/* - * XA' = Alpha*B, Single precision, A: lower triangular - * This kernel implementation supports matrices A and B such that - * m and n are multiples of 8 and n is less than or equal to BLI_XAltB_N_SP - */ -static err_t bli_strsm_small_XAltB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) -{ - int m = bli_obj_length(a); // number of rows of matrix B - int n = bli_obj_length(b); // number of columns of matrix B - - int lda = bli_obj_col_stride(a); // column stride of A - int ldb = bli_obj_col_stride(b); // column stride of B - - int rsa = bli_obj_row_stride(a); // row stride of A - int rsb = bli_obj_row_stride(b); // row stride of B - - int i = 0; - int isUnitDiag = bli_obj_has_unit_diag(a); - - float alphaVal; - float *L = a->buffer; - float *B = b->buffer; - - if ((m&7) != 0 || (n&7) != 0) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - if ( n > BLI_XAltB_N_SP || (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM ) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - - alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj)); - - if (alphaVal != 1) - { - if (isUnitDiag == 0) - { - trsm_XAtB_block_allSmallSizedMatrices_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); - } - else - { - trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); - } - } - else - { - if (isUnitDiag == 0) - { - trsm_XAtB_block_allSmallSizedMatrices((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); - } - else - { - trsm_XAtB_block_allSmallSizedMatrices_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); - } - } - return BLIS_SUCCESS; -} - -/* - * A'X = Alpha*B, Single precision, A: upper triangular - * This kernel implementation supports matrices A and B such that - * m and n are multiples of 8, m is less than or equal to BLI_AutXB_M_SP and n is less than or equal to BLI_AutXB_N_SP - */ -static err_t bli_strsm_small_AutXB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) -{ - int m = bli_obj_width(a); // number of rows of matrix A (since At, so width is taken) - int n = bli_obj_width(b); // number of columns of matrix B - - int lda = bli_obj_col_stride(a); // column stride of A - int ldb = bli_obj_col_stride(b); // column stride of B - - int rsa = bli_obj_row_stride(a); // row stride of A - int rsb = bli_obj_row_stride(b); // row stride of B - - int i = 0; - int isUnitDiag = bli_obj_has_unit_diag(a); - - float alphaVal; - float *L = a->buffer; - float *B = b->buffer; - - if ((m&7) != 0 || (n&7) != 0) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - if ( m > BLI_AutXB_M_SP || n > BLI_AutXB_N_SP || (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM ) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - - alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj)); - - if (alphaVal != 1) - { - if (isUnitDiag == 0) - { - trsm_AutXB_block_allSmallSizedMatrices_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); - } - else - { - trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); - } - } - else - { - if (isUnitDiag == 0) - { - trsm_AutXB_block_allSmallSizedMatrices((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); - } - else - { - trsm_AutXB_block_allSmallSizedMatrices_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); - } - } - return BLIS_SUCCESS; -} - -///////////////////////////// AX=B /////////////////////////////// -static void blis_strsm_microkernel_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alphaVal) -{ - float ones = 1.0; - int j; - int cs_b_offset[6]; - //int row2, row4, row6; - float *ptr_b_dup; - - //70 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_cols[8]; - __m256 mat_a_cols_rearr[36]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags; - __m256 alphaReg; - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - reciprocal_diags = _mm256_broadcast_ss((float const *)&ones); - alphaReg = _mm256_broadcast_ss((float const *)&alphaVal); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - //_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0); - //row2 = (cs_l << 1); - //row4 = (cs_l << 2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - //_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - //_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - //_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0); - //row6 = row2 + row4; - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - //_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - //_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - //_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - //_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0); - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - - //read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L - /*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/ - - //Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers - //tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually. - //mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]); - //1st col - mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0)); - mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1)); - mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2)); - mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3)); - mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4)); - mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5)); - mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6)); - mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7)); - //2nd col - ptr_l += cs_l; - mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //3rd col - ptr_l += cs_l; - mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //4rth col - ptr_l += cs_l; - mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //5th col - ptr_l += cs_l; - mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //6th col - ptr_l += cs_l; - mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //7th col - ptr_l += cs_l; - mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //7th col - ptr_l += cs_l; - mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - numCols_b -= 8; // blk_width = 8 - - //compute reciprocals of L(i,i) and broadcast in registers - mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]); - mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]); - mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]); - mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]); - - //mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55); - //mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55); - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC); - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20); - - //reciprocal of diagnol elements - reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]); - - //Start loop for cols of B to be processed in size of blk_width - for (j = 0; j < numCols_b; j += 8) - { - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Read next set of B columns - ptr_b += (cs_b + cs_b_offset[5]); - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - - //end loop of cols - } - - //Last block trsm processing - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - - //end loop of cols -} - -static void blis_strsm_microkernel_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alphaVal) -{ - //float ones = 1.0; - int j; - int cs_b_offset[6]; - //int row2, row4, row6; - float *ptr_b_dup; - - //70 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_cols[8]; - __m256 mat_a_cols_rearr[36]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags; - __m256 alphaReg; - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - //reciprocal_diags = _mm256_broadcast_ss((float const *)&ones); - alphaReg = _mm256_broadcast_ss((float const *)&alphaVal); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - //_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0); - //row2 = (cs_l << 1); - //row4 = (cs_l << 2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - //_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - //_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - //_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0); - //row6 = row2 + row4; - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - //_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - //_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - //_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - //_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0); - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - - //read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L - /*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/ - - //Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers - //tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually. - //mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]); - //1st col - mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0)); - mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1)); - mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2)); - mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3)); - mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4)); - mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5)); - mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6)); - mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7)); - //2nd col - ptr_l += cs_l; - mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //3rd col - ptr_l += cs_l; - mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //4rth col - ptr_l += cs_l; - mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //5th col - ptr_l += cs_l; - mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //6th col - ptr_l += cs_l; - mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //7th col - ptr_l += cs_l; - mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //8th col - //ptr_l += cs_l; - //mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - numCols_b -= 8; // blk_width = 8 - - //compute reciprocals of L(i,i) and broadcast in registers - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]); - - //mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55); - //mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55); - //mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC); - //mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC); - //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20); - - //reciprocal of diagnol elements - //reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]); - - //Start loop for cols of B to be processed in size of blk_width - for (j = 0; j < numCols_b; j += 8) - { - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - //mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - //mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); - - //extract diag a11 from a - //mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - //mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - //mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - //mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - //mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - //mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - //mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - //mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - //mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - //mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - //mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - //mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - //mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - //mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Read next set of B columns - ptr_b += (cs_b + cs_b_offset[5]); - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - - //end loop of cols - } - - //Last block trsm processing - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - //mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - //mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); - - //extract diag a11 from a - //mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - //mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - //mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - //mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - //mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - //mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - //mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - //mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - //mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - //mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - //mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - //mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - //mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - //mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - - //end loop of cols -} - -static void blis_strsm_microkernel_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - //float ones = 1.0; - int j; - int cs_b_offset[6]; - //int row2, row4, row6; - float *ptr_b_dup; - - //70 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_cols[8]; - __m256 mat_a_cols_rearr[36]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags; - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - //reciprocal_diags = _mm256_broadcast_ss((float const *)&ones); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - //_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0); - //row2 = (cs_l << 1); - //row4 = (cs_l << 2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - //_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - //_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - //_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0); - //row6 = row2 + row4; - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - //_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - //_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - //_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - //_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0); - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - - //read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L - /*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/ - - //Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers - //tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually. - //mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]); - //1st col - mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0)); - mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1)); - mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2)); - mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3)); - mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4)); - mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5)); - mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6)); - mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7)); - //2nd col - ptr_l += cs_l; - mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //3rd col - ptr_l += cs_l; - mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //4rth col - ptr_l += cs_l; - mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //5th col - ptr_l += cs_l; - mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //6th col - ptr_l += cs_l; - mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //7th col - ptr_l += cs_l; - mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //8th col - //ptr_l += cs_l; - //mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - numCols_b -= 8; // blk_width = 8 - - //compute reciprocals of L(i,i) and broadcast in registers - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]); - - //mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55); - //mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55); - //mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC); - //mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC); - //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20); - - //reciprocal of diagnol elements - //reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]); - - //Start loop for cols of B to be processed in size of blk_width - for (j = 0; j < numCols_b; j += 8) - { - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - //mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - //mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - //extract diag a11 from a - //mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - //mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - //mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - //mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - //mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - //mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - //mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - //mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - //mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - //mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - //mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - //mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - //mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - //mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Read next set of B columns - ptr_b += (cs_b + cs_b_offset[5]); - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - //end loop of cols - } - - //Last block trsm processing - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - //mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - //mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - //extract diag a11 from a - //mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - //mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - //mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - //mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - //mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - //mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - //mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - //mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - //mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - //mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - //mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - //mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - //mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - //mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - //end loop of cols -} - -static void blis_strsm_microkernel(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - float ones = 1.0; - int j; - int cs_b_offset[6]; - //int row2, row4, row6; - float *ptr_b_dup; - - //70 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_cols[8]; - __m256 mat_a_cols_rearr[36]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags; - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - reciprocal_diags = _mm256_broadcast_ss((float const *)&ones); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - //_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0); - //row2 = (cs_l << 1); - //row4 = (cs_l << 2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - //_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - //_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - //_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0); - //row6 = row2 + row4; - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - //_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - //_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - //_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - //_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0); - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - - //read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L - /*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/ - - //Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers - //tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually. - //mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]); - //1st col - mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0)); - mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1)); - mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2)); - mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3)); - mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4)); - mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5)); - mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6)); - mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7)); - //2nd col - ptr_l += cs_l; - mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //3rd col - ptr_l += cs_l; - mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //4rth col - ptr_l += cs_l; - mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //5th col - ptr_l += cs_l; - mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //6th col - ptr_l += cs_l; - mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //7th col - ptr_l += cs_l; - mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //7th col - ptr_l += cs_l; - mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - numCols_b -= 8; // blk_width = 8 - - //compute reciprocals of L(i,i) and broadcast in registers - mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]); - mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]); - mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]); - mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]); - - //mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55); - //mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55); - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC); - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20); - - //reciprocal of diagnol elements - reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]); - - //Start loop for cols of B to be processed in size of blk_width - for (j = 0; j < numCols_b; j += 8) - { - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Read next set of B columns - ptr_b += (cs_b + cs_b_offset[5]); - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - //end loop of cols - } - - //Last block trsm processing - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - //end loop of cols -} - -#if OPT_CACHE_BLOCKING_L1 //new intrinsic kernels -static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags[2]; - - reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - //read diag elems of L 16x16 block - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - reciprocal_diags[1] = reciprocal_diags[0]; - - //pack first 8 diags together - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - - //Read next 8x8 block of A to get diag elements - i3 += cs_l_offset[6]; - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); - - //pack 8 diags of A together - reciprocal_diags[0] = reciprocal_diags[1]; - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - i = i1 + r; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); -#endif - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - ptr_l_dup = ptr_l; - i4 = i2 + r; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - i4 = k >> 3; - ptr_l_dup += cs_l; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - } - i2 += cs_b_offset[6]; - i += cs_l_offset[6]; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - { - i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2)); -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - } - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} - -static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) -{ - float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags[2]; - __m256 alphaReg; - - reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); - alphaReg = _mm256_broadcast_ss((float const *)&alpha); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - //read diag elems of L 16x16 block - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - reciprocal_diags[1] = reciprocal_diags[0]; - - //pack first 8 diags together - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); -#if 0 - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); -#endif - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - - //Read next 8x8 block of A to get diag elements - i3 += cs_l_offset[6]; - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); - - //pack 8 diags of A together - reciprocal_diags[0] = reciprocal_diags[1]; - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - i = i1 + r; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); -#endif - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - ptr_l_dup = ptr_l; - i4 = i2 + r; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - i4 = k >> 3; - ptr_l_dup += cs_l; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - } - i2 += cs_b_offset[6]; - i += cs_l_offset[6]; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - { - i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2)); - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - - _mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - } - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} - -static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - //float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags[2]; - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - //(Row0) - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - i3 += cs_l_offset[6]; - - i = 0; - i2 = 0; - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - i = i1 + r; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); -#endif - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - ptr_l_dup = ptr_l; - i4 = i2 + r; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - i4 = k >> 3; - ptr_l_dup += cs_l; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - } - i2 += cs_b_offset[6]; - i += cs_l_offset[6]; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - { - i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2)); -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): already done -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - } - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} - -static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) -{ - //float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags[2]; - __m256 alphaReg; - alphaReg = _mm256_broadcast_ss((float const *)&alpha); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - -#if 0 - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); -#endif - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); - - //(Row0) - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - i3 += cs_l_offset[6]; - - i = 0; - i2 = 0; - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - i = i1 + r; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); -#endif - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - ptr_l_dup = ptr_l; - i4 = i2 + r; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - i4 = k >> 3; - ptr_l_dup += cs_l; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - } - i2 += cs_b_offset[6]; - i += cs_l_offset[6]; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - { - i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2)); - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): already done - -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - } - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} -#else //rel 1.0 intrisic kernels (NOT OPT_CACHE_BLOCKING_L1) -static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[16][8]; - __m256 mat_a_cols_rearr[8]; - __m256 mat_a_blk_elems[64]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags[2]; - - reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - //read diag elems of L 16x16 block - mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l); - mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); - mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); - mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); - mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); - mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); - mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); - mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - reciprocal_diags[1] = reciprocal_diags[0]; - - //pack first 8 diags together - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_col[0] = _mm256_mul_ps(mat_b_rearr[0][0], mat_a_diag_inv[0]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b) - mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_col[1] = _mm256_mul_ps(mat_b_rearr[1][0], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_col[2] = _mm256_mul_ps(mat_b_rearr[2][0], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_col[3] = _mm256_mul_ps(mat_b_rearr[3][0], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_col[4] = _mm256_mul_ps(mat_b_rearr[4][0], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_col[5] = _mm256_mul_ps(mat_b_rearr[5][0], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_col[6] = _mm256_mul_ps(mat_b_rearr[6][0], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_col[7] = _mm256_mul_ps(mat_b_rearr[7][0], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - - //Read next 8x8 block of A to get diag elements - i3 += cs_l_offset[6]; - mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l + i3); - mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); - mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); - mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); - mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); - mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); - mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); - mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); - - //pack 8 diags of A together - reciprocal_diags[0] = reciprocal_diags[1]; - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - i = 0; - i2 = 0; - for (k = 0; k < numCols_b; k += 8) - { - i = i1 + k; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - i2++; - } - - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4)); - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7)); - - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1)); - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5)); - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7)); - - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i)); - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2)); - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3)); - mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4)); - mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5)); - mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6)); - mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7)); - - // _mm256_permute2f128_ps() - - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i)); - mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1)); - mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2)); - mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3)); - mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4)); - mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5)); - mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6)); - mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7)); - - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i)); - mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1)); - mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2)); - mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3)); - mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4)); - mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5)); - mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6)); - mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7)); - - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i)); - mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1)); - mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2)); - mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3)); - mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4)); - mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5)); - mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6)); - mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7)); - - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i)); - mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1)); - mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2)); - mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3)); - mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4)); - mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5)); - mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6)); - mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7)); - - i += cs_l_offset[6]; - - - for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - - i4 = i2 + k; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - i4 = k >> 3; - - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b) - - //end loop of cols - } - i2 += cs_b_offset[6]; - } - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - k = 0; - for (i = 0; i < numCols_b; i+=8) - { - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[k][0] = _mm256_mul_ps(mat_b_rearr[k][0], mat_a_diag_inv[0]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[k][1] = _mm256_mul_ps(mat_b_rearr[k][1], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[k][2] = _mm256_mul_ps(mat_b_rearr[k][2], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[k][3] = _mm256_mul_ps(mat_b_rearr[k][3], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[k][4] = _mm256_mul_ps(mat_b_rearr[k][4], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[k][5] = _mm256_mul_ps(mat_b_rearr[k][5], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[k][6] = _mm256_mul_ps(mat_b_rearr[k][6], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[k][7] = _mm256_mul_ps(mat_b_rearr[k][7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - - _mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - } - - - } - ///////////////////loop ends ///////////////////// -} - -static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) -{ - float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[16][8]; - __m256 mat_a_cols_rearr[8]; - __m256 mat_a_blk_elems[64]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags[2]; - __m256 alphaReg; - - reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); - alphaReg = _mm256_broadcast_ss((float const *)&alpha); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - //read diag elems of L 16x16 block - mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l); - mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); - mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); - mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); - mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); - mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); - mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); - mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - reciprocal_diags[1] = reciprocal_diags[0]; - - //pack first 8 diags together - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_rearr[0][0] = _mm256_mul_ps(mat_b_rearr[0][0], alphaReg); - mat_b_rearr[1][0] = _mm256_mul_ps(mat_b_rearr[1][0], alphaReg); - mat_b_rearr[2][0] = _mm256_mul_ps(mat_b_rearr[2][0], alphaReg); - mat_b_rearr[3][0] = _mm256_mul_ps(mat_b_rearr[3][0], alphaReg); - mat_b_rearr[4][0] = _mm256_mul_ps(mat_b_rearr[4][0], alphaReg); - mat_b_rearr[5][0] = _mm256_mul_ps(mat_b_rearr[5][0], alphaReg); - mat_b_rearr[6][0] = _mm256_mul_ps(mat_b_rearr[6][0], alphaReg); - mat_b_rearr[7][0] = _mm256_mul_ps(mat_b_rearr[7][0], alphaReg); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_col[0] = _mm256_mul_ps(mat_b_rearr[0][0], mat_a_diag_inv[0]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b) - mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_col[1] = _mm256_mul_ps(mat_b_rearr[1][0], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_col[2] = _mm256_mul_ps(mat_b_rearr[2][0], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_col[3] = _mm256_mul_ps(mat_b_rearr[3][0], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_col[4] = _mm256_mul_ps(mat_b_rearr[4][0], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_col[5] = _mm256_mul_ps(mat_b_rearr[5][0], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_col[6] = _mm256_mul_ps(mat_b_rearr[6][0], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_col[7] = _mm256_mul_ps(mat_b_rearr[7][0], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - - //Read next 8x8 block of A to get diag elements - i3 += cs_l_offset[6]; - mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l + i3); - mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); - mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); - mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); - mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); - mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); - mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); - mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); - - //pack 8 diags of A together - reciprocal_diags[0] = reciprocal_diags[1]; - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - i = 0; - i2 = 0; - for (k = 0; k < numCols_b; k += 8) - { - i = i1 + k; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_rearr[i2][0] = _mm256_mul_ps(mat_b_rearr[i2][0], alphaReg); - mat_b_rearr[i2][1] = _mm256_mul_ps(mat_b_rearr[i2][1], alphaReg); - mat_b_rearr[i2][2] = _mm256_mul_ps(mat_b_rearr[i2][2], alphaReg); - mat_b_rearr[i2][3] = _mm256_mul_ps(mat_b_rearr[i2][3], alphaReg); - mat_b_rearr[i2][4] = _mm256_mul_ps(mat_b_rearr[i2][4], alphaReg); - mat_b_rearr[i2][5] = _mm256_mul_ps(mat_b_rearr[i2][5], alphaReg); - mat_b_rearr[i2][6] = _mm256_mul_ps(mat_b_rearr[i2][6], alphaReg); - mat_b_rearr[i2][7] = _mm256_mul_ps(mat_b_rearr[i2][7], alphaReg); - - i2++; - } - - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4)); - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7)); - - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1)); - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5)); - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7)); - - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i)); - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2)); - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3)); - mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4)); - mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5)); - mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6)); - mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7)); - - // _mm256_permute2f128_ps() - - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i)); - mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1)); - mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2)); - mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3)); - mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4)); - mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5)); - mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6)); - mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7)); - - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i)); - mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1)); - mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2)); - mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3)); - mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4)); - mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5)); - mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6)); - mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7)); - - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i)); - mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1)); - mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2)); - mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3)); - mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4)); - mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5)); - mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6)); - mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7)); - - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i)); - mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1)); - mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2)); - mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3)); - mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4)); - mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5)); - mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6)); - mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7)); - - i += cs_l_offset[6]; - - - for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - - i4 = i2 + k; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - i4 = k >> 3; - - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b) - - //end loop of cols - } - i2 += cs_b_offset[6]; - } - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - k = 0; - for (i = 0; i < numCols_b; i+=8) - { - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[k][0] = _mm256_mul_ps(mat_b_rearr[k][0], mat_a_diag_inv[0]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[k][1] = _mm256_mul_ps(mat_b_rearr[k][1], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[k][2] = _mm256_mul_ps(mat_b_rearr[k][2], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[k][3] = _mm256_mul_ps(mat_b_rearr[k][3], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[k][4] = _mm256_mul_ps(mat_b_rearr[k][4], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[k][5] = _mm256_mul_ps(mat_b_rearr[k][5], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[k][6] = _mm256_mul_ps(mat_b_rearr[k][6], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[k][7] = _mm256_mul_ps(mat_b_rearr[k][7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - - _mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]); - k++; - } - - - } - ///////////////////loop ends ///////////////////// -} - -static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - //float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[16][8]; - //__m256 mat_a_cols_rearr[8]; - __m256 mat_a_blk_elems[64]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags[2]; - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - //(Row0) - mat_b_col[0] = mat_b_rearr[0][0]; - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b) - mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b) - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - i3 += cs_l_offset[6]; - - i = 0; - i2 = 0; - for (k = 0; k < numCols_b; k += 8) - { - i = i1 + k; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - i2++; - } - - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4)); - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7)); - - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1)); - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5)); - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7)); - - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i)); - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2)); - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3)); - mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4)); - mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5)); - mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6)); - mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7)); - - // _mm256_permute2f128_ps() - - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i)); - mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1)); - mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2)); - mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3)); - mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4)); - mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5)); - mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6)); - mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7)); - - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i)); - mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1)); - mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2)); - mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3)); - mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4)); - mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5)); - mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6)); - mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7)); - - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i)); - mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1)); - mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2)); - mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3)); - mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4)); - mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5)); - mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6)); - mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7)); - - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i)); - mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1)); - mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2)); - mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3)); - mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4)); - mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5)); - mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6)); - mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7)); - - i += cs_l_offset[6]; - - for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - - i4 = i2 + k; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - i4 = k >> 3; - - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b) - - //end loop of cols - } - i2 += cs_b_offset[6]; - } - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - k = 0; - for (i = 0; i < numCols_b; i+=8) - { - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A - - //(Row0): already done - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b) - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - - _mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - } - - - } - ///////////////////loop ends ///////////////////// -} - -static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) -{ - //float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[16][8]; - //__m256 mat_a_cols_rearr[8]; - __m256 mat_a_blk_elems[64]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags[2]; - __m256 alphaReg; - alphaReg = _mm256_broadcast_ss((float const *)&alpha); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_rearr[0][0] = _mm256_mul_ps(mat_b_rearr[0][0], alphaReg); - mat_b_rearr[1][0] = _mm256_mul_ps(mat_b_rearr[1][0], alphaReg); - mat_b_rearr[2][0] = _mm256_mul_ps(mat_b_rearr[2][0], alphaReg); - mat_b_rearr[3][0] = _mm256_mul_ps(mat_b_rearr[3][0], alphaReg); - mat_b_rearr[4][0] = _mm256_mul_ps(mat_b_rearr[4][0], alphaReg); - mat_b_rearr[5][0] = _mm256_mul_ps(mat_b_rearr[5][0], alphaReg); - mat_b_rearr[6][0] = _mm256_mul_ps(mat_b_rearr[6][0], alphaReg); - mat_b_rearr[7][0] = _mm256_mul_ps(mat_b_rearr[7][0], alphaReg); - - //(Row0) - mat_b_col[0] = mat_b_rearr[0][0]; - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b) - mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b) - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - i3 += cs_l_offset[6]; - - i = 0; - i2 = 0; - for (k = 0; k < numCols_b; k += 8) - { - i = i1 + k; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_rearr[i2][0] = _mm256_mul_ps(mat_b_rearr[i2][0], alphaReg); - mat_b_rearr[i2][1] = _mm256_mul_ps(mat_b_rearr[i2][1], alphaReg); - mat_b_rearr[i2][2] = _mm256_mul_ps(mat_b_rearr[i2][2], alphaReg); - mat_b_rearr[i2][3] = _mm256_mul_ps(mat_b_rearr[i2][3], alphaReg); - mat_b_rearr[i2][4] = _mm256_mul_ps(mat_b_rearr[i2][4], alphaReg); - mat_b_rearr[i2][5] = _mm256_mul_ps(mat_b_rearr[i2][5], alphaReg); - mat_b_rearr[i2][6] = _mm256_mul_ps(mat_b_rearr[i2][6], alphaReg); - mat_b_rearr[i2][7] = _mm256_mul_ps(mat_b_rearr[i2][7], alphaReg); - - i2++; - } - - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4)); - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7)); - - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1)); - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5)); - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7)); - - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i)); - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2)); - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3)); - mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4)); - mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5)); - mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6)); - mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7)); - - // _mm256_permute2f128_ps() - - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i)); - mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1)); - mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2)); - mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3)); - mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4)); - mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5)); - mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6)); - mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7)); - - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i)); - mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1)); - mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2)); - mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3)); - mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4)); - mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5)); - mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6)); - mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7)); - - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i)); - mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1)); - mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2)); - mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3)); - mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4)); - mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5)); - mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6)); - mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7)); - - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i)); - mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1)); - mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2)); - mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3)); - mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4)); - mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5)); - mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6)); - mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7)); - - i += cs_l_offset[6]; - - for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - - i4 = i2 + k; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - i4 = k >> 3; - - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b) - - //end loop of cols - } - i2 += cs_b_offset[6]; - } - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - k = 0; - for (i = 0; i < numCols_b; i+=8) - { - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A - - //(Row0): already done - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b) - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - - _mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - } - - - } - ///////////////////loop ends ///////////////////// -} -#endif //OPT_CACHE_BLOCKING_L1 - -//////////////////////////// AutX=B /////////////////////// -static void trsm_AutXB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags[2]; - - reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - //read diag elems of L 16x16 block - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - reciprocal_diags[1] = reciprocal_diags[0]; - - //pack first 8 diags together - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); -#if 0 - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); -#endif - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5])); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5])); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5])); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5])); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5])); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5])); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5])); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]); - - i += cs_b_offset[6]; - ptr_b_dup += cs_b_offset[6]; - //i += 8; - //ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += cs_l_offset[6]; - - //Read next 8x8 block of A to get diag elements - i3 += 8; - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); - - //pack 8 diags of A together - reciprocal_diags[0] = reciprocal_diags[1]; - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += 8; - i1 += 8; - i = i1; - i2 = 0; - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ -#endif - //i = 0; - ptr_l_dup = ptr_l; - i4 = i2; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - //{ - /////////////////// Partial Lower 8x8 block trsm of B - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44); - mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE); - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE); -#else - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E); - mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC); - mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33); - mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC); - mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31); - /* transpose steps end */ - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i4 = k >> 3; - ptr_l_dup++; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - //} - //i2 += cs_b_offset[6]; - i4 += 8; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - //{ - //i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5])); - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - //} - i += cs_b_offset[6]; - i2 += cs_b_offset[6]; - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} - -static void trsm_AutXB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) -{ - float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags[2]; - __m256 alphaReg; - - reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); - alphaReg = _mm256_broadcast_ss((float const *)&alpha); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - //read diag elems of L 16x16 block - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - reciprocal_diags[1] = reciprocal_diags[0]; - - //pack first 8 diags together - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); -#if 0 - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); -#endif - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5])); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5])); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5])); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5])); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5])); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5])); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5])); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]); - - i += cs_b_offset[6]; - ptr_b_dup += cs_b_offset[6]; - //i += 8; - //ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += cs_l_offset[6]; - - //Read next 8x8 block of A to get diag elements - i3 += 8; - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); - - //pack 8 diags of A together - reciprocal_diags[0] = reciprocal_diags[1]; - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += 8; - i1 += 8; - i = i1; - i2 = 0; - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); -#endif - - //i = 0; - ptr_l_dup = ptr_l; - i4 = i2; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - //{ - /////////////////// Partial Lower 8x8 block trsm of B - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44); - mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE); - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE); -#else - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E); - mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC); - mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33); - mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC); - mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31); - /* transpose steps end */ - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i4 = k >> 3; - ptr_l_dup++; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - //} - //i2 += cs_b_offset[6]; - i4 += 8; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - //{ - //i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5])); - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - //} - i += cs_b_offset[6]; - i2 += cs_b_offset[6]; - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} - -static void trsm_AutXB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - //float ones = 1.0; - int i, i1, i2, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags[2]; - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - -#if 0 - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); -#endif - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - - //(Row0) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5])); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5])); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5])); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5])); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5])); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5])); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5])); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]); - - i += cs_b_offset[6]; - ptr_b_dup += cs_b_offset[6]; - //i += 8; - //ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += cs_l_offset[6]; - - - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += 8; - i1 += 8; - i = i1; - i2 = 0; - - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ -#endif - - //i = 0; - ptr_l_dup = ptr_l; - i4 = i2; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - //{ - /////////////////// Partial Lower 8x8 block trsm of B - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44); - mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE); - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE); -#else - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E); - mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC); - mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33); - mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC); - mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31); - /* transpose steps end */ - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i4 = k >> 3; - ptr_l_dup++; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - //} - //i2 += cs_b_offset[6]; - i4 += 8; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - //{ - //i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): already done - -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5])); - - - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - //} - i += cs_b_offset[6]; - i2 += cs_b_offset[6]; - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} - -static void trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) -{ - //float ones = 1.0; - int i, i1, i2, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags[2]; - __m256 alphaReg; - alphaReg = _mm256_broadcast_ss((float const *)&alpha); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - -#if 0 - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); -#endif - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); - - //(Row0) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5])); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5])); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5])); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5])); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5])); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5])); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5])); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]); - - i += cs_b_offset[6]; - ptr_b_dup += cs_b_offset[6]; - //i += 8; - //ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += cs_l_offset[6]; - - - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += 8; - i1 += 8; - i = i1; - i2 = 0; - - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); -#endif - - //i = 0; - ptr_l_dup = ptr_l; - i4 = i2; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - //{ - /////////////////// Partial Lower 8x8 block trsm of B - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44); - mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE); - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE); -#else - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E); - mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC); - mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33); - mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC); - mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31); - /* transpose steps end */ - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i4 = k >> 3; - ptr_l_dup++; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - //} - //i2 += cs_b_offset[6]; - i4 += 8; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - //{ - //i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): already done - -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5])); - - - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - //} - i += cs_b_offset[6]; - i2 += cs_b_offset[6]; - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} -#endif +#endif //BLIS_ENABLE_SMALL_MATRIX_TRSM diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 9f3888340..e15b81e9c 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -260,3 +260,13 @@ void bli_dgemm_ref_k1_nn double* c, const inc_t ldc ); + err_t bli_trsm_small + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); +