diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index ee4c07a49..b7edac319 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -46,43 +46,43 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // XA = B; A is lower-traingular; No transpose; double precision; non-unit diagonal static err_t bli_dtrsm_small_XAlB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); //XA = B; A is lower triabgular; No transpose; double precision; unit-diagonal static err_t bli_dtrsm_small_XAlB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); //XA = B; A is lower-triangular; A is transposed; double precision; non-unit-diagonal static err_t bli_dtrsm_small_XAltB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); //XA = B; A is lower-triangular; A is transposed; double precision; unit-diagonal static err_t bli_dtrsm_small_XAltB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); // XA = B; A is upper triangular; No transpose; double presicion; non-unit diagonal static err_t bli_dtrsm_small_XAuB @@ -97,53 +97,53 @@ static err_t bli_dtrsm_small_XAuB //XA = B; A is upper triangular; No transpose; double precision; unit-diagonal static err_t bli_dtrsm_small_XAuB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); //XA = B; A is upper-triangular; A is transposed; double precision; non-unit diagonal static err_t bli_dtrsm_small_XAutB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); //XA = B; A is upper-triangular; A is transposed; double precision; unit diagonal static err_t bli_dtrsm_small_XAutB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); //AX = B; A is lower triangular; No transpose; double precision; non-unit diagonal static err_t bli_dtrsm_small_AlXB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); //AX = B; A is lower triangular; No transpose; double precision; unit diagonal static err_t bli_dtrsm_small_AlXB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ); + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); @@ -447,141 +447,141 @@ err_t bli_trsm_small // In the below implementation, based on the number of finally implemented // cases, can move the checks with more cases higher up. - if(side == BLIS_LEFT) - { - if(bli_obj_has_trans(a)) - { - if(dt == BLIS_DOUBLE) - { - if(bli_obj_is_upper(a)) - { - //return bli_dtrsm_small_AutXB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - else - { - //return bli_dtrsm_small_AltXB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - } - else - { - if(bli_obj_is_upper(a)) - { - return bli_strsm_small_AutXB(side, alpha, a, b, cntx, cntl); - } - else - { - //return bli_strsm_small_AltXB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } + if(side == BLIS_LEFT) + { + if(bli_obj_has_trans(a)) + { + if(dt == BLIS_DOUBLE) + { + if(bli_obj_is_upper(a)) + { + //return bli_dtrsm_small_AutXB(side, alpha, a, b, cntx, cntl); + return BLIS_NOT_YET_IMPLEMENTED; + } + else + { + //return bli_dtrsm_small_AltXB(side, alpha, a, b, cntx, cntl); + return BLIS_NOT_YET_IMPLEMENTED; + } + } + else + { + if(bli_obj_is_upper(a)) + { + return bli_strsm_small_AutXB(side, alpha, a, b, cntx, cntl); + } + else + { + //return bli_strsm_small_AltXB(side, alpha, a, b, cntx, cntl); + return BLIS_NOT_YET_IMPLEMENTED; + } - } - } - else - { - if(dt == BLIS_DOUBLE) - { - if(bli_obj_is_upper(a)) - { - //return bli_dtrsm_small_AuXB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - else - { + } + } + else + { + if(dt == BLIS_DOUBLE) + { + if(bli_obj_is_upper(a)) + { + //return bli_dtrsm_small_AuXB(side, alpha, a, b, cntx, cntl); + return BLIS_NOT_YET_IMPLEMENTED; + } + else + { if(bli_obj_has_unit_diag(a)) return bli_dtrsm_small_AlXB_unitDiag(side, alpha, a, b, cntx, cntl); else - return bli_dtrsm_small_AlXB(side, alpha, a, b, cntx, cntl); - } - } - else - { - if(bli_obj_is_upper(a)) - { - //return bli_strsm_small_AuXB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - else - { - return bli_strsm_small_AlXB(side, alpha, a, b, cntx, cntl); - } + return bli_dtrsm_small_AlXB(side, alpha, a, b, cntx, cntl); + } + } + else + { + if(bli_obj_is_upper(a)) + { + //return bli_strsm_small_AuXB(side, alpha, a, b, cntx, cntl); + return BLIS_NOT_YET_IMPLEMENTED; + } + else + { + return bli_strsm_small_AlXB(side, alpha, a, b, cntx, cntl); + } - } + } - } - } - else - { - if(bli_obj_has_trans(a)) - { - if(dt == BLIS_DOUBLE) - { - if(bli_obj_is_upper(a)) - { + } + } + else + { + if(bli_obj_has_trans(a)) + { + if(dt == BLIS_DOUBLE) + { + if(bli_obj_is_upper(a)) + { if(bli_obj_has_unit_diag(a)) - return bli_dtrsm_small_XAutB_unitDiag(side, alpha, a, b, cntx, cntl); + return bli_dtrsm_small_XAutB_unitDiag(side, alpha, a, b, cntx, cntl); else - return bli_dtrsm_small_XAutB(side, alpha, a, b, cntx, cntl); - } - else - { + return bli_dtrsm_small_XAutB(side, alpha, a, b, cntx, cntl); + } + else + { if(bli_obj_has_unit_diag(a)) return bli_dtrsm_small_XAltB_unitDiag(side, alpha, a, b, cntx, cntl); else - return bli_dtrsm_small_XAltB(side, alpha, a, b, cntx, cntl); - } - } - else - { - if(bli_obj_is_upper(a)) - { - //return bli_strsm_small_XAutB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - else - { - return bli_strsm_small_XAltB(side, alpha, a, b, cntx, cntl); - } + return bli_dtrsm_small_XAltB(side, alpha, a, b, cntx, cntl); + } + } + else + { + if(bli_obj_is_upper(a)) + { + //return bli_strsm_small_XAutB(side, alpha, a, b, cntx, cntl); + return BLIS_NOT_YET_IMPLEMENTED; + } + else + { + return bli_strsm_small_XAltB(side, alpha, a, b, cntx, cntl); + } - } - } - else - { - if(dt == BLIS_DOUBLE) - { - if(bli_obj_is_upper(a)) - { - if(bli_obj_has_unit_diag(a)) - return bli_dtrsm_small_XAuB_unitDiag(side, alpha, a, b, cntx, cntl); - else - return bli_dtrsm_small_XAuB(side, alpha, a, b, cntx, cntl); - } - else - { + } + } + else + { + if(dt == BLIS_DOUBLE) + { + if(bli_obj_is_upper(a)) + { + if(bli_obj_has_unit_diag(a)) + return bli_dtrsm_small_XAuB_unitDiag(side, alpha, a, b, cntx, cntl); + else + return bli_dtrsm_small_XAuB(side, alpha, a, b, cntx, cntl); + } + else + { if(bli_obj_has_unit_diag(a)) return bli_dtrsm_small_XAlB_unitDiag(side, alpha, a, b, cntx, cntl); else - return bli_dtrsm_small_XAlB(side, alpha, a, b, cntx, cntl); - } - } - else - { - if(bli_obj_is_upper(a)) - { - //return bli_strsm_small_XAuB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } - else - { - //return bli_strsm_small_XAlB(side, alpha, a, b, cntx, cntl); - return BLIS_NOT_YET_IMPLEMENTED; - } + return bli_dtrsm_small_XAlB(side, alpha, a, b, cntx, cntl); + } + } + else + { + if(bli_obj_is_upper(a)) + { + //return bli_strsm_small_XAuB(side, alpha, a, b, cntx, cntl); + return BLIS_NOT_YET_IMPLEMENTED; + } + else + { + //return bli_strsm_small_XAlB(side, alpha, a, b, cntx, cntl); + return BLIS_NOT_YET_IMPLEMENTED; + } - } + } - } - } + } + } return BLIS_NOT_YET_IMPLEMENTED; }; @@ -652,27 +652,27 @@ static err_t dtrsm_small_AlXB_unitDiag ( * Dimensions: X:mxn A:nxn B:mxn */ static err_t dtrsm_small_XAuB ( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb ) { dim_t i, j, k; for(k = 0; k < N; k++) { - double lkk_inv = 1.0/A[k+k*lda]; - for(i = 0; i < M; i++) - { - B[i+k*ldb] *= lkk_inv; - for(j = k+1; j < N; j++) - { - B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda]; - } - } + double lkk_inv = 1.0/A[k+k*lda]; + for(i = 0; i < M; i++) + { + B[i+k*ldb] *= lkk_inv; + for(j = k+1; j < N; j++) + { + B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda]; + } + } } return BLIS_SUCCESS; @@ -684,13 +684,13 @@ return BLIS_SUCCESS; */ static err_t dtrsm_small_XAlB ( - double *A, - double *B, - double alpha, + double *A, + double *B, + double alpha, dim_t M, - dim_t N, - dim_t lda, - dim_t ldb + dim_t N, + dim_t lda, + dim_t ldb ) { @@ -701,15 +701,15 @@ static err_t dtrsm_small_XAlB ( for(k = N-1; k+1 > 0; k--) { - double lkk_inv = 1.0/A[k+k*lda]; - for(i = M-1; i+1 > 0; i--) - { - B[i+k*ldb] *= lkk_inv; - for(j = k-1; j+1 > 0; j--) - { - B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda]; - } - } + double lkk_inv = 1.0/A[k+k*lda]; + for(i = M-1; i+1 > 0; i--) + { + B[i+k*ldb] *= lkk_inv; + for(j = k-1; j+1 > 0; j--) + { + B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda]; + } + } } return BLIS_SUCCESS; } @@ -719,13 +719,13 @@ return BLIS_SUCCESS; *Dimensions: X:mxn A:nxn B:mxn */ static err_t dtrsm_small_XAlB_unitDiag( - double *A, - double *B, + double *A, + double *B, double alpha, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb ) { @@ -737,13 +737,13 @@ static err_t dtrsm_small_XAlB_unitDiag( for(k = N-1; k+1 > 0; k--) { - for(i = M-1; i+1 > 0; i--) - { - for(j = k-1; j+1 > 0; j--) - { - B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda]; - } - } + for(i = M-1; i+1 > 0; i--) + { + for(j = k-1; j+1 > 0; j--) + { + B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda]; + } + } } return BLIS_SUCCESS; } @@ -753,13 +753,13 @@ return BLIS_SUCCESS; * Dimensions: X:mxn A:nxn B:mxn */ static err_t dtrsm_small_XAutB ( - double *A, - double *B, - double alpha, + double *A, + double *B, + double alpha, dim_t M, - dim_t N, - dim_t lda, - dim_t ldb + dim_t N, + dim_t lda, + dim_t ldb ) { @@ -772,14 +772,14 @@ static err_t dtrsm_small_XAutB ( for(k = N-1; k+1 > 0; k--) { double lkk_inv = 1.0/A[k+k*lda]; - for(i = M-1; i+1 > 0; i--) - { - B[i+k*ldb] *= lkk_inv; - for(j = k-1; j+1 > 0; j--) - { - B[i+j*ldb] -= B[i+k*ldb] * A[j+k*lda]; - } - } + for(i = M-1; i+1 > 0; i--) + { + B[i+k*ldb] *= lkk_inv; + for(j = k-1; j+1 > 0; j--) + { + B[i+j*ldb] -= B[i+k*ldb] * A[j+k*lda]; + } + } } return BLIS_SUCCESS; } @@ -789,13 +789,13 @@ return BLIS_SUCCESS; * Dimensions: X:mxn A:nxn B:mxn */ static err_t dtrsm_small_XAutB_unitDiag( - double *A, - double *B, + double *A, + double *B, double alpha, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb ) { @@ -807,14 +807,14 @@ static err_t dtrsm_small_XAutB_unitDiag( for(i = M-1; i+1 > 0; i--) { - for(j = N-1; j+1 > 0; j--) - { - for(k = j-1; k+1 > 0; k--) - { - B[i+k*ldb] -= B[i+j*ldb] * A[k+j*lda]; + for(j = N-1; j+1 > 0; j--) + { + for(k = j-1; k+1 > 0; k--) + { + B[i+k*ldb] -= B[i+j*ldb] * A[k+j*lda]; - } - } + } + } } return BLIS_SUCCESS; } @@ -824,12 +824,12 @@ return BLIS_SUCCESS; * Dimensions: X:mxn A:nxn B:mxn */ static err_t dtrsm_small_XAltB ( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb ) { @@ -837,15 +837,15 @@ static err_t dtrsm_small_XAltB ( for(k = 0; k < N; k++) { - double lkk_inv = 1.0/A[k+k*lda]; - for(i = 0; i < M; i++) - { - B[i+k*ldb] *= lkk_inv; - for(j = k+1; j < N; j++) - { - B[i+j*ldb] -= B[i+k*ldb] * A[j+k*lda]; - } - } + double lkk_inv = 1.0/A[k+k*lda]; + for(i = 0; i < M; i++) + { + B[i+k*ldb] *= lkk_inv; + for(j = k+1; j < N; j++) + { + B[i+j*ldb] -= B[i+k*ldb] * A[j+k*lda]; + } + } } return BLIS_SUCCESS; } @@ -855,12 +855,12 @@ return BLIS_SUCCESS; * Dimensions: X:mxn A:nxn B:mxn */ static err_t dtrsm_small_XAltB_unitDiag( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb ) { @@ -869,12 +869,12 @@ static err_t dtrsm_small_XAltB_unitDiag( for(k = 0; k < N; k++) { for(i = 0; i < M; i++) - { - for(j = k+1; j < N; j++) - { - B[i+j*ldb] -= B[i+k*ldb] * A[j+k*lda]; - } - } + { + for(j = k+1; j < N; j++) + { + B[i+j*ldb] -= B[i+k*ldb] * A[j+k*lda]; + } + } } return BLIS_SUCCESS; } @@ -884,12 +884,12 @@ return BLIS_SUCCESS; * Dimensions: X:mxn A:nxn B:mxn */ static err_t dtrsm_small_XAuB_unitDiag ( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb ) { @@ -897,13 +897,13 @@ static err_t dtrsm_small_XAuB_unitDiag ( for(k = 0; k < N; k++) { - for(i = 0; i < M; i++) - { - for(j = k+1; j < N; j++) - { - B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda]; - } - } + for(i = 0; i < M; i++) + { + for(j = k+1; j < N; j++) + { + B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda]; + } + } } return BLIS_SUCCESS; } @@ -932,13 +932,13 @@ a10 ****** b11 ***************** a11---> */ static err_t bli_dtrsm_small_AlXB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) { dim_t D_MR = 4; //size of block along 'M' dimpension @@ -946,7 +946,7 @@ static err_t bli_dtrsm_small_AlXB( dim_t m = bli_obj_length(b); // number of rows of matrix B dim_t n = bli_obj_width(b); // number of columns of matrix B - +/* #ifdef BLIS_ENABLE_SMALL_MATRIX_ROME if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME) { @@ -958,7 +958,7 @@ static err_t bli_dtrsm_small_AlXB( return BLIS_NOT_YET_IMPLEMENTED; } #endif - +*/ dim_t m_remainder = m % D_MR; //number of remainder rows dim_t n_remainder = n % D_NR; //number of remainder columns @@ -990,10 +990,10 @@ static err_t bli_dtrsm_small_AlXB( { for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' dimension { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4) @@ -1012,7 +1012,7 @@ static err_t bli_dtrsm_small_AlXB( { ptr_b01_dup = b01; - ymm16 = _mm256_loadu_pd((double const *)(a10));//A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm16 = _mm256_loadu_pd((double const *)(a10));//A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] @@ -1024,19 +1024,19 @@ static err_t bli_dtrsm_small_AlXB( ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] - b01 += 1; //mobe to next row of B + b01 += 1; //mobe to next row of B - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a));//A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a));//A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] @@ -1048,19 +1048,19 @@ static err_t bli_dtrsm_small_AlXB( ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] - b01 += 1; //mobe to next row of B + b01 += 1; //mobe to next row of B - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2));//A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2));//A10[0][2] A10[1][2] A10[2][2] A10[3][2] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] @@ -1072,19 +1072,19 @@ static err_t bli_dtrsm_small_AlXB( ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] - b01 += 1; //mobe to next row of B + b01 += 1; //mobe to next row of B - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][4]*A10[3][2]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][5]*A10[3][2]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][6]*A10[3][2]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][7]*A10[3][2]) + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][4]*A10[3][2]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][5]*A10[3][2]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][6]*A10[3][2]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][7]*A10[3][2]) - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3));//A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3));//A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] @@ -1096,20 +1096,20 @@ static err_t bli_dtrsm_small_AlXB( ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] - b01 += 1; //mobe to next row of B + b01 += 1; //mobe to next row of B - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[3][0] B01[3][0]*A10[2][3] B01[3][0]*A10[3][0]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[3][1]*A10[0][3] B01[3][1]*A10[3][0] B01[3][1]*A10[2][3] B01[3][1]*A10[3][0]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[3][2]*A10[0][3] B01[3][2]*A10[3][0] B01[3][2]*A10[2][3] B01[3][2]*A10[3][0]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[3][3]*A10[0][3] B01[3][3]*A10[3][0] B01[3][3]*A10[2][3] B01[3][3]*A10[3][0]) + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[3][0] B01[3][0]*A10[2][3] B01[3][0]*A10[3][0]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[3][1]*A10[0][3] B01[3][1]*A10[3][0] B01[3][1]*A10[2][3] B01[3][1]*A10[3][0]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[3][2]*A10[0][3] B01[3][2]*A10[3][0] B01[3][2]*A10[2][3] B01[3][2]*A10[3][0]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[3][3]*A10[0][3] B01[3][3]*A10[3][0] B01[3][3]*A10[2][3] B01[3][3]*A10[3][0]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[3][4]*A10[0][3] B01[3][4]*A10[3][0] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[3][5]*A10[0][3] B01[3][5]*A10[3][0] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[3][6]*A10[0][3] B01[3][6]*A10[3][0] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[3][7]*A10[0][3] B01[3][7]*A10[3][0] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[3][4]*A10[0][3] B01[3][4]*A10[3][0] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[3][5]*A10[0][3] B01[3][5]*A10[3][0] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[3][6]*A10[0][3] B01[3][6]*A10[3][0] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[3][7]*A10[0][3] B01[3][7]*A10[3][0] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) - a10 += D_MR * cs_a; //pointer math to calculate next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM + a10 += D_MR * cs_a; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM } ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -1246,50 +1246,56 @@ static err_t bli_dtrsm_small_AlXB( //unpacklow// ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store B11[6][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store B11[7][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store B11[7][0-3] } - if(m_remainder) //implementation for reamainder rows(when 'M' is not a multiple of D_MR) + if(m_remainder) //implementation for reamainder rows(when 'M' is not a multiple of D_MR) { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - k_iter = i / D_MR; //number of times GEMM operation to be done(in blocks of 4x4) + k_iter = i / D_MR; //number of times GEMM operation to be done(in blocks of 4x4) + + double f_temp[4]; + int iter; + + for(iter = 0; iter < m_remainder; iter++) + f_temp[iter] = (b11 + cs_b * 7)[iter]; ymm8 = _mm256_setzero_pd(); ymm9 = _mm256_setzero_pd(); @@ -1301,284 +1307,262 @@ static err_t bli_dtrsm_small_AlXB( ymm15 = _mm256_setzero_pd(); ///GEMM code Begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { ptr_b01_dup = b01; - ymm16 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm16 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] - b01 += 1; //move to next row of B + b01 += 1; //move to next row of B - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0] ) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0] ) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm16 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm16 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 1)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 1)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] - b01 += 1; //move to next row of B01 + b01 += 1; //move to next row of B01 - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] //A10[1][2] A10[2][2] A10[3][2] + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] //A10[1][2] A10[2][2] A10[3][2] - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] - b01 += 1; //move to next row of B + b01 += 1; //move to next row of B - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][0]*A10[3][2]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][1]*A10[3][2]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][2]*A10[3][2]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][3]*A10[3][2]) + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][0]*A10[3][2]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][1]*A10[3][2]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][2]*A10[3][2]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][3]*A10[3][2]) - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] - b01 += 1; //move to next row of B + b01 += 1; //move to next row of B - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm8 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm8 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm8 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm8 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm8 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm8 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm8 += (B01[3][0]*A10[0][3] B01[3][4]*A10[1][3] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm8 += (B01[3][1]*A10[0][3] B01[3][5]*A10[1][3] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm8 += (B01[3][2]*A10[0][3] B01[3][6]*A10[1][3] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm8 += (B01[3][3]*A10[0][3] B01[3][7]*A10[1][3] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm8 += (B01[3][0]*A10[0][3] B01[3][4]*A10[1][3] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm8 += (B01[3][1]*A10[0][3] B01[3][5]*A10[1][3] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm8 += (B01[3][2]*A10[0][3] B01[3][6]*A10[1][3] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm8 += (B01[3][3]*A10[0][3] B01[3][7]*A10[1][3] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } + ///GEMM code ends/// - ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *7)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] *alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] *alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] *alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] *alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] *alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] *alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] *alpha -= B01[0-3][6] + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] *alpha -= B01[0-3][7] - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] *alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] *alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] *alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] *alpha -= B01[0-3][3] - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] *alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] *alpha -= B01[0-3][5] - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] *alpha -= B01[0-3][6] - ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] *alpha -= B01[0-3][7] - - ///implement TRSM/// - - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - - - ymm0 = _mm256_broadcast_sd((double const *)&ones); - - //broadcast diagonal elements of A11 - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_b +1)); //A11[1][1] - ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_b*2 + 2)); //A11[2][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+cs_b*3 + 3)); //A11[3][3] - - ymm5 = _mm256_unpacklo_pd(ymm1, ymm2); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] - ymm6 = _mm256_unpacklo_pd(ymm3, ymm4); //A11[2][2] A11[2][2] A11[3][3] A11[3][3] - - ymm5 = _mm256_blend_pd(ymm5, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm0 = _mm256_div_pd(ymm0, ymm5); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - - //extract a00 - ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] - - //(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0] - ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0] - - //extract a11 - ymm1 = _mm256_permute_pd(ymm0, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] - ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] - - ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] - ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0] - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][0] - - a11 += cs_a; - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= B11[0-3][0]*A11[2][0] - ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); //B11[3][0-3] -= B11[0-3][0]*A11[3][0] - - ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= B11[0-3][4]*A11[2][4] - ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); //B11[7][0-3] -= B11[0-3][4]*A11[3][4] - - ymm9 = _mm256_mul_pd(ymm9, ymm1); //B11[0-3][1] /= A11[1][1] - ymm13 = _mm256_mul_pd(ymm13, ymm1); //B11[0-3][5] /= A11[1][1] - - ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1] - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][1] - - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] - - //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1] - ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); //B11[3][0-3] -= A11[3][1] * B11[0-3][1] - - ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5] - ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); //B11[7][0-3] -= A11[3][1] * B11[0-3][5] - - //perform mul operation - ymm10 = _mm256_mul_pd(ymm10, ymm1); //B11[0-3][2] /=A11[2][2] - ymm14 = _mm256_mul_pd(ymm14, ymm1); //B11[0-3][6] /= A11[2][2] - - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][2] - - a11 += cs_a; - - //extract a33 - ymm1 = _mm256_permute_pd(ymm0, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] - ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3] - - //(ROw2): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); //B11[0-3][3] -= A11[3][2]*B11[0-3][2] - - ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); //B11[0-3][7] -= A11[3][2]*B11[0-3][6] - - //perform mul operation - ymm11 = _mm256_mul_pd(ymm11, ymm1); //B11[0-3][3] /= A11[3][3] - ymm15 = _mm256_mul_pd(ymm15, ymm1); //B11[0-3][7] /= A11[3][3] - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[2][5] B11[3][5] B11[2][7] B11[3][7] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - - ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] - ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b * 7)); //load B11[0-3][7] - //determine correct values to store if(m_remainder == 3) { + ///implement TRSM/// + + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //broadcast diagonal elements of A11 + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_b +1)); //A11[1][1] + ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_b*2 + 2)); //A11[2][2] + + ymm5 = _mm256_unpacklo_pd(ymm1, ymm2); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] + ymm6 = _mm256_unpacklo_pd(ymm3, ymm0); //A11[2][2] A11[2][2] A11[3][3] A11[3][3] + + ymm5 = _mm256_blend_pd(ymm5, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm0 = _mm256_div_pd(ymm0, ymm5); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + //extract a00 + ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] + + //(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0] + ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0] + + //extract a11 + ymm1 = _mm256_permute_pd(ymm0, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] + ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] + + ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] + ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0] + + a11 += cs_a; + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0] + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= B11[0-3][0]*A11[2][0] + + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4] + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= B11[0-3][4]*A11[2][4] + + ymm9 = _mm256_mul_pd(ymm9, ymm1); //B11[0-3][1] /= A11[1][1] + ymm13 = _mm256_mul_pd(ymm13, ymm1); //B11[0-3][5] /= A11[1][1] + + ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1] + + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1] + + ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5] + + //perform mul operation + ymm10 = _mm256_mul_pd(ymm10, ymm1); //B11[0-3][2] /=A11[2][2] + ymm14 = _mm256_mul_pd(ymm14, ymm1); //B11[0-3][6] /= A11[2][2] + + ymm11 = _mm256_broadcast_sd((double const *)(&ones)); + ymm15 = _mm256_broadcast_sd((double const *)(&ones)); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[2][5] B11[3][5] B11[2][7] B11[3][7] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] + ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] + ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] + + //determine correct values to store ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x08); ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x08); ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x08); @@ -1587,9 +1571,111 @@ static err_t bli_dtrsm_small_AlXB( ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x08); ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x08); ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x08); + } if(m_remainder == 2) { + ///implement TRSM/// + + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //broadcast diagonal elements of A11 + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_b +1)); //A11[1][1] + + ymm5 = _mm256_unpacklo_pd(ymm1, ymm2); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] + + ymm5 = _mm256_blend_pd(ymm5, ymm0, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm0 = _mm256_div_pd(ymm0, ymm5); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + //extract a00 + ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] + + //(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0] + ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0] + + //extract a11 + ymm1 = _mm256_permute_pd(ymm0, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] + ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] + + ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] + + a11 += cs_a; + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0] + + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4] + + ymm9 = _mm256_mul_pd(ymm9, ymm1); //B11[0-3][1] /= A11[1][1] + ymm13 = _mm256_mul_pd(ymm13, ymm1); //B11[0-3][5] /= A11[1][1] + + ymm10 = _mm256_broadcast_sd((double const *)&ones); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm10, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm10, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm10, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm10, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm10, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm10, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm10, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm10, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] + ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] + ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] + + //determine correct values to store ymm0 = _mm256_permute2f128_pd(ymm0, ymm8, 0x30); ymm1 = _mm256_permute2f128_pd(ymm1, ymm9, 0x30); ymm2 = _mm256_permute2f128_pd(ymm2, ymm10, 0x30); @@ -1598,9 +1684,91 @@ static err_t bli_dtrsm_small_AlXB( ymm5 = _mm256_permute2f128_pd(ymm5, ymm13, 0x30); ymm6 = _mm256_permute2f128_pd(ymm6, ymm14, 0x30); ymm7 = _mm256_permute2f128_pd(ymm7, ymm15, 0x30); + } if(m_remainder == 1) { + ///implement TRSM/// + + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //broadcast diagonal elements of A11 + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + ymm0 = _mm256_div_pd(ymm0, ymm1); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + //extract a00 + ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] + + //(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0] + ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0] + + ymm9 = _mm256_broadcast_sd((double const *)(&ones)); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm9); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm9, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm9, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm9, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm9, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm9); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm9, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm9, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] + ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] + ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] + + //determine correct values to store ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x0E); ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x0E); ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x0E); @@ -1611,35 +1779,37 @@ static err_t bli_dtrsm_small_AlXB( ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x0E); } - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store(B11[0-3][4]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store(B11[0-3][5]) - _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store(B11[0-3][6]) - _mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store(B11[0-3][7]) + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store(B11[0-3][4]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store(B11[0-3][5]) + _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store(B11[0-3][6]) + _mm256_storeu_pd((double *)(f_temp), ymm7); //store(B11[0-3][7]) + for(iter = 0; iter < m_remainder; iter++) + (b11 + cs_b * 7)[iter] = f_temp[iter]; } } - if((n & 4)) //implementation for remainder columns(when 'N' is a multiple of 4) + if((n & 4)) //implementation for remainder columns(when 'n_remainder' is greater than 4) { - for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction + for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - k_iter = i / D_MR; //number of times GEMM to be performed(in block of 4) + k_iter = i / D_MR; //number of times GEMM to be performed(in block of 4) ///GEMM for previously calculated values /// //load 4x4 block from b11 - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm4 = _mm256_setzero_pd(); @@ -1647,281 +1817,74 @@ static err_t bli_dtrsm_small_AlXB( ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a*2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ptr_b01_dup = b01; + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a*2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] - b01 += 1; //move to next row of B + b01 += 1; //move to next row of B - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - b01 += 1; + b01 += 1; - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - b01 += 1; + b01 += 1; - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[1][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[2][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[3][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[1][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[2][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[3][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - b01 += 1; + b01 += 1; - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B01[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B01[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B01[0-3][3] *alpha -= ymm7 + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ///implement TRSM/// - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] - - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] - - //3rd col - a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] - ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] - - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3] - //compute reciprocals of L(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[2][2] A11[2][2] - ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[1][1] A11[1][1] A11[3][3] A11[3][3] - - ymm14 = _mm256_broadcast_sd((double const *)&ones); - - ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] -/* - mat_b_rearr[0] = _mm256_mul_pd(mat_b_rearr[0], alphaReg); - mat_b_rearr[2] = _mm256_mul_pd(mat_b_rearr[2], alphaReg); -*/ - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] -/* - mat_b_rearr[1] = _mm256_mul_pd(mat_b_rearr[1], alphaReg); - mat_b_rearr[3] = _mm256_mul_pd(mat_b_rearr[3], alphaReg); -*/ - //extract a00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] - - //extract diag a11 from a - ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]*B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]*B11[0][0-3] - ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]*B11[0][0-3] - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] - - - //extract diag a22 from a - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]*B11[1][0-3] - ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]*B11[1][0-3] - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2] - - //extract diag a33 from a - ymm15 = _mm256_permute_pd(ymm14, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3] - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]*B11[2][0-3] - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - ymm13 = _mm256_mul_pd(ymm13, ymm15); //B11[3][0-3] /= A11[3][3] - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[0-3][3]) - - } - if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - - k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) - - ///GEMM for previously calculated values /// - - //load 4x4 block from b11 - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - for(k = 0; k < k_iter; k++) //looop for number of GEMM operations - { - ptr_b01_dup = b01; - - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - - } - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[0-3][3] *alpha -= ymm7 + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B01[0-3][1] *alpha -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B01[0-3][2] *alpha -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B01[0-3][3] *alpha -= ymm7 ///implement TRSM/// //1st col @@ -1932,116 +1895,305 @@ static err_t bli_dtrsm_small_AlXB( //2nd col a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] + ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] + ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] //3rd col a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] - ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] + ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] //4th col a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3] + ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3] //compute reciprocals of L(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] - ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[2][2] A11[3][3] A11[3][3] + ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[2][2] A11[2][2] + ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[1][1] A11[1][1] A11[3][3] A11[3][3] ymm14 = _mm256_broadcast_sd((double const *)&ones); - ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] -/* - mat_b_rearr[0] = _mm256_mul_pd(mat_b_rearr[0], alphaReg); - mat_b_rearr[2] = _mm256_mul_pd(mat_b_rearr[2], alphaReg); -*/ + ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] -/* - mat_b_rearr[1] = _mm256_mul_pd(mat_b_rearr[1], alphaReg); - mat_b_rearr[3] = _mm256_mul_pd(mat_b_rearr[3], alphaReg); -*/ + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + //extract a00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] + ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] + ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] //extract diag a11 from a - ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[][] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] + ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]* B11[0][0-3] - ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]* B11[0][0-3] + ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]*B11[0][0-3] + ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]*B11[0][0-3] + ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]*B11[0][0-3] //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] + ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] //extract diag a22 from a - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] + ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]* B11[1][0-3] - ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]* B11[1][0-3] + ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]*B11[1][0-3] + ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]*B11[1][0-3] //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2] + ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2] //extract diag a33 from a - ymm15 = _mm256_permute_pd(ymm14, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3] + ymm15 = _mm256_permute_pd(ymm14, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3] //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]* B11[2][0-3] + ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]*B11[2][0-3] //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B ymm13 = _mm256_mul_pd(ymm13, ymm15); //B11[3][0-3] /= A11[3][3] //--> Transpose and store results of columns of B block <--// ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] //rearrange high elements ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - //load 4x4 block from b11 - ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[0-3][3]) + + } + if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) + { + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) + + double f_temp[4]; + int iter; + + for(iter = 0; iter < m_remainder; iter++) + f_temp[iter] = (b11 + cs_b * 3)[iter]; + + ///GEMM for previously calculated values /// + + //load 4x4 block from b11 + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + for(k = 0; k < k_iter; k++) //looop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + + } + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[0-3][3] *alpha -= ymm7 - //determine correct values to store if(m_remainder == 3) { + ///implement TRSM/// + //1st col + ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] + + //2nd col + a11 += cs_a; + ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] + ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] + + //3rd col + a11 += cs_a; + ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] + + //4th col + a11 += cs_a; + ymm13 = _mm256_broadcast_sd((double const *)(&ones)); //A11[3][3] + + //compute reciprocals of L(i,i) and broadcast in registers + ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] + ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[2][2] A11[3][3] A11[3][3] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + ////unpacklow//// + ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a00 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] + + //extract diag a11 from a + ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[][] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) + ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] + ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]* B11[0][0-3] + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] + + //extract diag a22 from a + ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]* B11[1][0-3] + + //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B + ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2] + + ymm13 = _mm256_broadcast_sd((double const *)(&ones)); + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ////unpackhigh//// + ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + + ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + //load 4x4 block from b11 + ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] + + //determine correct values to store ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08); ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08); ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08); @@ -2049,13 +2201,151 @@ static err_t bli_dtrsm_small_AlXB( } if(m_remainder == 2) { + ///implement TRSM/// + //1st col + ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] + + //2nd col + a11 += cs_a; + ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] + + //compute reciprocals of L(i,i) and broadcast in registers + ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_blend_pd(ymm4, ymm14, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + ////unpacklow//// + ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a00 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] + + //extract diag a11 from a + ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[][] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) + ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] + + //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B + ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] + + ymm11 = _mm256_broadcast_sd((double const *)(&ones)); + ymm13 = _mm256_broadcast_sd((double const *)(&ones)); + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ////unpackhigh//// + ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + + ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + //load 4x4 block from b11 + ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] + + //determine correct values to store ymm0 = _mm256_permute2f128_pd(ymm0, ymm4,0x30); ymm1 = _mm256_permute2f128_pd(ymm1, ymm5,0x30); ymm2 = _mm256_permute2f128_pd(ymm2, ymm6,0x30); ymm3 = _mm256_permute2f128_pd(ymm3, ymm7,0x30); + } if(m_remainder == 1) { + ///implement TRSM/// + //1st col + ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + ////unpacklow//// + ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a00 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] + + //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B + ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] + + ymm8 = _mm256_broadcast_sd((double const *)(&ones)); + ymm11 = _mm256_broadcast_sd((double const *)(&ones)); + ymm13 = _mm256_broadcast_sd((double const *)(&ones)); + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ////unpackhigh//// + ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + + ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + //load 4x4 block from b11 + ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] + + //determine correct values to store ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x0E); ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x0E); ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x0E); @@ -2065,7 +2355,10 @@ static err_t bli_dtrsm_small_AlXB( _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + _mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[0-3][3]) + + for(iter = 0; iter < m_remainder; iter++) + (b11 + cs_b * 3)[iter] = f_temp[iter]; } @@ -2078,14 +2371,17 @@ static err_t bli_dtrsm_small_AlXB( { for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ///GEMM for previously calculated values /// @@ -2095,95 +2391,173 @@ static err_t bli_dtrsm_small_AlXB( ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_broadcast_sd((double const *)&ones); + + for(k = 0; k < k_iter; k++) + { + ptr_b01_dup = b01; + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value + + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6 + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 } if(n_remainder == 2) { ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); + + for(k = 0; k < k_iter; k++) + { + ptr_b01_dup = b01; + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value + + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6 + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 } if(n_remainder == 1) { ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_broadcast_sd((double const *)&ones); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const*)&ones); + + for(k = 0; k < k_iter; k++) + { + ptr_b01_dup = b01; + + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value + + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][1] *alpha -= ymm5 + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6 + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 + } - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - for(k = 0; k < k_iter; k++) - { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - } - - ///GEMM code ends/// - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[0-3][3] *alpha -= ymm7 - ///implement TRSM/// //1st col ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] @@ -2221,10 +2595,7 @@ static err_t bli_dtrsm_small_AlXB( //rearrange low elements ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] -/* - mat_b_rearr[0] = _mm256_mul_pd(mat_b_rearr[0], alphaReg); - mat_b_rearr[2] = _mm256_mul_pd(mat_b_rearr[2], alphaReg); -*/ + ////unpackhigh//// ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] @@ -2232,10 +2603,7 @@ static err_t bli_dtrsm_small_AlXB( //rearrange high elements ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] -/* - mat_b_rearr[1] = _mm256_mul_pd(mat_b_rearr[1], alphaReg); - mat_b_rearr[3] = _mm256_mul_pd(mat_b_rearr[3], alphaReg); -*/ + //extract a00 ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] @@ -2314,7 +2682,7 @@ static err_t bli_dtrsm_small_AlXB( } if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) { - a10 = L +i; //pointer to block of A to be used for GEMM + a10 = L +i; //pointer to block of A to be used for GEMM a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM b01 = B + j*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM @@ -2322,7 +2690,16 @@ static err_t bli_dtrsm_small_AlXB( k_iter = i / D_MR; //number of times GEMM operations to be performed - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value + double f_temp[4]; + int iter; + + for(iter = 0; iter < m_remainder; iter++) + f_temp[iter] = (b11 + cs_b * (n_remainder -1))[iter]; + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ///GEMM for previously calculated values /// @@ -2332,136 +2709,227 @@ static err_t bli_dtrsm_small_AlXB( { ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_broadcast_sd((double const *)&ones); + ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[0][2] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[1][2] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[2][2] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[3][2] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value + + ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 + ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5 + ymm10 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] * alpha -= ymm6 + + ///implement TRSM/// + //determine correct values to store + if(m_remainder == 3) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + } + if(m_remainder == 2) + { + ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); + ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); + ymm2 = _mm256_permute2f128_pd(ymm10, ymm2, 0x30); + } + if(m_remainder == 1) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + } + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(f_temp), ymm2); //store(B11[0-3][2]) } if(n_remainder == 2) { ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); + ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value + + ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 + ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5 + + ///implement TRSM/// + //determine correct values to store + if(m_remainder == 3) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + } + if(m_remainder == 2) + { + ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); + ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); + } + if(m_remainder == 1) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + } + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[0-3][1]) } if(n_remainder == 1) { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_broadcast_sd((double const *)&ones); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); + ymm0 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + + } + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value + + ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 + + ///implement TRSM/// + //determine correct values to store + if(m_remainder == 3) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + } + if(m_remainder == 2) + { + ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); + } + if(m_remainder == 1) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + } + _mm256_storeu_pd((double *)(f_temp), ymm0); //store(B11[0-3][0]) } - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[0][3] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[1][3] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[2][3] - - b01 += 1; //move to next row of B - - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[3][3] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - - } - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] * alpha -= ymm4 - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[0-3][1] * alpha -= ymm5 - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[0-3][2] * alpha -= ymm6 - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[0-3][3] * alpha -= ymm7 - - ///implement TRSM/// - //determine correct values to store - if(m_remainder == 3) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - - } - if(m_remainder == 2) - { - ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); - ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); - ymm2 = _mm256_permute2f128_pd(ymm10, ymm2, 0x30); - ymm3 = _mm256_permute2f128_pd(ymm11, ymm3, 0x30); - - } - if(m_remainder == 1) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - } - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - } + for(iter = 0; iter < m_remainder; iter++) + (b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter]; ///scalar code for trsm without alpha/// dtrsm_small_AlXB(a11, b11, m_remainder, n_remainder, cs_a, cs_b); @@ -2495,13 +2963,13 @@ a10 ****** b11 ***************** */ static err_t bli_dtrsm_small_AlXB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) { dim_t D_MR = 4; //size of block along 'M' dimpension @@ -2553,10 +3021,10 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( { for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' dimension { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4) @@ -2575,7 +3043,7 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( { ptr_b01_dup = b01; - ymm16 = _mm256_loadu_pd((double const *)(a10));//A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm16 = _mm256_loadu_pd((double const *)(a10));//A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] @@ -2587,19 +3055,19 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] - b01 += 1; //mobe to next row of B + b01 += 1; //mobe to next row of B - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a));//A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a));//A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] @@ -2611,19 +3079,19 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] - b01 += 1; //mobe to next row of B + b01 += 1; //mobe to next row of B - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2));//A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2));//A10[0][2] A10[1][2] A10[2][2] A10[3][2] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] @@ -2635,19 +3103,19 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] - b01 += 1; //mobe to next row of B + b01 += 1; //mobe to next row of B - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][4]*A10[3][2]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][5]*A10[3][2]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][6]*A10[3][2]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][7]*A10[3][2]) + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][4]*A10[3][2]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][5]*A10[3][2]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][6]*A10[3][2]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][7]*A10[3][2]) - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3));//A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3));//A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] @@ -2659,20 +3127,20 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] - b01 += 1; //mobe to next row of B + b01 += 1; //mobe to next row of B - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[3][0] B01[3][0]*A10[2][3] B01[3][0]*A10[3][0]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[3][1]*A10[0][3] B01[3][1]*A10[3][0] B01[3][1]*A10[2][3] B01[3][1]*A10[3][0]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[3][2]*A10[0][3] B01[3][2]*A10[3][0] B01[3][2]*A10[2][3] B01[3][2]*A10[3][0]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[3][3]*A10[0][3] B01[3][3]*A10[3][0] B01[3][3]*A10[2][3] B01[3][3]*A10[3][0]) + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[3][0] B01[3][0]*A10[2][3] B01[3][0]*A10[3][0]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[3][1]*A10[0][3] B01[3][1]*A10[3][0] B01[3][1]*A10[2][3] B01[3][1]*A10[3][0]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[3][2]*A10[0][3] B01[3][2]*A10[3][0] B01[3][2]*A10[2][3] B01[3][2]*A10[3][0]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[3][3]*A10[0][3] B01[3][3]*A10[3][0] B01[3][3]*A10[2][3] B01[3][3]*A10[3][0]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[3][4]*A10[0][3] B01[3][4]*A10[3][0] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[3][5]*A10[0][3] B01[3][5]*A10[3][0] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[3][6]*A10[0][3] B01[3][6]*A10[3][0] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[3][7]*A10[0][3] B01[3][7]*A10[3][0] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[3][4]*A10[0][3] B01[3][4]*A10[3][0] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[3][5]*A10[0][3] B01[3][5]*A10[3][0] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[3][6]*A10[0][3] B01[3][6]*A10[3][0] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[3][7]*A10[0][3] B01[3][7]*A10[3][0] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) - a10 += D_MR * cs_a; //pointer math to calculate next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM + a10 += D_MR * cs_a; //pointer math to calculate next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM } ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -2726,12 +3194,6 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - //broadcast diagonal elements of A11 - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_b +1)); //A11[1][1] - ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_b*2 + 2)); //A11[2][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+cs_b*3 + 3)); //A11[3][3] - ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0] ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][0] @@ -2763,57 +3225,63 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( a11 += cs_a; - //(ROw2): FMA operations + //(ROw1): FMA operations ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); //B11[3][0-3] -= A11[3][2] * B11[0-3][2] ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); //B11[7][0-3] -= A11[3][2] * B11[0-3][6] //unpacklow// ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store B11[6][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store B11[7][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store B11[7][0-3] } - if(m_remainder) //implementation for reamainder rows(when 'M' is not a multiple of D_MR) + if(m_remainder) //implementation for reamainder rows(when 'M' is not a multiple of D_MR) { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - k_iter = i / D_MR; //number of times GEMM operation to be done(in blocks of 4x4) + k_iter = i / D_MR; //number of times GEMM operation to be done(in blocks of 4x4) + + double f_temp[4]; + int iter; + + for(iter = 0; iter < m_remainder; iter++) + f_temp[iter] = (b11 + cs_b * 7)[iter]; ymm8 = _mm256_setzero_pd(); ymm9 = _mm256_setzero_pd(); @@ -2825,244 +3293,226 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm15 = _mm256_setzero_pd(); ///GEMM code Begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { ptr_b01_dup = b01; - ymm16 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm16 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] - b01 += 1; //move to next row of B + b01 += 1; //move to next row of B - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0] ) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0] ) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm16 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm16 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 1)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 1)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] - b01 += 1; //move to next row of B01 + b01 += 1; //move to next row of B01 - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] //A10[1][2] A10[2][2] A10[3][2] + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] //A10[1][2] A10[2][2] A10[3][2] - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] - b01 += 1; //move to next row of B + b01 += 1; //move to next row of B - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][0]*A10[3][2]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][1]*A10[3][2]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][2]*A10[3][2]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][3]*A10[3][2]) + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][0]*A10[3][2]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][1]*A10[3][2]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][2]*A10[3][2]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][3]*A10[3][2]) - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] + ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] + ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] + ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] + ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] + ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] + ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] - b01 += 1; //move to next row of B + b01 += 1; //move to next row of B - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm8 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm8 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm8 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) + ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm8 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm8 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) + ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm8 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm8 += (B01[3][0]*A10[0][3] B01[3][4]*A10[1][3] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm8 += (B01[3][1]*A10[0][3] B01[3][5]*A10[1][3] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm8 += (B01[3][2]*A10[0][3] B01[3][6]*A10[1][3] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm8 += (B01[3][3]*A10[0][3] B01[3][7]*A10[1][3] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) + ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm8 += (B01[3][0]*A10[0][3] B01[3][4]*A10[1][3] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) + ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm8 += (B01[3][1]*A10[0][3] B01[3][5]*A10[1][3] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm8 += (B01[3][2]*A10[0][3] B01[3][6]*A10[1][3] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) + ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm8 += (B01[3][3]*A10[0][3] B01[3][7]*A10[1][3] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } + ///GEMM code ends/// - ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *7)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] *alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] *alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] *alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] *alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] *alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] *alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] *alpha -= B01[0-3][6] + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] *alpha -= B01[0-3][7] - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] *alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] *alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] *alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] *alpha -= B01[0-3][3] - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] *alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] *alpha -= B01[0-3][5] - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] *alpha -= B01[0-3][6] - ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] *alpha -= B01[0-3][7] - - ///implement TRSM/// - - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - - //broadcast diagonal elements of A11 - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_b +1)); //A11[1][1] - ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_b*2 + 2)); //A11[2][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+cs_b*3 + 3)); //A11[3][3] - - ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] - ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0] - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][0] - - a11 += cs_a; - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= B11[0-3][0]*A11[2][0] - ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); //B11[3][0-3] -= B11[0-3][0]*A11[3][0] - - ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= B11[0-3][4]*A11[2][4] - ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); //B11[7][0-3] -= B11[0-3][4]*A11[3][4] - - ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1] - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][1] - - a11 += cs_a; - - //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1] - ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); //B11[3][0-3] -= A11[3][1] * B11[0-3][1] - - ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5] - ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); //B11[7][0-3] -= A11[3][1] * B11[0-3][5] - - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][2] - - a11 += cs_a; - - //(ROw2): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); //B11[0-3][3] -= A11[3][2]*B11[0-3][2] - - ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); //B11[0-3][7] -= A11[3][2]*B11[0-3][6] - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[2][5] B11[3][5] B11[2][7] B11[3][7] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - - ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] - ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b * 7)); //load B11[0-3][7] - //determine correct values to store if(m_remainder == 3) { + ///implement TRSM/// + + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] + ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0] + + a11 += cs_a; + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0] + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= B11[0-3][0]*A11[2][0] + + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4] + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= B11[0-3][4]*A11[2][4] + + ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1] + + a11 += cs_a; + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1] + + ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5] + + ymm11 = _mm256_broadcast_sd((double const *)(&ones)); + ymm15 = _mm256_broadcast_sd((double const *)(&ones)); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[2][5] B11[3][5] B11[2][7] B11[3][7] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] + ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] + ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] + + //determine correct values to store ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x08); ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x08); ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x08); @@ -3071,9 +3521,85 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x08); ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x08); ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x08); + } if(m_remainder == 2) { + ///implement TRSM/// + + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] + + a11 += cs_a; + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0] + + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4] + + ymm10 = _mm256_broadcast_sd((double const *)&ones); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm10, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm10, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm10, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm10, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm10, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm10, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm10, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm10, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] + ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] + ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] + + //determine correct values to store ymm0 = _mm256_permute2f128_pd(ymm0, ymm8, 0x30); ymm1 = _mm256_permute2f128_pd(ymm1, ymm9, 0x30); ymm2 = _mm256_permute2f128_pd(ymm2, ymm10, 0x30); @@ -3082,9 +3608,20 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm5 = _mm256_permute2f128_pd(ymm5, ymm13, 0x30); ymm6 = _mm256_permute2f128_pd(ymm6, ymm14, 0x30); ymm7 = _mm256_permute2f128_pd(ymm7, ymm15, 0x30); + } if(m_remainder == 1) { + ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] + ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] + ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] + + //determine correct values to store ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x0E); ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x0E); ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x0E); @@ -3095,35 +3632,37 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x0E); } - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store(B11[0-3][4]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store(B11[0-3][5]) - _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store(B11[0-3][6]) - _mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store(B11[0-3][7]) + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store(B11[0-3][4]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store(B11[0-3][5]) + _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store(B11[0-3][6]) + _mm256_storeu_pd((double *)(f_temp), ymm7); //store(B11[0-3][7]) + for(iter = 0; iter < m_remainder; iter++) + (b11 + cs_b * 7)[iter] = f_temp[iter]; } } - if((n & 4)) //implementation for remainder columns(when 'N' is a multiple of 4) + if((n & 4)) //implementation for remainder columns(when 'n_remainder' is greater than 4) { - for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction + for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - k_iter = i / D_MR; //number of times GEMM to be performed(in block of 4) + k_iter = i / D_MR; //number of times GEMM to be performed(in block of 4) ///GEMM for previously calculated values /// //load 4x4 block from b11 - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm4 = _mm256_setzero_pd(); @@ -3131,167 +3670,163 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a*2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ptr_b01_dup = b01; + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a*2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] - b01 += 1; //move to next row of B + b01 += 1; //move to next row of B - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - b01 += 1; + b01 += 1; - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - b01 += 1; + b01 += 1; - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[1][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[2][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[3][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[1][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[2][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[3][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - b01 += 1; + b01 += 1; - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B01[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B01[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B01[0-3][3] *alpha -= ymm7 + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B01[0-3][1] *alpha -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B01[0-3][2] *alpha -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B01[0-3][3] *alpha -= ymm7 ///implement TRSM/// //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] + ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] //2nd col a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] + ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] //3rd col a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] - ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] - - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3] + ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]*B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]*B11[0][0-3] - ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]*B11[0][0-3] + ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]*B11[0][0-3] + ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]*B11[0][0-3] + ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]*B11[0][0-3] //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]*B11[1][0-3] - ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]*B11[1][0-3] + ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]*B11[1][0-3] + ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]*B11[1][0-3] //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]*B11[2][0-3] + ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]*B11[2][0-3] //--> Transpose and store results of columns of B block <--// ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[0-3][3]) + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[0-3][3]) } if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) + double f_temp[4]; + int iter; + + for(iter = 0; iter < m_remainder; iter++) + f_temp[iter] = (b11 + cs_b * 3)[iter]; + ///GEMM for previously calculated values /// //load 4x4 block from b11 ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm4 = _mm256_setzero_pd(); @@ -3301,149 +3836,134 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( for(k = 0; k < k_iter; k++) //looop for number of GEMM operations { - ptr_b01_dup = b01; + ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] - b01 += 1; + b01 += 1; - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - b01 += 1; + b01 += 1; - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - b01 += 1; + b01 += 1; - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - b01 += 1; + b01 += 1; - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[0-3][3] *alpha -= ymm7 + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[0-3][3] *alpha -= ymm7 - ///implement TRSM/// - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] - - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] - - //3rd col - a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] - ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] - - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]* B11[0][0-3] - ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]* B11[0][0-3] - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]* B11[1][0-3] - ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]* B11[1][0-3] - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]* B11[2][0-3] - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - //load 4x4 block from b11 - ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] - - //determine correct values to store if(m_remainder == 3) { + ///implement TRSM/// + //1st col + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] + + //2nd col + a11 += cs_a; + ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] + + ////unpacklow//// + ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) + ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] + ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]* B11[0][0-3] + + //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) + ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]* B11[1][0-3] + + ymm13 = _mm256_broadcast_sd((double const *)(&ones)); + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ////unpackhigh//// + ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + + ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + //load 4x4 block from b11 + ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] + + //determine correct values to store ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08); ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08); ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08); @@ -3451,13 +3971,72 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( } if(m_remainder == 2) { + ///implement TRSM/// + //1st col + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] + + ////unpacklow//// + ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) + ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] + + ymm11 = _mm256_broadcast_sd((double const *)(&ones)); + ymm13 = _mm256_broadcast_sd((double const *)(&ones)); + + //--> Transpose and store results of columns of B block <--// + ////unpacklow//// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ////unpackhigh//// + ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + + ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + //load 4x4 block from b11 + ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] + + //determine correct values to store ymm0 = _mm256_permute2f128_pd(ymm0, ymm4,0x30); ymm1 = _mm256_permute2f128_pd(ymm1, ymm5,0x30); ymm2 = _mm256_permute2f128_pd(ymm2, ymm6,0x30); ymm3 = _mm256_permute2f128_pd(ymm3, ymm7,0x30); + } if(m_remainder == 1) { + //load 4x4 block from b11 + ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] + + //determine correct values to store ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x0E); ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x0E); ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x0E); @@ -3467,7 +4046,10 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + _mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[0-3][3]) + + for(iter = 0; iter < m_remainder; iter++) + (b11 + cs_b * 3)[iter] = f_temp[iter]; } @@ -3480,14 +4062,17 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( { for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + a10 = L +i; //pointer to block of A to be used for GEMM + a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ///GEMM for previously calculated values /// @@ -3497,117 +4082,188 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_broadcast_sd((double const *)&ones); + + for(k = 0; k < k_iter; k++) + { + ptr_b01_dup = b01; + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value + + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6 + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 } if(n_remainder == 2) { ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); + + for(k = 0; k < k_iter; k++) + { + ptr_b01_dup = b01; + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value + + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6 + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 } if(n_remainder == 1) { ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_broadcast_sd((double const *)&ones); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const*)&ones); + + for(k = 0; k < k_iter; k++) + { + ptr_b01_dup = b01; + + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] + + b01 += 1; + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value + + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][1] *alpha -= ymm5 + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6 + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 + } - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - for(k = 0; k < k_iter; k++) - { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - } - - ///GEMM code ends/// - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[0-3][3] *alpha -= ymm7 - ///implement TRSM/// //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] //2nd col a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] //3rd col a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3] - ////unpacklow//// ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] @@ -3624,6 +4280,7 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0] * B11[0][0-3] ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0] * B11[0][0-3] @@ -3675,7 +4332,7 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( } if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) { - a10 = L +i; //pointer to block of A to be used for GEMM + a10 = L +i; //pointer to block of A to be used for GEMM a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM b01 = B + j*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM @@ -3683,7 +4340,16 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( k_iter = i / D_MR; //number of times GEMM operations to be performed - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value + double f_temp[4]; + int iter; + + for(iter = 0; iter < m_remainder; iter++) + f_temp[iter] = (b11 + cs_b * (n_remainder -1))[iter]; + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ///GEMM for previously calculated values /// @@ -3693,136 +4359,227 @@ static err_t bli_dtrsm_small_AlXB_unitDiag( { ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_broadcast_sd((double const *)&ones); + ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[0][2] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[1][2] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[2][2] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[3][2] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value + + ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 + ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5 + ymm10 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] * alpha -= ymm6 + + ///implement TRSM/// + //determine correct values to store + if(m_remainder == 3) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + } + if(m_remainder == 2) + { + ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); + ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); + ymm2 = _mm256_permute2f128_pd(ymm10, ymm2, 0x30); + } + if(m_remainder == 1) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + } + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(f_temp), ymm2); //store(B11[0-3][2]) } if(n_remainder == 2) { ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); + ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + } + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value + + ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 + ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5 + + ///implement TRSM/// + //determine correct values to store + if(m_remainder == 3) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + } + if(m_remainder == 2) + { + ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); + ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); + } + if(m_remainder == 1) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + } + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[0-3][1]) } if(n_remainder == 1) { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_broadcast_sd((double const *)&ones); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); + ymm0 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_b01_dup = b01; + + ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) + + ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] + + b01 += 1; //move to next row of B + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) + + a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM + b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM + + } + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value + + ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 + + ///implement TRSM/// + //determine correct values to store + if(m_remainder == 3) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + } + if(m_remainder == 2) + { + ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); + } + if(m_remainder == 1) + { + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + } + _mm256_storeu_pd((double *)(f_temp), ymm0); //store(B11[0-3][0]) } - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[0][3] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[1][3] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[2][3] - - b01 += 1; //move to next row of B - - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[3][3] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - - } - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] * alpha -= ymm4 - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[0-3][1] * alpha -= ymm5 - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[0-3][2] * alpha -= ymm6 - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[0-3][3] * alpha -= ymm7 - - ///implement TRSM/// - //determine correct values to store - if(m_remainder == 3) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - - } - if(m_remainder == 2) - { - ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); - ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); - ymm2 = _mm256_permute2f128_pd(ymm10, ymm2, 0x30); - ymm3 = _mm256_permute2f128_pd(ymm11, ymm3, 0x30); - - } - if(m_remainder == 1) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - } - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - } + for(iter = 0; iter < m_remainder; iter++) + (b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter]; ///scalar code for trsm without alpha/// dtrsm_small_AlXB_unitDiag(a11, b11, m_remainder, n_remainder, cs_a, cs_b); @@ -3852,13 +4609,13 @@ b11 * * * * * **a01 * * a11 */ static err_t bli_dtrsm_small_XAuB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns @@ -3871,15 +4628,15 @@ static err_t bli_dtrsm_small_XAuB( dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B #ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO) - { - return BLIS_NOT_YET_IMPLEMENTED; - } + if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO) + { + return BLIS_NOT_YET_IMPLEMENTED; + } #else - if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO) - { - return BLIS_NOT_YET_IMPLEMENTED; - } + if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO) + { + return BLIS_NOT_YET_IMPLEMENTED; + } #endif dim_t i, j, k; //loop variablse @@ -3907,1169 +4664,1166 @@ static err_t bli_dtrsm_small_XAuB( for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction { - for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction - { - a01 = L + j*cs_a; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction + { + a01 = L + j*cs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4) + k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4) - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - a01 += 1; //move to next row + a01 += 1; //move to next row - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - a01 += 1; //move to next row of A01 + a01 += 1; //move to next row of A01 - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - a01 += 1; //move to next row of A01 + a01 += 1; //move to next row of A01 - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - //load 8x4 block of B11 - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + //load 8x4 block of B11 + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 + ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 + ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 + ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 + ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 + ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 + ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 - ///implement TRSM/// + ///implement TRSM/// - ///read 4x4 block of A11/// + ///read 4x4 block of A11/// - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - //2nd col - a11 += cs_a; - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1] + //2nd col + a11 += cs_a; + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1] - //3rd col - a11 += cs_a; - ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2] + //3rd col + a11 += cs_a; + ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] + ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] + ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2] - //4th col - a11 += cs_a; - ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3] + //4th col + a11 += cs_a; + ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3] - //compute reciprocals of L(i,i) and broadcast in registers - ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] + //compute reciprocals of L(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] - ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) - ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3] + ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3] + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3] - //extract a00 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + //extract a00 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] + ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] - ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] + ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] - //extract a11 - ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + //extract a11 + ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] - ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] + ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] + ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] - ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] + ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] - ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] + ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] - //extract a22 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + //extract a22 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) - //(Row2)FMA operations - ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] - ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] + //(Row2)FMA operations + ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] + ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] - ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] - ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] + ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] + ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] - ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2] + ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2] - ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2] + ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2] - //extract a33 - ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + //extract a33 + ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - //(Row3)FMA operations - ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] + //(Row3)FMA operations + ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] - ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] + ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] - ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3] + ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3] - ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3] + ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3] - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) } - if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) + if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) { - a01 = L + j*cs_a; //pointer to block of A to be used for GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used for GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + a01 = L + j*cs_a; //pointer to block of A to be used for GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used for GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) + k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) - ///load 4x4 block of b11 + ///load 4x4 block of b11 - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); //subtract the calculated GEMM block from current TRSM block - //load 8x4 block of B11 - if(n_remainder == 3) - { + //load 8x4 block of B11 + if(n_remainder == 3) + { ///GEMM implementation begins/// - for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { - ptr_a01_dup = a01; + ptr_a01_dup = a01; - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2] - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - //2nd col - a11 += cs_a; - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1] - - //3rd col - a11 += cs_a; - ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2] - - //4th col - a11 += cs_a; - ymm6 = _mm256_broadcast_sd((double const *)(&ones)); //A11[3][3] - - //compute reciprocals of L(i,i) and broadcast in registers - ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] - - ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) - - //extract a00 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - - ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] - - ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] - - //extract a11 - ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] - - ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] - - ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] - - //extract a22 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) - - //(Row2)FMA operations - ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] - - ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] - - ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2] - - ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2] - - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) - } - if(n_remainder == 2) - { - ///GEMM implementation begins/// - - for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - - a01 += 1; //move to next row of A - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - - a01 += 1; //move to next row of A - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - //2nd col - a11 += cs_a; - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1] - - //compute reciprocals of L(i,i) and broadcast in registers - ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - - ymm0 = _mm256_blend_pd(ymm0, ymm7, 0x0C); //A11[0][0] A11[1][1] 1 1 - ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/1 1/1) - - //extract a00 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - - ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] - - ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] - - //extract a11 - ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] - - ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] - - ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] - - - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - } - if(n_remainder == 1) - { - ///GEMM implementation begins/// - - for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - - a01 += 1; //move to next row of A - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - - a01 += 1; //move to next row of A - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + ///GEMM code ends/// - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2] - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] + ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] + ymm10 = _mm256_fmsub_pd(ymm10, ymm15, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] + ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] + ymm14 = _mm256_fmsub_pd(ymm14, ymm15, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] - ///implement TRSM/// + ///implement TRSM/// - ///read 4x4 block of A11/// + ///read 4x4 block of A11/// - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + //2nd col + a11 += cs_a; + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1] - //extract a00 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + //3rd col + a11 += cs_a; + ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] + ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] + ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2] - ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] + //4th col + a11 += cs_a; + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); //A11[3][3] - ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] + //compute reciprocals of L(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + + //extract a00 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + + ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] + + ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] + + //extract a11 + ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] + + ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] + + ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] + + ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] + + //extract a22 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(Row2)FMA operations + ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] + + ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] + + ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2] + + ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2] + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) } - } + if(n_remainder == 2) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ///GEMM code ends/// + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] + + ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] + ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] + ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + //2nd col + a11 += cs_a; + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1] + + //compute reciprocals of L(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + + ymm0 = _mm256_blend_pd(ymm0, ymm7, 0x0C); //A11[0][0] A11[1][1] 1 1 + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/1 1/1) + + //extract a00 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + + ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] + + ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] + + //extract a11 + ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] + + ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] + + ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] + + ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] + + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + } + if(n_remainder == 1) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + ///GEMM code ends/// + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] + + ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] + ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + + //extract a00 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + + ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] + + ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + } + } } - if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4) + if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4) { - for(j = 0; (j+D_NR-1)D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO) - { - return BLIS_NOT_YET_IMPLEMENTED; - } + if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO) + { + return BLIS_NOT_YET_IMPLEMENTED; + } #else - if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO) - { - return BLIS_NOT_YET_IMPLEMENTED; - } + if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO) + { + return BLIS_NOT_YET_IMPLEMENTED; + } #endif dim_t i, j, k; //loop variablse @@ -5616,1116 +6369,1361 @@ static err_t bli_dtrsm_small_XAuB_unitDiag( for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction { - for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction - { - a01 = L + j*cs_a; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction + { + a01 = L + j*cs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4) + k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4) - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - - a01 += 1; //move to next row - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - - a01 += 1; //move to next row of A01 - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - - a01 += 1; //move to next row of A01 - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - //load 8x4 block of B11 - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 - - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - //2nd col - a11 += cs_a; - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1] - - //3rd col - a11 += cs_a; - ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2] - - //4th col - a11 += cs_a; - ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3] - - ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3] - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] - - - //(Row2)FMA operations - ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] - ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] - - ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] - ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] - - - //(Row3)FMA operations - ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] - - ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] - - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) - } - if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) - { - a01 = L + j*cs_a; //pointer to block of A to be used for GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used for GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) - - ///load 4x4 block of b11 - - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation begins/// - - for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { - ptr_a01_dup = a01; + ptr_a01_dup = a01; - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A01 - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A01 - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - - //subtract the calculated GEMM block from current TRSM block - //load 8x4 block of B11 - if(n_remainder == 3) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - if(n_remainder == 2) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - if(n_remainder == 1) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - //2nd col - a11 += cs_a; - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1] - - //3rd col - a11 += cs_a; - ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2] - - //4th col - a11 += cs_a; - ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3] - - ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3] + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + //load 8x4 block of B11 + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] + ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 + ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 + ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 - ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] + ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 + ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 + ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 - //(Row2)FMA operations - ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] - ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] + ///implement TRSM/// - ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] - ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] + ///read 4x4 block of A11/// - //(Row3)FMA operations - ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] + //2nd col + a11 += cs_a; + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3] + //3rd col + a11 += cs_a; + ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] + ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] - ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3] + //4th col + a11 += cs_a; + ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3] + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3] - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - } - } - } - if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4) - { - for(j = 0; (j+D_NR-1)buffer; //value of Alpha double *L = a->buffer; //pointer to matrix A double *B = b->buffer; //pointer to matrix B @@ -8535,546 +9522,635 @@ static err_t bli_dtrsm_small_XAltB_unitDiag( for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction { - for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction - { - a01 = L + j; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction + { + a01 = L + j; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4) + k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4) - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - - a01 += cs_a; //move to next row - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - - a01 += cs_a; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - - a01 += cs_a; //move to next row of A01 - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - - a01 += cs_a; //move to next row of A01 - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - //load 8x4 block of B11 - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 - - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - //2nd col - a11 += 1; - ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] - - //3rd col - a11 += 1; - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] - - //4th col - a11 += 1; - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] - - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] - - //(Row2)FMA operations - ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] - ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] - - ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] - ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] - - //(Row3)FMA operations - ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] - - ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] - - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) - } - if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) - { - a01 = L + j; //pointer to block of A to be used for GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used for GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) - - ///load 4x4 block of b11 - - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation begins/// - - for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { - ptr_a01_dup = a01; + ptr_a01_dup = a01; - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - a01 += cs_a; //move to next row of A + a01 += cs_a; //move to next row - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - a01 += cs_a; //move to next row of A + a01 += cs_a; //move to next row of A - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - a01 += cs_a; //move to next row of A + a01 += cs_a; //move to next row of A01 - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - a01 += cs_a; //move to next row of A + a01 += cs_a; //move to next row of A01 - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + //load 8x4 block of B11 + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - //subtract the calculated GEMM block from current TRSM block - //load 8x4 block of B11 - if(n_remainder == 3) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - if(n_remainder == 2) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - if(n_remainder == 1) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] + ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 + ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 + ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 - ///implement TRSM/// + ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 + ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 + ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 - ///read 4x4 block of A11/// + ///implement TRSM/// - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + ///read 4x4 block of A11/// - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + //2nd col + a11 += 1; + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] - //2nd col - a11 += 1; - ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] + //3rd col + a11 += 1; + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] + ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] - //3rd col - a11 += 1; - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] + //4th col + a11 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] - //4th col - a11 += 1; - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] + ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] + //(Row2)FMA operations + ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] + ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] - ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] + ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] + ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] - //(Row2)FMA operations - ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] - ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] + //(Row3)FMA operations + ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] - ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] - ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] + ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] - //(Row3)FMA operations - ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] - - ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] - - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - } - } - } - if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4) - { - for(j = 0; (j+D_NR-1) 0; i -= D_MR) //loop along 'M' direction { - for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction - { - a01 = L + j*cs_a +(j+D_NR); //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM + for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction + { + a01 = L + j*cs_a +(j+D_NR); //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM - k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) + k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - a01 += 1; //move to next row + a01 += 1; //move to next row - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - a01 += 1; //move to next row of A01 + a01 += 1; //move to next row of A01 - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - a01 += 1; //move to next row of A01 + a01 += 1; //move to next row of A01 - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - //load 8x4 block of B11 - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + //load 8x4 block of B11 + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 + ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 + ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 + ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 + ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 + ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 + ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 - ///implement TRSM/// + ///implement TRSM/// - ///read 4x4 block of A11/// + ///read 4x4 block of A11/// - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - //2nd col - a11 += 1; - ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] + //2nd col + a11 += 1; + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] - //3rd col - a11 += 1; - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] + //3rd col + a11 += 1; + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] + ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] - //4th col - a11 += 1; - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] + //4th col + a11 += 1; + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] - //compute reciprocals of L(i,i) and broadcast in registers - ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] + //compute reciprocals of L(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] - ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] - //extract a33 - ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + //extract a33 + ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm11 = _mm256_mul_pd(ymm11, ymm0); ymm15 = _mm256_mul_pd(ymm15, ymm0); - //extract a22 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + //extract a22 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(row 3):FMA operations ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); @@ -9905,9 +11162,9 @@ static err_t bli_dtrsm_small_XAlB( ymm14 = _mm256_mul_pd(ymm14, ymm0); - //extract a11 - ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + //extract a11 + ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) //(Row 2): FMA operations ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); @@ -9920,198 +11177,197 @@ static err_t bli_dtrsm_small_XAlB( ymm13 = _mm256_mul_pd(ymm13, ymm0); - //extract a00 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + //extract a00 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) //(Row 1): FMA operations ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); - ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] + ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] - ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] + ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) } - if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) + if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) { - a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j + D_NR)*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j + D_NR)*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) + k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) - ///load 4x4 block of b11 + ///load 4x4 block of b11 - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); //subtract the calculated GEMM block from current TRSM block - //load 8x4 block of B11 - if(n_remainder == 3) - { + //load 8x4 block of B11 + if(n_remainder == 3) + { ///GEMM implementation begins/// - for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { - ptr_a01_dup = a01; + ptr_a01_dup = a01; - //broadcast 1st row of A01 - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + //broadcast 1st row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - //broadcast 2nd row of A01 - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + //broadcast 2nd row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - //broadcast 3rd row of A01 - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + //broadcast 3rd row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] )); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] )); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - //broadcast 4th row of A01 - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + //broadcast 4th row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// - ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2] + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] + ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2] - ///implement TRSM/// + ymm9 = _mm256_fmsub_pd(ymm9, ymm8, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] + ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] + ymm13 = _mm256_fmsub_pd(ymm13, ymm8, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] + ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] - ///read 4x4 block of A11/// + ///implement TRSM/// - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + ///read 4x4 block of A11/// - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(&ones)); //A11[0][0] + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - //2nd col - a11 += 1; - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(&ones)); //A11[0][0] - //3rd col - a11 += 1; - ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] + //2nd col + a11 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] - //4th col - a11 += 1; - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] + //3rd col + a11 += 1; + ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] - //compute reciprocals of L(i,i) and broadcast in registers - ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] + //4th col + a11 += 1; + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] - ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + //compute reciprocals of L(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) - //extract a33 - ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + + //extract a33 + ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm11 = _mm256_mul_pd(ymm11, ymm0); ymm15 = _mm256_mul_pd(ymm15, ymm0); - //extract a22 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + //extract a22 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(row 3):FMA operations ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); @@ -10124,9 +11380,9 @@ static err_t bli_dtrsm_small_XAlB( ymm14 = _mm256_mul_pd(ymm14, ymm0); - //extract a11 - ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + //extract a11 + ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) //(Row 2): FMA operations ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); @@ -10137,129 +11393,131 @@ static err_t bli_dtrsm_small_XAlB( ymm13 = _mm256_mul_pd(ymm13, ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - if(n_remainder == 2) - { + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + if(n_remainder == 2) + { ///GEMM implementation begins/// - for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { - ptr_a01_dup = a01; + ptr_a01_dup = a01; - //broadcast 1st row of A01 - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + //broadcast 1st row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - //broadcast 2nd row of A01 - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + //broadcast 2nd row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - //broadcast 3rd row of A01 - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + //broadcast 3rd row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - //broadcast 4th row of A01 - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + //broadcast 4th row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] + ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] - ///implement TRSM/// + ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] - ///read 4x4 block of A11/// + ///implement TRSM/// - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + ///read 4x4 block of A11/// - //3rd col - a11 += 2; - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - //4th col - a11 += 1; - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] + //3rd col + a11 += 2; + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] - //compute reciprocals of L(i,i) and broadcast in registers - ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] + //4th col + a11 += 1; + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] - ymm0 = _mm256_blend_pd(ymm7, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + //compute reciprocals of L(i,i) and broadcast in registers + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + ymm0 = _mm256_blend_pd(ymm7, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) - //extract a33 - ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + + //extract a33 + ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm11 = _mm256_mul_pd(ymm11, ymm0); ymm15 = _mm256_mul_pd(ymm15, ymm0); - //extract a22 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + //extract a22 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(row 3):FMA operations @@ -10271,253 +11529,249 @@ static err_t bli_dtrsm_small_XAlB( ymm14 = _mm256_mul_pd(ymm14, ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - if(n_remainder == 1) - { + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + if(n_remainder == 1) + { ///GEMM implementation begins/// - for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { - ptr_a01_dup = a01; + ptr_a01_dup = a01; - //broadcast 1st row of A01 - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + //broadcast 1st row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - //broadcast 2nd row of A01 - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + //broadcast 2nd row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - //broadcast 3rd row of A01 - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + //broadcast 3rd row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - //broadcast 4th row of A01 - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + //broadcast 4th row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// - ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] + ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] - ///implement TRSM/// + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] - ///read 4x4 block of A11/// + ///implement TRSM/// - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + ///read 4x4 block of A11/// - //4th col - a11 += 3; - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - //compute reciprocals of L(i,i) and broadcast in registers - ymm7 = _mm256_div_pd(ymm7, ymm6); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + //4th col + a11 += 3; + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] + + //compute reciprocals of L(i,i) and broadcast in registers + ymm7 = _mm256_div_pd(ymm7, ymm6); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) ymm11 = _mm256_mul_pd(ymm11, ymm7); ymm15 = _mm256_mul_pd(ymm15, ymm7); - _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - } + _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + } } if(i<0) i += D_NR; - if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4) + if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4) { - for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction + for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction { - a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM + a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) + k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ///GEMM for previous blocks /// + ///GEMM for previous blocks /// - ///load 4x4 block of b11 - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ///load 4x4 block of b11 + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - //multiply by alpha - ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha - ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha - ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha - ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; - //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - //broadcast 1st row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + //broadcast 1st row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - //broadcast 2nd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + //broadcast 2nd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) - //braodcast 3rd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + //braodcast 3rd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) - //broadcast 4th row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + //broadcast 4th row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - } + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } ///GEMM code end/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4 - ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5 - ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6 - ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7 + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 - ///implement TRSM/// + ///implement TRSM/// - ///read 4x4 block of A11/// + ///read 4x4 block of A11/// - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][0] + //1st col + ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][0] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][0] + ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][0] - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1] + //2nd col + a11 += cs_a; + ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1] - //3rd col - a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2] - ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] + //3rd col + a11 += cs_a; + ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2] + ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3] + //4th col + a11 += cs_a; + ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3] - ymm14 = _mm256_broadcast_sd((double const *)&ones); + ymm14 = _mm256_broadcast_sd((double const *)&ones); - //compute reciprocals of A(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + //compute reciprocals of A(i,i) and broadcast in registers + ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] - ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - //extract a33 - ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + //extract a33 + ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm3 = _mm256_mul_pd(ymm3, ymm15); - //extract a22 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + //extract a22 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(Row 3): FMA operations ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); @@ -10526,9 +11780,9 @@ static err_t bli_dtrsm_small_XAlB( ymm2 = _mm256_mul_pd(ymm2, ymm15); - //extract a11 - ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + //extract a11 + ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) //(ROW 2): FMA operations ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); @@ -10536,115 +11790,114 @@ static err_t bli_dtrsm_small_XAlB( ymm1 = _mm256_mul_pd(ymm1, ymm15); - //extract A00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + //extract A00 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) //(Row 1):FMA operations ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); ymm0 = _mm256_mul_pd(ymm0, ymm15); - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) } - if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR) - { - a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM - a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR) + { + a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM + a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) + k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM for previous blocks /// - if(n_remainder == 3) - { - ///load 4x4 block of b11 - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ///GEMM for previous blocks /// + if(n_remainder == 3) + { + ///load 4x4 block of b11 + ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ///GEMM processing stars/// - for(k = 0; k < k_iter; k++) - { - ptr_a01_dup = a01; + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; - //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - //broadcast 1st row of A01 - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + //broadcast 1st row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - //broadcast 2nd row of A01 - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + //broadcast 2nd row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) - //braodcast 3rd row of A01 - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + //braodcast 3rd row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) - //broadcast 4th row of A01 - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + //broadcast 4th row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - } + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[x][1] -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[x][2] -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[x][3] -= ymm7 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 - ///implement TRSM/// + ///implement TRSM/// - ///read 4x4 block of A11/// + ///read 4x4 block of A11/// //2nd col a11 += cs_a; ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] @@ -10662,22 +11915,22 @@ static err_t bli_dtrsm_small_XAlB( ymm14 = _mm256_broadcast_sd((double const *)&ones); - //compute reciprocals of A(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm14, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + //compute reciprocals of A(i,i) and broadcast in registers + ymm4 = _mm256_unpacklo_pd(ymm14, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] - ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - //extract a33 - ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + //extract a33 + ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm3 = _mm256_mul_pd(ymm3, ymm15); - //extract a22 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + //extract a22 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(Row 3): FMA operations ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); @@ -10685,9 +11938,9 @@ static err_t bli_dtrsm_small_XAlB( ymm2 = _mm256_mul_pd(ymm2, ymm15); - //extract a11 - ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + //extract a11 + ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) //(ROW 2): FMA operations ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); @@ -10695,75 +11948,76 @@ static err_t bli_dtrsm_small_XAlB( ymm1 = _mm256_mul_pd(ymm1, ymm15); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0]) - } - if(n_remainder == 2) - { - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + } + if(n_remainder == 2) + { + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ///GEMM processing stars/// - for(k = 0; k < k_iter; k++) - { - ptr_a01_dup = a01; + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; - //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - //broadcast 1st row of A01 - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + //broadcast 1st row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - //broadcast 2nd row of A01 - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + //broadcast 2nd row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) - //braodcast 3rd row of A01 - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + //braodcast 3rd row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) - //broadcast 4th row of A01 - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + //broadcast 4th row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - } + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[x][2] -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[x][3] -= ymm7 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 - ///implement TRSM/// + ///implement TRSM/// - ///read 4x4 block of A11/// + ///read 4x4 block of A11/// //3rd col a11 += 2 * cs_a; @@ -10776,86 +12030,87 @@ static err_t bli_dtrsm_small_XAlB( ymm14 = _mm256_broadcast_sd((double const *)&ones); - //compute reciprocals of A(i,i) and broadcast in registers - ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + //compute reciprocals of A(i,i) and broadcast in registers + ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] - ymm15 = _mm256_blend_pd(ymm14, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + ymm15 = _mm256_blend_pd(ymm14, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - //extract a33 - ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + //extract a33 + ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm3 = _mm256_mul_pd(ymm3, ymm15); - //extract a22 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + //extract a22 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(Row 3): FMA operations ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); ymm2 = _mm256_mul_pd(ymm2, ymm15); - _mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1]) - } - if(n_remainder == 1) - { - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + _mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1]) + } + if(n_remainder == 1) + { + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ///GEMM processing stars/// - for(k = 0; k < k_iter; k++) - { - ptr_a01_dup = a01; + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; - //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - //broadcast 1st row of A01 - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + //broadcast 1st row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - //broadcast 2nd row of A01 - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + //broadcast 2nd row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) - //braodcast 3rd row of A01 - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + //braodcast 3rd row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) - //broadcast 4th row of A01 - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + //broadcast 4th row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - } + } ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[x][3] -= ymm7 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 - ///implement TRSM/// + ///implement TRSM/// - ///read 4x4 block of A11/// + ///read 4x4 block of A11/// //4th col a11 += 3 * cs_a; @@ -10863,19 +12118,19 @@ static err_t bli_dtrsm_small_XAlB( ymm14 = _mm256_broadcast_sd((double const *)&ones); - //compute reciprocals of A(i,i) and broadcast in registers - ymm14 = _mm256_div_pd(ymm14, ymm13); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + //compute reciprocals of A(i,i) and broadcast in registers + ymm14 = _mm256_div_pd(ymm14, ymm13); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - //extract a33 + //extract a33 ymm3 = _mm256_mul_pd(ymm3, ymm14); - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0]) - } + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0]) + } } - m_remainder -= 4; - i -= 4; + m_remainder -= 4; + i -= 4; } -// if(i < 0) i = 0; +// if(i < 0) i = 0; if(m_remainder) ///implementation for remainder rows { dtrsm_small_XAlB(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); @@ -10902,13 +12157,13 @@ b10 ***************** ************* */ static err_t bli_dtrsm_small_XAlB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns @@ -10920,6 +12175,7 @@ static err_t bli_dtrsm_small_XAlB_unitDiag( dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B + #ifdef BLIS_ENABLE_SMALL_MATRIX_ROME if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_ROME) { @@ -10936,8 +12192,6 @@ static err_t bli_dtrsm_small_XAlB_unitDiag( dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides - double ones = 1.0; - double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha double *L = a->buffer; //pointer to matrix A double *B = b->buffer; //pointer to matrix B @@ -10957,167 +12211,160 @@ static err_t bli_dtrsm_small_XAlB_unitDiag( for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction { - for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction - { - a01 = L + j*cs_a +(j+D_NR); //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM + for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction + { + a01 = L + j*cs_a +(j+D_NR); //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM - k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) + k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - a01 += 1; //move to next row + a01 += 1; //move to next row - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - a01 += 1; //move to next row of A01 + a01 += 1; //move to next row of A01 - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - a01 += 1; //move to next row of A01 + a01 += 1; //move to next row of A01 - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - //load 8x4 block of B11 - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + //load 8x4 block of B11 + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 + ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 + ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 + ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 + ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 + ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 + ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 - ///implement TRSM/// + ///implement TRSM/// - ///read 4x4 block of A11/// + ///read 4x4 block of A11/// - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + //2nd col + a11 += 1; + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] - //2nd col - a11 += 1; - ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] + //3rd col + a11 += 1; + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] + ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] - //3rd col - a11 += 1; - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] - - //4th col - a11 += 1; - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] - - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + //4th col + a11 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] //(row 3):FMA operations ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); @@ -11140,559 +12387,473 @@ static err_t bli_dtrsm_small_XAlB_unitDiag( ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) } - if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) + if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) { - a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j + D_NR)*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j + D_NR)*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) + k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) - ///load 4x4 block of b11 + ///load 4x4 block of b11 - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation begins/// - - for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - - a01 += 1; //move to next row of A - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - - a01 += 1; //move to next row of A - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); //subtract the calculated GEMM block from current TRSM block - //load 8x4 block of B11 - if(n_remainder == 3) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2] - } - if(n_remainder == 2) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] - } - if(n_remainder == 1) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] - } + //load 8x4 block of B11 + if(n_remainder == 3) + { + ///GEMM implementation begins/// - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; - ///implement TRSM/// + //broadcast 1st row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - ///read 4x4 block of A11/// + a01 += 1; //move to next row of A - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - //2nd col - a11 += 1; - ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - //3rd col - a11 += 1; - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] + //broadcast 2nd row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - //4th col - a11 += 1; - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] + a01 += 1; //move to next row of A + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - //(row 3):FMA operations - ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8); + //broadcast 3rd row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); - ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12); + a01 += 1; //move to next row of A - //(Row 2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8); + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] )); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12); + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - //(Row 1): FMA operations - ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); + //broadcast 4th row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - } + a01 += 1; //move to next row of A + + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2] + + ymm9 = _mm256_fmsub_pd(ymm9, ymm8, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] + ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] + ymm13 = _mm256_fmsub_pd(ymm13, ymm8, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] + ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //3rd col + a11 += 2; + ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] + + //4th col + a11 += 1; + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + + //(row 3):FMA operations + ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); + + ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); + + //(Row 2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); + + ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); + + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + if(n_remainder == 2) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] + + ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] + + ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + //4th col + a11 += 3; + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + + //(row 3):FMA operations + ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); + + ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); + + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + if(n_remainder == 1) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] + + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] + + _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + } } if(i<0) i += D_NR; - if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4) + if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4) { - for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction + for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction { - a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM + a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) + k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ///GEMM for previous blocks /// + ///GEMM for previous blocks /// - ///load 4x4 block of b11 - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ///load 4x4 block of b11 + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - //multiply by alpha - ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha - ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha - ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha - ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; - //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - //broadcast 1st row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + //broadcast 1st row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - //broadcast 2nd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + //broadcast 2nd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) - //braodcast 3rd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + //braodcast 3rd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) - //broadcast 4th row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + //broadcast 4th row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - } + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } ///GEMM code end/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4 - ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5 - ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6 - ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7 + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 - ///implement TRSM/// + ///implement TRSM/// - ///read 4x4 block of A11/// - - - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][0] - - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1] - - //3rd col - a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2] - ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] - - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3] - - ymm14 = _mm256_broadcast_sd((double const *)&ones); - - - //(Row 3): FMA operations - ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); - ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0); - - //(ROW 2): FMA operations - ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0); - - //(Row 1):FMA operations - ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); - - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //store(B11[x][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) - - } - if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR) - { - a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM - a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value - ///GEMM for previous blocks /// - - ///load 4x4 block of b11 - if(n_remainder == 3) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - } - if(n_remainder == 2) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - } - if(n_remainder == 1) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - } - //multiply by alpha - ymm0 = _mm256_mul_pd(ymm0, ymm16); //B11[x][0] *= alpha - ymm1 = _mm256_mul_pd(ymm1, ymm16); //B11[x][1] *=alpha - ymm2 = _mm256_mul_pd(ymm2, ymm16); //B11[x][2] *= alpha - ymm3 = _mm256_mul_pd(ymm3, ymm16); //B11[x][3] *= alpha - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - - ///GEMM processing stars/// - - for(k = 0; k < k_iter; k++) - { - ptr_a01_dup = a01; - - //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - - //broadcast 1st row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - - a01 += 1; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - - a01 += 1; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) - - //braodcast 3rd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - - a01 += 1; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) - - //broadcast 4th row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - - a01 += 1; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) - - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - - } - - ///GEMM code ends/// - - ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -= ymm4 - ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5 - ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6 - ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// + ///read 4x4 block of A11/// //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][0] ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][0] ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][0] //2nd col a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1] ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1] //3rd col a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2] ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3] - - ymm14 = _mm256_broadcast_sd((double const *)&ones); - //(Row 3): FMA operations ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); @@ -11705,27 +12866,264 @@ static err_t bli_dtrsm_small_XAlB_unitDiag( //(Row 1):FMA operations ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0]) - } + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) } - m_remainder -= 4; - i -= 4; + if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR) + { + a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM + a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + ///GEMM for previous blocks /// + if(n_remainder == 3) + { + ///load 4x4 block of b11 + ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///GEMM processing stars/// + + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + //2nd col + a11 += cs_a; + ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1] + + //3rd col + a11 += cs_a; + ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] + + //(Row 3): FMA operations + ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); + ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); + + //(ROW 2): FMA operations + ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); + + _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0]) + } + if(n_remainder == 2) + { + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ///GEMM processing stars/// + + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //3rd col + a11 += 2 * cs_a; + ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] + + + //(Row 3): FMA operations + ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); + + _mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1]) + } + if(n_remainder == 1) + { + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///GEMM processing stars/// + + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + + } + + ///GEMM code ends/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value + + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0]) + } + } + m_remainder -= 4; + i -= 4; } - if(m_remainder) +// if(i < 0) i = 0; + if(m_remainder) ///implementation for remainder rows { dtrsm_small_XAlB_unitDiag(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); } @@ -11752,13 +13150,13 @@ b10 ***************** ************* */ static err_t bli_dtrsm_small_XAutB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns @@ -11807,190 +13205,190 @@ static err_t bli_dtrsm_small_XAutB( for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction { - for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction - { - a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM + for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction + { + a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM - k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) + k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - a01 += cs_a; //move to next row + a01 += cs_a; //move to next row - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - a01 += cs_a; //move to next row of A + a01 += cs_a; //move to next row of A - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - a01 += cs_a; //move to next row of A01 + a01 += cs_a; //move to next row of A01 - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - a01 += cs_a; //move to next row of A01 + a01 += cs_a; //move to next row of A01 - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - //load 8x4 block of B11 - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + //load 8x4 block of B11 + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 + ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 + ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 + ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 + ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 + ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 + ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 - ///implement TRSM/// + ///implement TRSM/// - ///read 4x4 block of A11/// + ///read 4x4 block of A11/// - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - a11 += cs_a; + a11 += cs_a; - //2nd col - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + //2nd col + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - a11 += cs_a; + a11 += cs_a; - //3rd col - ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + //3rd col + ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - a11 += cs_a; + a11 += cs_a; - //4th col - ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] + //4th col + ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] - ymm7 = _mm256_broadcast_sd((double const *)&ones); + ymm7 = _mm256_broadcast_sd((double const *)&ones); - //compute reciprocals of A(i,i) and broadcast in registers - ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + //compute reciprocals of A(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] - ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm0 = _mm256_div_pd(ymm7, ymm0); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm0 = _mm256_div_pd(ymm7, ymm0); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - //extract a33 - ymm7 = _mm256_permute_pd(ymm0, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) - ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + //extract a33 + ymm7 = _mm256_permute_pd(ymm0, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) + ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm11 = _mm256_mul_pd(ymm11, ymm7); ymm15 = _mm256_mul_pd(ymm15, ymm7); - //extract a22 - ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + //extract a22 + ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(Row 3): FMA operations ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); @@ -12006,9 +13404,9 @@ static err_t bli_dtrsm_small_XAutB( ymm14 = _mm256_mul_pd(ymm14, ymm7); - //extract a11 - ymm7 = _mm256_permute_pd(ymm0, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + //extract a11 + ymm7 = _mm256_permute_pd(ymm0, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) //(ROW 2): FMA operations ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); @@ -12021,9 +13419,9 @@ static err_t bli_dtrsm_small_XAutB( ymm13 = _mm256_mul_pd(ymm13, ymm7); - //extract A00 - ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + //extract A00 + ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) //(Row 1):FMA operations ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); @@ -12034,449 +13432,595 @@ static err_t bli_dtrsm_small_XAutB( ymm12 = _mm256_mul_pd(ymm12, ymm7); - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[x][3]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[x][3]) + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[x][3]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[x][3]) } - if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) + if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) { - a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM + a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM - k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) + k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); - ///GEMM implementation starts/// + //load 8x4 block of B11 + if(n_remainder == 3) + { + ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + //broadcast 1st row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - a01 += cs_a; //move to next row + a01 += cs_a; //move to next row - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + //broadcast 2nd row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - a01 += cs_a; //move to next row of A + a01 += cs_a; //move to next row of A - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + //broadcast 3rd row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - a01 += cs_a; //move to next row of A01 + a01 += cs_a; //move to next row of A01 - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + //broadcast 4th row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - a01 += cs_a; //move to next row of A01 + a01 += cs_a; //move to next row of A01 - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2] + + ymm9 = _mm256_fmsub_pd(ymm9, ymm8, ymm1); //B11[4-7][0] * alpha-= ymm1 + ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3 + + ymm13 = _mm256_fmsub_pd(ymm13, ymm8, ymm5); //B11[4-7][2] * alpha -= ymm5 + ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(&ones)); //A11[0][0] + + a11 += cs_a; + + //2nd col + ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + + a11 += cs_a; + + //3rd col + ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + + a11 += cs_a; + + //4th col + ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] + + ymm7 = _mm256_broadcast_sd((double const *)&ones); + + //compute reciprocals of A(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm0 = _mm256_div_pd(ymm7, ymm0); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + //extract a33 + ymm7 = _mm256_permute_pd(ymm0, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) + ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + + ymm11 = _mm256_mul_pd(ymm11, ymm7); + + ymm15 = _mm256_mul_pd(ymm15, ymm7); + + //extract a22 + ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(Row 3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); + + //(Row 3): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); + + ymm10 = _mm256_mul_pd(ymm10, ymm7); + + ymm14 = _mm256_mul_pd(ymm14, ymm7); + + //extract a11 + ymm7 = _mm256_permute_pd(ymm0, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(ROW 2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); + + ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); + + ymm9 = _mm256_mul_pd(ymm9, ymm7); + + ymm13 = _mm256_mul_pd(ymm13, ymm7); + + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) } + if(n_remainder == 2) + { + ///GEMM implementation starts/// - ///GEMM code ends/// + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - //load 8x4 block of B11 - if(n_remainder == 3) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2] - } - if(n_remainder == 2) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] - } - if(n_remainder == 1) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] - } + //broadcast 1st row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 + a01 += cs_a; //move to next row - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - ///implement TRSM/// + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ///read 4x4 block of A11/// + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + //broadcast 2nd row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - a11 += cs_a; + a01 += cs_a; //move to next row of A - //2nd col - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - a11 += cs_a; + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - //3rd col - ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + //broadcast 3rd row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - a11 += cs_a; + a01 += cs_a; //move to next row of A01 - //4th col - ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - ymm7 = _mm256_broadcast_sd((double const *)&ones); + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - //compute reciprocals of A(i,i) and broadcast in registers - ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + //broadcast 4th row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm0 = _mm256_div_pd(ymm7, ymm0); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + a01 += cs_a; //move to next row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - //extract a33 - ymm7 = _mm256_permute_pd(ymm0, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) - ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - ymm11 = _mm256_mul_pd(ymm11, ymm7); + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - ymm15 = _mm256_mul_pd(ymm15, ymm7); + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } - //extract a22 - ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + ///GEMM code ends/// + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); - //(Row 3): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8); + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] - //(Row 3): FMA operations - ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); - ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12); + ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3 - ymm10 = _mm256_mul_pd(ymm10, ymm7); + ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7 - ymm14 = _mm256_mul_pd(ymm14, ymm7); + ///implement TRSM/// - //extract a11 - ymm7 = _mm256_permute_pd(ymm0, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + ///read 4x4 block of A11/// - //(ROW 2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8); + //1st col + a11 += 2 * cs_a; - ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12); + //3rd col + ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - ymm9 = _mm256_mul_pd(ymm9, ymm7); + a11 += cs_a; - ymm13 = _mm256_mul_pd(ymm13, ymm7); + //4th col + ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] - //extract A00 - ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + ymm7 = _mm256_broadcast_sd((double const *)&ones); - //(Row 1):FMA operations - ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); + //compute reciprocals of A(i,i) and broadcast in registers + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] - ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); + ymm0 = _mm256_blend_pd(ymm7, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm0 = _mm256_div_pd(ymm7, ymm0); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - ymm8 = _mm256_mul_pd(ymm8, ymm7); + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + //extract a33 + ymm7 = _mm256_permute_pd(ymm0, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) + ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - ymm12 = _mm256_mul_pd(ymm12, ymm7); + ymm11 = _mm256_mul_pd(ymm11, ymm7); - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - } + ymm15 = _mm256_mul_pd(ymm15, ymm7); + + //extract a22 + ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(Row 3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); + + //(Row 3): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); + + ymm10 = _mm256_mul_pd(ymm10, ymm7); + + ymm14 = _mm256_mul_pd(ymm14, ymm7); + + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + if(n_remainder == 1) + { + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A01 + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A01 + + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] + + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3 + + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + a11 += 3 * cs_a; + + //4th col + ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] + + ymm7 = _mm256_broadcast_sd((double const *)&ones); + + ymm0 = _mm256_div_pd(ymm7, ymm6); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + ymm11 = _mm256_mul_pd(ymm11, ymm0); + + ymm15 = _mm256_mul_pd(ymm15, ymm0); + + _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + } } if(i<0) i += D_NR; - if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4) + if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4) { - for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction + for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction { - a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM + a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) + k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ///GEMM for previous blocks /// + ///GEMM for previous blocks /// - ///load 4x4 block of b11 - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ///load 4x4 block of b11 + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - //multiply by alpha - ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha - ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha - ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha - ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; - //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - //broadcast 1st row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + //broadcast 1st row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - a01 += cs_a; //move to next row of A + a01 += cs_a; //move to next row of A - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - //broadcast 2nd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + //broadcast 2nd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - a01 += cs_a; //move to next row of A + a01 += cs_a; //move to next row of A - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) - //braodcast 3rd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + //braodcast 3rd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - a01 += cs_a; //move to next row of A + a01 += cs_a; //move to next row of A - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) - //broadcast 4th row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + //broadcast 4th row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - a01 += cs_a; //move to next row of A + a01 += cs_a; //move to next row of A - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR*cs_a; //pointer math to find next block of A for GEMM - } + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR*cs_a; //pointer math to find next block of A for GEMM + } ///GEMM code end/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4 - ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5 - ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6 - ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7 + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 - ///implement TRSM/// + ///implement TRSM/// - ///read 4x4 block of A11/// + ///read 4x4 block of A11/// - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + //1st col + ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - a11 += cs_a; + a11 += cs_a; - //2nd col - ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + //2nd col + ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - a11 += cs_a; + a11 += cs_a; - //3rd col - ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + //3rd col + ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - a11 += cs_a; + a11 += cs_a; - //4th col - ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] + //4th col + ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] + ymm14 = _mm256_broadcast_sd((double const *)&ones); - ymm14 = _mm256_broadcast_sd((double const *)&ones); + //compute reciprocals of A(i,i) and broadcast in registers + ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] - //compute reciprocals of A(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - - //extract a33 - ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + //extract a33 + ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm3 = _mm256_mul_pd(ymm3, ymm15); - //extract a22 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + //extract a22 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(Row 3): FMA operations ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); @@ -12485,9 +14029,9 @@ static err_t bli_dtrsm_small_XAutB( ymm2 = _mm256_mul_pd(ymm2, ymm15); - //extract a11 - ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + //extract a11 + ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) //(ROW 2): FMA operations ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); @@ -12495,242 +14039,361 @@ static err_t bli_dtrsm_small_XAutB( ymm1 = _mm256_mul_pd(ymm1, ymm15); - //extract A00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + //extract A00 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) //(Row 1):FMA operations ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); ymm0 = _mm256_mul_pd(ymm0, ymm15); - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) } - if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR) - { + if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR) + { - a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM + a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) + k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ///GEMM for previous blocks /// + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + ///GEMM for previous blocks /// - ///load 4x4 block of b11 - if(n_remainder == 3) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - } - if(n_remainder == 2) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - } - if(n_remainder == 1) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - } + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); - //multiply by alpha - ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha - ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha - ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha - ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha + ///load 4x4 block of b11 + if(n_remainder == 3) + { + ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + ///GEMM implementation starts/// - ///GEMM implementation starts/// + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + //broadcast 1st row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - //broadcast 1st row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + a01 += cs_a; //move to next row of A - a01 += cs_a; //move to next row of A + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + //broadcast 2nd row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - //broadcast 2nd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + a01 += cs_a; //move to next row of A - a01 += cs_a; //move to next row of A + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + //braodcast 3rd row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - //braodcast 3rd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + a01 += cs_a; //move to next row of A - a01 += cs_a; //move to next row of A + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + //broadcast 4th row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - //broadcast 4th row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + a01 += cs_a; //move to next row of A - a01 += cs_a; //move to next row of A + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code end/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //1st col + ymm4 = _mm256_broadcast_sd((double const *)(&ones)); //A11[0][0] + + a11 += cs_a; + + //2nd col + ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + + a11 += cs_a; + + //3rd col + ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + + a11 += cs_a; + + //4th col + ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + //compute reciprocals of A(i,i) and broadcast in registers + ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + + ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + //extract a33 + ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + + ymm3 = _mm256_mul_pd(ymm3, ymm15); + + //extract a22 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(Row 3): FMA operations + ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); + ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); + + ymm2 = _mm256_mul_pd(ymm2, ymm15); + + //extract a11 + ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(ROW 2): FMA operations + ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); + + ymm1 = _mm256_mul_pd(ymm1, ymm15); + + _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0]) + } + if(n_remainder == 2) + { + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code end/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //1st col + + a11 += 2 * cs_a; + + //3rd col + ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + + a11 += cs_a; + + //4th col + ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + //compute reciprocals of A(i,i) and broadcast in registers + ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + + ymm15 = _mm256_blend_pd(ymm14, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + //extract a33 + ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + + ymm3 = _mm256_mul_pd(ymm3, ymm15); + + //extract a22 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(Row 3): FMA operations + ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); + + ymm2 = _mm256_mul_pd(ymm2, ymm15); - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM - } + _mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1]) + } + if(n_remainder == 1) + { + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ///GEMM code end/// + ///GEMM implementation starts/// - ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4 - ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5 - ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6 - ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7 + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; - ///implement TRSM/// + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - ///read 4x4 block of A11/// + //broadcast 1st row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + a01 += cs_a; //move to next row of A - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - a11 += cs_a; + //broadcast 2nd row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - //2nd col - ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + a01 += cs_a; //move to next row of A - a11 += cs_a; + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) - //3rd col - ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + //braodcast 3rd row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - a11 += cs_a; + a01 += cs_a; //move to next row of A - //4th col - ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + //broadcast 4th row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - ymm14 = _mm256_broadcast_sd((double const *)&ones); + a01 += cs_a; //move to next row of A - //compute reciprocals of A(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) - ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } - //extract a33 - ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + ///GEMM code end/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ymm3 = _mm256_mul_pd(ymm3, ymm15); + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 - //extract a22 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + ///implement TRSM/// - //(Row 3): FMA operations - ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); - ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0); + ///read 4x4 block of A11/// - ymm2 = _mm256_mul_pd(ymm2, ymm15); + a11 += 3 * cs_a; - //extract a11 - ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + //4th col + ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] - //(ROW 2): FMA operations - ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0); + ymm14 = _mm256_broadcast_sd((double const *)&ones); - ymm1 = _mm256_mul_pd(ymm1, ymm15); + //compute reciprocals of A(i,i) and broadcast in registers + ymm14 = _mm256_div_pd(ymm14, ymm13); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - //extract A00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - - //(Row 1):FMA operations - ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); - - ymm0 = _mm256_mul_pd(ymm0, ymm15); - - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0]) - } + ymm3 = _mm256_mul_pd(ymm3, ymm14); + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0]) + } } - m_remainder -= 4; - i -= 4; + m_remainder -= 4; + i -= 4; } if(m_remainder) ///implementation for remainder rows { - dtrsm_small_XAutB(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); + dtrsm_small_XAutB(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); } return BLIS_SUCCESS; } @@ -12754,13 +14417,13 @@ b10 ***************** ************* */ static err_t bli_dtrsm_small_XAutB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) + side_t side, + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns @@ -12788,8 +14451,6 @@ static err_t bli_dtrsm_small_XAutB_unitDiag( dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides - double ones = 1.0; - double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha double *L = a->buffer; //pointer to matrix A double *B = b->buffer; //pointer to matrix B @@ -12809,391 +14470,164 @@ static err_t bli_dtrsm_small_XAutB_unitDiag( for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction { - for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction - { - a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) - - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - - a01 += cs_a; //move to next row - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - - a01 += cs_a; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - - a01 += cs_a; //move to next row of A01 - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - - a01 += cs_a; //move to next row of A01 - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - //load 8x4 block of B11 - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 - - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - a11 += cs_a; - - //2nd col - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - - a11 += cs_a; - - //3rd col - ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - - a11 += cs_a; - - //4th col - ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] - - ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - - //(Row 3): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8); - - //(Row 3): FMA operations - ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); - ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12); - - //(ROW 2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8); - - ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12); - - //(Row 1):FMA operations - ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); - - ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); - - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[x][3]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[x][3]) - - } - if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) + for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction { + a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM - a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM + k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) - k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) - - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - a01 += cs_a; //move to next row + a01 += cs_a; //move to next row - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - a01 += cs_a; //move to next row of A + a01 += cs_a; //move to next row of A - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - a01 += cs_a; //move to next row of A01 + a01 += cs_a; //move to next row of A01 - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - a01 += cs_a; //move to next row of A01 + a01 += cs_a; //move to next row of A01 - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - //load 8x4 block of B11 - if(n_remainder == 3) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2] - } - if(n_remainder == 2) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] - } - if(n_remainder == 1) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] - } + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); + //load 8x4 block of B11 + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 + ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 + ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 + ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 + ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 + ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 + ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 ///implement TRSM/// - ///read 4x4 block of A11/// + ///read 4x4 block of A11/// - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + //1st col - a11 += cs_a; + a11 += cs_a; - //2nd col - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + //2nd col + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - a11 += cs_a; + a11 += cs_a; - //3rd col - ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + //3rd col + ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - a11 += cs_a; + a11 += cs_a; - //4th col - ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] - - ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + //4th col + ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] //(Row 3): FMA operations ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); @@ -13217,168 +14651,478 @@ static err_t bli_dtrsm_small_XAutB_unitDiag( ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[x][3]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[x][3]) - } + } + if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) + { + + a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) + + ymm0 = _mm256_setzero_pd(); + ymm1 = _mm256_setzero_pd(); + ymm2 = _mm256_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + //load 8x4 block of B11 + if(n_remainder == 3) + { + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A01 + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A01 + + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0] + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2] + + ymm9 = _mm256_fmsub_pd(ymm9, ymm8, ymm1); //B11[4-7][0] * alpha-= ymm1 + ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3 + + ymm13 = _mm256_fmsub_pd(ymm13, ymm8, ymm5); //B11[4-7][2] * alpha -= ymm5 + ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + a11 += 2 * cs_a; + + //3rd col + ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + + a11 += cs_a; + + //4th col + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + + //(Row 3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); + + //(Row 3): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); + + //(ROW 2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); + + ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); + + _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + if(n_remainder == 2) + { + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A01 + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A01 + + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] + + ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha-= ymm2 + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3 + + ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= ymm6 + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + + a11 += 3 * cs_a; + + //4th col + ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + + //(Row 3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); + + //(Row 3): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); + + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + if(n_remainder == 1) + { + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A01 + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A01 + + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] + + ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3 + + ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7 + + _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + } + } } if(i<0) i += D_NR; - if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4) + if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4) { - for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction + for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction { - a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM + a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) + k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ///GEMM for previous blocks /// + ///GEMM for previous blocks /// - ///load 4x4 block of b11 - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ///load 4x4 block of b11 + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - //multiply by alpha - ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha - ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha - ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha - ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; - //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - //broadcast 1st row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + //broadcast 1st row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - a01 += cs_a; //move to next row of A + a01 += cs_a; //move to next row of A - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - //broadcast 2nd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + //broadcast 2nd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - a01 += cs_a; //move to next row of A + a01 += cs_a; //move to next row of A - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) - //braodcast 3rd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + //braodcast 3rd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - a01 += cs_a; //move to next row of A + a01 += cs_a; //move to next row of A - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) - //broadcast 4th row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + //broadcast 4th row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - a01 += cs_a; //move to next row of A + a01 += cs_a; //move to next row of A - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR*cs_a; //pointer math to find next block of A for GEMM - } + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR*cs_a; //pointer math to find next block of A for GEMM + } ///GEMM code end/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4 - ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5 - ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6 - ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7 + ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 - ///implement TRSM/// + ///implement TRSM/// - ///read 4x4 block of A11/// + ///read 4x4 block of A11/// + a11 += cs_a; + //2nd col + ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + a11 += cs_a; - a11 += cs_a; + //3rd col + ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - //2nd col - ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - - a11 += cs_a; - - //3rd col - ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - - a11 += cs_a; - - //4th col - ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] + a11 += cs_a; + //4th col + ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] //(Row 3): FMA operations ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); @@ -13392,199 +15136,270 @@ static err_t bli_dtrsm_small_XAutB_unitDiag( //(Row 1):FMA operations ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) } - if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR) - { + if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR) + { - a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM + a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM + b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) + k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ///GEMM for previous blocks /// + ///GEMM for previous blocks /// - ///load 4x4 block of b11 - if(n_remainder == 3) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - } - if(n_remainder == 2) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - } - if(n_remainder == 1) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - } + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); - //multiply by alpha - ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha - ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha - ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha - ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha + ///load 4x4 block of b11 + if(n_remainder == 3) + { + ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + ///GEMM implementation starts/// - ///GEMM implementation starts/// + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + //broadcast 1st row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - //broadcast 1st row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + a01 += cs_a; //move to next row of A - a01 += cs_a; //move to next row of A + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + //broadcast 2nd row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - //broadcast 2nd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + a01 += cs_a; //move to next row of A - a01 += cs_a; //move to next row of A + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + //braodcast 3rd row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - //braodcast 3rd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + a01 += cs_a; //move to next row of A - a01 += cs_a; //move to next row of A + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + //broadcast 4th row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - //broadcast 4th row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + a01 += cs_a; //move to next row of A - a01 += cs_a; //move to next row of A + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code end/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + a11 += 2 * cs_a; + + //3rd col + ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + + a11 += cs_a; + + //4th col + ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + + //(Row 3): FMA operations + ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); + ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); + + //(ROW 2): FMA operations + ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); + + _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0]) + } + if(n_remainder == 2) + { + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + + a01 += cs_a; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + + a01 += cs_a; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + + a01 += cs_a; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + + a01 += cs_a; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code end/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM - } + a11 += 3 * cs_a; - ///GEMM code end/// + //4th col + ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4 - ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5 - ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6 - ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7 + //(Row 3): FMA operations + ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); - ///implement TRSM/// + _mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1]) + } + if(n_remainder == 1) + { + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ///read 4x4 block of A11/// + ///GEMM implementation starts/// + for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + { + ptr_a01_dup = a01; - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - a11 += cs_a; + //broadcast 1st row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - //2nd col - ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + a01 += cs_a; //move to next row of A - a11 += cs_a; + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - //3rd col - ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] + //broadcast 2nd row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - a11 += cs_a; + a01 += cs_a; //move to next row of A - //4th col - ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) - //(Row 3): FMA operations - ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); - ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0); + //braodcast 3rd row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - //(ROW 2): FMA operations - ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0); + a01 += cs_a; //move to next row of A - //(Row 1):FMA operations - ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0]) - } + //broadcast 4th row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + a01 += cs_a; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code end/// + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 + + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0]) + } } - m_remainder -= 4; - i -= 4; + m_remainder -= 4; + i -= 4; } - if(m_remainder) + if(m_remainder) ///implementation for remainder rows { - dtrsm_small_XAutB_unitDiag(a->buffer, b->buffer,AlphaVal, m_remainder, n, cs_a, cs_b); + dtrsm_small_XAutB_unitDiag(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); } return BLIS_SUCCESS; } @@ -16451,7 +18266,7 @@ static void blis_dtrsm_microkernel_alpha(double *ptr_l, mat_b_col[0] = _mm256_loadu_pd((double const *)ptr_b); mat_b_col[1] = _mm256_loadu_pd((double const *)(ptr_b + (cs_b))); mat_b_col[2] = _mm256_loadu_pd((double const *)(ptr_b + cs_b_offset[0])); - mat_b_col[3] = _mm256_broadcast_sd((double const *)&ones); + mat_b_col[3] = _mm256_broadcast_sd((double const *)&ones); } if(n_remainder == 2) { @@ -16459,7 +18274,7 @@ static void blis_dtrsm_microkernel_alpha(double *ptr_l, mat_b_col[0] = _mm256_loadu_pd((double const *)ptr_b); mat_b_col[1] = _mm256_loadu_pd((double const *)(ptr_b + (cs_b))); mat_b_col[2] = _mm256_broadcast_sd((double const *)&ones); - mat_b_col[3] = _mm256_broadcast_sd((double const *)&ones); + mat_b_col[3] = _mm256_broadcast_sd((double const *)&ones); } if(n_remainder == 1) { @@ -16467,7 +18282,7 @@ static void blis_dtrsm_microkernel_alpha(double *ptr_l, mat_b_col[0] = _mm256_loadu_pd((double const *)ptr_b); mat_b_col[1] = _mm256_broadcast_sd((double const *)&ones); mat_b_col[2] = _mm256_broadcast_sd((double const *)&ones); - mat_b_col[3] = _mm256_broadcast_sd((double const *)&ones); + mat_b_col[3] = _mm256_broadcast_sd((double const *)&ones); } /*Shuffle to rearrange/transpose 8x4 block of B into contiguous row-wise registers*/ ////unpacklow////